/ internal / server / server.go
server.go
  1  package server
  2  
  3  import (
  4  	"errors"
  5  	"fmt"
  6  	"net"
  7  	"net/rpc"
  8  	"os"
  9  	"os/signal"
 10  	"path/filepath"
 11  	"syscall"
 12  	"time"
 13  
 14  	"codeflow.dananglin.me.uk/apollo/enbas/internal/gtsclient"
 15  	"codeflow.dananglin.me.uk/apollo/enbas/internal/printer"
 16  	"codeflow.dananglin.me.uk/apollo/enbas/internal/utilities"
 17  )
 18  
 19  const minIdleTimeout = 60
 20  
 21  var (
 22  	ErrSocketFileInUse        = errors.New("this socket file is used by another server")
 23  	ErrSocketFileNotSpecified = errors.New("the path to the socket file is not specified")
 24  )
 25  
 26  func Run(
 27  	printSettings printer.Settings,
 28  	client *gtsclient.GTSClient,
 29  	socketPath string,
 30  	noIdleTimeout bool,
 31  	idleTimeout int,
 32  ) error {
 33  	if socketPath == "" {
 34  		return ErrSocketFileNotSpecified
 35  	}
 36  
 37  	socketPath, err := utilities.AbsolutePath(socketPath)
 38  	if err != nil {
 39  		return fmt.Errorf(
 40  			"unable to calculate the absolute path to the socket file: %w",
 41  			err,
 42  		)
 43  	}
 44  
 45  	// Ensure that the socket file's parent folder is present.
 46  	if err := utilities.EnsureDirectory(filepath.Dir(socketPath)); err != nil {
 47  		return fmt.Errorf(
 48  			"unable to ensure the presence of the socket's parent directory: %w",
 49  			err,
 50  		)
 51  	}
 52  
 53  	if err := removeUnusedSocketFile(socketPath); err != nil {
 54  		return fmt.Errorf("error removing the unused socket file: %w", err)
 55  	}
 56  
 57  	// Create the RPC server and register the GTS client methods.
 58  	server := rpc.NewServer()
 59  
 60  	if err := server.Register(client); err != nil {
 61  		return fmt.Errorf("error registering the GTSClient methods to the server: %w", err)
 62  	}
 63  
 64  	// Create a channel for receiving the shutdown signal.
 65  	shutdown := make(chan os.Signal, 1)
 66  	signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
 67  
 68  	// Run the server without a timer.
 69  	if noIdleTimeout {
 70  		return runWithoutIdleTimeout(
 71  			printSettings,
 72  			server,
 73  			socketPath,
 74  			shutdown,
 75  		)
 76  	}
 77  
 78  	// Run the server with a timer.
 79  	return runWithIdleTimeout(
 80  		printSettings,
 81  		server,
 82  		socketPath,
 83  		idleTimeout,
 84  		shutdown,
 85  	)
 86  }
 87  
 88  // runWithIdleTimeout runs the RPC server. The server closes after a specified amount of idle time or when the
 89  // shutdown signal is received.
 90  func runWithIdleTimeout(
 91  	printSettings printer.Settings,
 92  	server *rpc.Server,
 93  	socketPath string,
 94  	idleTimeout int,
 95  	shutdown <-chan os.Signal,
 96  ) error {
 97  	listener, err := net.Listen("unix", socketPath)
 98  	if err != nil {
 99  		return fmt.Errorf("error serving socket connection: %w", err)
100  	}
101  	defer listener.Close()
102  
103  	printer.PrintInfo("Running the server using socket path: " + socketPath + "\n")
104  
105  	// Create a timer for the idle timeout.
106  	if idleTimeout < minIdleTimeout {
107  		idleTimeout = minIdleTimeout
108  	}
109  
110  	timeout := time.Duration(idleTimeout) * time.Second
111  	timer := time.NewTimer(timeout)
112  
113  	// Listen and serve connections from the client in a separate goroutine.
114  	go func() {
115  		for {
116  			conn, err := listener.Accept()
117  			if err != nil {
118  				if errors.Is(err, net.ErrClosed) {
119  					printer.PrintInfo("Network connection closed.\n")
120  
121  					return
122  				}
123  
124  				printer.PrintFailure(
125  					printSettings,
126  					"Error accepting the connection: "+err.Error()+".",
127  				)
128  
129  				os.Exit(1)
130  			}
131  
132  			timer.Reset(timeout)
133  
134  			go server.ServeConn(conn)
135  		}
136  	}()
137  
138  	select {
139  	case <-timer.C:
140  		printer.PrintInfo("Server idle timeout.\n")
141  
142  		return nil
143  	case <-shutdown:
144  		printer.PrintInfo("Shutdown signal received.\n")
145  
146  		return nil
147  	}
148  }
149  
150  // runWithoutIdleTimeout runs the RPC server. The server closes when the shutdown signal is received.
151  func runWithoutIdleTimeout(
152  	printSettings printer.Settings,
153  	server *rpc.Server,
154  	socketPath string,
155  	shutdown <-chan os.Signal,
156  ) error {
157  	listener, err := net.Listen("unix", socketPath)
158  	if err != nil {
159  		return fmt.Errorf("error serving socket connection: %w", err)
160  	}
161  	defer listener.Close()
162  
163  	printer.PrintInfo("Running the server using socket path: " + socketPath + "\n")
164  
165  	// Listen and serve connections from the client in a separate goroutine.
166  	go func() {
167  		for {
168  			conn, err := listener.Accept()
169  			if err != nil {
170  				if errors.Is(err, net.ErrClosed) {
171  					printer.PrintInfo("Network connection closed.\n")
172  
173  					return
174  				}
175  
176  				printer.PrintFailure(
177  					printSettings,
178  					"Error accepting the connection: "+err.Error()+".",
179  				)
180  
181  				os.Exit(1)
182  			}
183  
184  			go server.ServeConn(conn)
185  		}
186  	}()
187  
188  	<-shutdown
189  
190  	printer.PrintInfo("Shutdown signal received.\n")
191  
192  	return nil
193  }
194  
195  // removeUnusedSocketFile removes the socket file if it already exists and
196  // is not being used by a running server.
197  func removeUnusedSocketFile(path string) error {
198  	// Check for the existence of the socket path.
199  	exists, err := utilities.FileExists(path)
200  	if err != nil {
201  		return fmt.Errorf("received an error checking for the socket file: %w", err)
202  	}
203  
204  	if !exists {
205  		return nil
206  	}
207  
208  	// Attempt a connection to the socket path to see if it is in use.
209  	_, err = rpc.Dial("unix", path)
210  
211  	// If the connection is successful, then the socket file is currently in
212  	// use by another running server.
213  	if err == nil {
214  		return ErrSocketFileInUse
215  	}
216  
217  	// If no connection can be made then it should be safe to remove the file.
218  	if err := os.Remove(path); err != nil {
219  		return fmt.Errorf("error removing the unused socket file: %w", err)
220  	}
221  
222  	return nil
223  }