123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- package grace
- import (
- "crypto/tls"
- "fmt"
- "log"
- "net"
- "net/http"
- "os"
- "os/exec"
- "os/signal"
- "strings"
- "sync"
- "syscall"
- "time"
- )
- // Server embedded http.Server
- type Server struct {
- *http.Server
- GraceListener net.Listener
- SignalHooks map[int]map[os.Signal][]func()
- tlsInnerListener *graceListener
- wg sync.WaitGroup
- sigChan chan os.Signal
- isChild bool
- state uint8
- Network string
- }
- // Serve accepts incoming connections on the Listener l,
- // creating a new service goroutine for each.
- // The service goroutines read requests and then call srv.Handler to reply to them.
- func (srv *Server) Serve() (err error) {
- srv.state = StateRunning
- err = srv.Server.Serve(srv.GraceListener)
- log.Println(syscall.Getpid(), "Waiting for connections to finish...")
- srv.wg.Wait()
- srv.state = StateTerminate
- return
- }
- // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
- // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
- // used.
- func (srv *Server) ListenAndServe() (err error) {
- addr := srv.Addr
- if addr == "" {
- addr = ":http"
- }
- go srv.handleSignals()
- l, err := srv.getListener(addr)
- if err != nil {
- log.Println(err)
- return err
- }
- srv.GraceListener = newGraceListener(l, srv)
- if srv.isChild {
- process, err := os.FindProcess(os.Getppid())
- if err != nil {
- log.Println(err)
- return err
- }
- err = process.Kill()
- if err != nil {
- return err
- }
- }
- log.Println(os.Getpid(), srv.Addr)
- return srv.Serve()
- }
- // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
- // Serve to handle requests on incoming TLS connections.
- //
- // Filenames containing a certificate and matching private key for the server must
- // be provided. If the certificate is signed by a certificate authority, the
- // certFile should be the concatenation of the server's certificate followed by the
- // CA's certificate.
- //
- // If srv.Addr is blank, ":https" is used.
- func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) {
- addr := srv.Addr
- if addr == "" {
- addr = ":https"
- }
- if srv.TLSConfig == nil {
- srv.TLSConfig = &tls.Config{}
- }
- if srv.TLSConfig.NextProtos == nil {
- srv.TLSConfig.NextProtos = []string{"http/1.1"}
- }
- srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
- srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
- if err != nil {
- return
- }
- go srv.handleSignals()
- l, err := srv.getListener(addr)
- if err != nil {
- log.Println(err)
- return err
- }
- srv.tlsInnerListener = newGraceListener(l, srv)
- srv.GraceListener = tls.NewListener(srv.tlsInnerListener, srv.TLSConfig)
- if srv.isChild {
- process, err := os.FindProcess(os.Getppid())
- if err != nil {
- log.Println(err)
- return err
- }
- err = process.Kill()
- if err != nil {
- return err
- }
- }
- log.Println(os.Getpid(), srv.Addr)
- return srv.Serve()
- }
- // getListener either opens a new socket to listen on, or takes the acceptor socket
- // it got passed when restarted.
- func (srv *Server) getListener(laddr string) (l net.Listener, err error) {
- if srv.isChild {
- var ptrOffset uint
- if len(socketPtrOffsetMap) > 0 {
- ptrOffset = socketPtrOffsetMap[laddr]
- log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
- }
- f := os.NewFile(uintptr(3+ptrOffset), "")
- l, err = net.FileListener(f)
- if err != nil {
- err = fmt.Errorf("net.FileListener error: %v", err)
- return
- }
- } else {
- l, err = net.Listen(srv.Network, laddr)
- if err != nil {
- err = fmt.Errorf("net.Listen error: %v", err)
- return
- }
- }
- return
- }
- // handleSignals listens for os Signals and calls any hooked in function that the
- // user had registered with the signal.
- func (srv *Server) handleSignals() {
- var sig os.Signal
- signal.Notify(
- srv.sigChan,
- hookableSignals...,
- )
- pid := syscall.Getpid()
- for {
- sig = <-srv.sigChan
- srv.signalHooks(PreSignal, sig)
- switch sig {
- case syscall.SIGHUP:
- log.Println(pid, "Received SIGHUP. forking.")
- err := srv.fork()
- if err != nil {
- log.Println("Fork err:", err)
- }
- case syscall.SIGINT:
- log.Println(pid, "Received SIGINT.")
- srv.shutdown()
- case syscall.SIGTERM:
- log.Println(pid, "Received SIGTERM.")
- srv.shutdown()
- default:
- log.Printf("Received %v: nothing i care about...\n", sig)
- }
- srv.signalHooks(PostSignal, sig)
- }
- }
- func (srv *Server) signalHooks(ppFlag int, sig os.Signal) {
- if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
- return
- }
- for _, f := range srv.SignalHooks[ppFlag][sig] {
- f()
- }
- return
- }
- // shutdown closes the listener so that no new connections are accepted. it also
- // starts a goroutine that will serverTimeout (stop all running requests) the server
- // after DefaultTimeout.
- func (srv *Server) shutdown() {
- if srv.state != StateRunning {
- return
- }
- srv.state = StateShuttingDown
- if DefaultTimeout >= 0 {
- go srv.serverTimeout(DefaultTimeout)
- }
- err := srv.GraceListener.Close()
- if err != nil {
- log.Println(syscall.Getpid(), "Listener.Close() error:", err)
- } else {
- log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
- }
- }
- // serverTimeout forces the server to shutdown in a given timeout - whether it
- // finished outstanding requests or not. if Read/WriteTimeout are not set or the
- // max header size is very big a connection could hang
- func (srv *Server) serverTimeout(d time.Duration) {
- defer func() {
- if r := recover(); r != nil {
- log.Println("WaitGroup at 0", r)
- }
- }()
- if srv.state != StateShuttingDown {
- return
- }
- time.Sleep(d)
- log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
- for {
- if srv.state == StateTerminate {
- break
- }
- srv.wg.Done()
- }
- }
- func (srv *Server) fork() (err error) {
- regLock.Lock()
- defer regLock.Unlock()
- if runningServersForked {
- return
- }
- runningServersForked = true
- var files = make([]*os.File, len(runningServers))
- var orderArgs = make([]string, len(runningServers))
- for _, srvPtr := range runningServers {
- switch srvPtr.GraceListener.(type) {
- case *graceListener:
- files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
- default:
- files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
- }
- orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
- }
- log.Println(files)
- path := os.Args[0]
- var args []string
- if len(os.Args) > 1 {
- for _, arg := range os.Args[1:] {
- if arg == "-graceful" {
- break
- }
- args = append(args, arg)
- }
- }
- args = append(args, "-graceful")
- if len(runningServers) > 1 {
- args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
- log.Println(args)
- }
- cmd := exec.Command(path, args...)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- cmd.ExtraFiles = files
- err = cmd.Start()
- if err != nil {
- log.Fatalf("Restart: Failed to launch, error: %v", err)
- }
- return
- }
- // RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
- func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
- if ppFlag != PreSignal && ppFlag != PostSignal {
- err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal.")
- return
- }
- for _, s := range hookableSignals {
- if s == sig {
- srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
- return
- }
- }
- err = fmt.Errorf("Signal '%v' is not supported.", sig)
- return
- }
|