1
0

server.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. package grace
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/http"
  8. "os"
  9. "os/exec"
  10. "os/signal"
  11. "strings"
  12. "sync"
  13. "syscall"
  14. "time"
  15. )
  16. // Server embedded http.Server
  17. type Server struct {
  18. *http.Server
  19. GraceListener net.Listener
  20. SignalHooks map[int]map[os.Signal][]func()
  21. tlsInnerListener *graceListener
  22. wg sync.WaitGroup
  23. sigChan chan os.Signal
  24. isChild bool
  25. state uint8
  26. Network string
  27. }
  28. // Serve accepts incoming connections on the Listener l,
  29. // creating a new service goroutine for each.
  30. // The service goroutines read requests and then call srv.Handler to reply to them.
  31. func (srv *Server) Serve() (err error) {
  32. srv.state = StateRunning
  33. err = srv.Server.Serve(srv.GraceListener)
  34. log.Println(syscall.Getpid(), "Waiting for connections to finish...")
  35. srv.wg.Wait()
  36. srv.state = StateTerminate
  37. return
  38. }
  39. // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
  40. // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
  41. // used.
  42. func (srv *Server) ListenAndServe() (err error) {
  43. addr := srv.Addr
  44. if addr == "" {
  45. addr = ":http"
  46. }
  47. go srv.handleSignals()
  48. l, err := srv.getListener(addr)
  49. if err != nil {
  50. log.Println(err)
  51. return err
  52. }
  53. srv.GraceListener = newGraceListener(l, srv)
  54. if srv.isChild {
  55. process, err := os.FindProcess(os.Getppid())
  56. if err != nil {
  57. log.Println(err)
  58. return err
  59. }
  60. err = process.Kill()
  61. if err != nil {
  62. return err
  63. }
  64. }
  65. log.Println(os.Getpid(), srv.Addr)
  66. return srv.Serve()
  67. }
  68. // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
  69. // Serve to handle requests on incoming TLS connections.
  70. //
  71. // Filenames containing a certificate and matching private key for the server must
  72. // be provided. If the certificate is signed by a certificate authority, the
  73. // certFile should be the concatenation of the server's certificate followed by the
  74. // CA's certificate.
  75. //
  76. // If srv.Addr is blank, ":https" is used.
  77. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
  78. addr := srv.Addr
  79. if addr == "" {
  80. addr = ":https"
  81. }
  82. if srv.TLSConfig == nil {
  83. srv.TLSConfig = &tls.Config{}
  84. }
  85. if srv.TLSConfig.NextProtos == nil {
  86. srv.TLSConfig.NextProtos = []string{"http/1.1"}
  87. }
  88. srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
  89. srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  90. if err != nil {
  91. return
  92. }
  93. go srv.handleSignals()
  94. l, err := srv.getListener(addr)
  95. if err != nil {
  96. log.Println(err)
  97. return err
  98. }
  99. srv.tlsInnerListener = newGraceListener(l, srv)
  100. srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
  101. if srv.isChild {
  102. process, err := os.FindProcess(os.Getppid())
  103. if err != nil {
  104. log.Println(err)
  105. return err
  106. }
  107. err = process.Kill()
  108. if err != nil {
  109. return err
  110. }
  111. }
  112. log.Println(os.Getpid(), srv.Addr)
  113. return srv.Serve()
  114. }
  115. // getListener either opens a new socket to listen on, or takes the acceptor socket
  116. // it got passed when restarted.
  117. func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
  118. if srv.isChild {
  119. var ptrOffset uint
  120. if len(socketPtrOffsetMap) > 0 {
  121. ptrOffset = socketPtrOffsetMap[laddr]
  122. log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
  123. }
  124. f := os.NewFile(uintptr(3+ptrOffset), "")
  125. l, err = net.FileListener(f)
  126. if err != nil {
  127. err = fmt.Errorf("net.FileListener error: %v", err)
  128. return
  129. }
  130. } else {
  131. l, err = net.Listen(srv.Network, laddr)
  132. if err != nil {
  133. err = fmt.Errorf("net.Listen error: %v", err)
  134. return
  135. }
  136. }
  137. return
  138. }
  139. // handleSignals listens for os Signals and calls any hooked in function that the
  140. // user had registered with the signal.
  141. func (srv *Server) handleSignals() {
  142. var sig os.Signal
  143. signal.Notify(
  144. srv.sigChan,
  145. hookableSignals...,
  146. )
  147. pid := syscall.Getpid()
  148. for {
  149. sig = <-srv.sigChan
  150. srv.signalHooks(PreSignal, sig)
  151. switch sig {
  152. case syscall.SIGHUP:
  153. log.Println(pid, "Received SIGHUP. forking.")
  154. err := srv.fork()
  155. if err != nil {
  156. log.Println("Fork err:", err)
  157. }
  158. case syscall.SIGINT:
  159. log.Println(pid, "Received SIGINT.")
  160. srv.shutdown()
  161. case syscall.SIGTERM:
  162. log.Println(pid, "Received SIGTERM.")
  163. srv.shutdown()
  164. default:
  165. log.Printf("Received %v: nothing i care about...\n", sig)
  166. }
  167. srv.signalHooks(PostSignal, sig)
  168. }
  169. }
  170. func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
  171. if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
  172. return
  173. }
  174. for _, f := range srv.SignalHooks[ppFlag][sig] {
  175. f()
  176. }
  177. return
  178. }
  179. // shutdown closes the listener so that no new connections are accepted. it also
  180. // starts a goroutine that will serverTimeout (stop all running requests) the server
  181. // after DefaultTimeout.
  182. func (srv *Server) shutdown() {
  183. if srv.state != StateRunning {
  184. return
  185. }
  186. srv.state = StateShuttingDown
  187. if DefaultTimeout >= 0 {
  188. go srv.serverTimeout(DefaultTimeout)
  189. }
  190. err := srv.GraceListener.Close()
  191. if err != nil {
  192. log.Println(syscall.Getpid(), "Listener.Close() error:", err)
  193. } else {
  194. log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
  195. }
  196. }
  197. // serverTimeout forces the server to shutdown in a given timeout - whether it
  198. // finished outstanding requests or not. if Read/WriteTimeout are not set or the
  199. // max header size is very big a connection could hang
  200. func (srv *Server) serverTimeout(d time.Duration) {
  201. defer func() {
  202. if r := recover(); r != nil {
  203. log.Println("WaitGroup at 0", r)
  204. }
  205. }()
  206. if srv.state != StateShuttingDown {
  207. return
  208. }
  209. time.Sleep(d)
  210. log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
  211. for {
  212. if srv.state == StateTerminate {
  213. break
  214. }
  215. srv.wg.Done()
  216. }
  217. }
  218. func (srv *Server) fork() (err error) {
  219. regLock.Lock()
  220. defer regLock.Unlock()
  221. if runningServersForked {
  222. return
  223. }
  224. runningServersForked = true
  225. var files = make([]*os.File, len(runningServers))
  226. var orderArgs = make([]string, len(runningServers))
  227. for _, srvPtr := range runningServers {
  228. switch srvPtr.GraceListener.(type) {
  229. case *graceListener:
  230. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
  231. default:
  232. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
  233. }
  234. orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
  235. }
  236. log.Println(files)
  237. path := os.Args[0]
  238. var args []string
  239. if len(os.Args) > 1 {
  240. for _, arg := range os.Args[1:] {
  241. if arg == "-graceful" {
  242. break
  243. }
  244. args = append(args, arg)
  245. }
  246. }
  247. args = append(args, "-graceful")
  248. if len(runningServers) > 1 {
  249. args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
  250. log.Println(args)
  251. }
  252. cmd := exec.Command(path, args...)
  253. cmd.Stdout = os.Stdout
  254. cmd.Stderr = os.Stderr
  255. cmd.ExtraFiles = files
  256. err = cmd.Start()
  257. if err != nil {
  258. log.Fatalf("Restart: Failed to launch, error: %v", err)
  259. }
  260. return
  261. }
  262. // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
  263. func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
  264. if ppFlag != PreSignal && ppFlag != PostSignal {
  265. err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal.")
  266. return
  267. }
  268. for _, s := range hookableSignals {
  269. if s == sig {
  270. srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
  271. return
  272. }
  273. }
  274. err = fmt.Errorf("Signal '%v' is not supported.", sig)
  275. return
  276. }