You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
6.0 KiB

  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. // This code is highly inspired by endless go
  5. package graceful
  6. import (
  7. "crypto/tls"
  8. "net"
  9. "os"
  10. "strings"
  11. "sync"
  12. "syscall"
  13. "time"
  14. "code.gitea.io/gitea/modules/log"
  15. )
  16. var (
  17. // DefaultReadTimeOut default read timeout
  18. DefaultReadTimeOut time.Duration
  19. // DefaultWriteTimeOut default write timeout
  20. DefaultWriteTimeOut time.Duration
  21. // DefaultMaxHeaderBytes default max header bytes
  22. DefaultMaxHeaderBytes int
  23. )
  24. func init() {
  25. DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
  26. }
  27. // ServeFunction represents a listen.Accept loop
  28. type ServeFunction = func(net.Listener) error
  29. // Server represents our graceful server
  30. type Server struct {
  31. network string
  32. address string
  33. listener net.Listener
  34. wg sync.WaitGroup
  35. state state
  36. lock *sync.RWMutex
  37. BeforeBegin func(network, address string)
  38. OnShutdown func()
  39. }
  40. // NewServer creates a server on network at provided address
  41. func NewServer(network, address string) *Server {
  42. if GetManager().IsChild() {
  43. log.Info("Restarting new server: %s:%s on PID: %d", network, address, os.Getpid())
  44. } else {
  45. log.Info("Starting new server: %s:%s on PID: %d", network, address, os.Getpid())
  46. }
  47. srv := &Server{
  48. wg: sync.WaitGroup{},
  49. state: stateInit,
  50. lock: &sync.RWMutex{},
  51. network: network,
  52. address: address,
  53. }
  54. srv.BeforeBegin = func(network, addr string) {
  55. log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
  56. }
  57. return srv
  58. }
  59. // ListenAndServe listens on the provided network address and then calls Serve
  60. // to handle requests on incoming connections.
  61. func (srv *Server) ListenAndServe(serve ServeFunction) error {
  62. go srv.awaitShutdown()
  63. l, err := GetListener(srv.network, srv.address)
  64. if err != nil {
  65. log.Error("Unable to GetListener: %v", err)
  66. return err
  67. }
  68. srv.listener = newWrappedListener(l, srv)
  69. srv.BeforeBegin(srv.network, srv.address)
  70. return srv.Serve(serve)
  71. }
  72. // ListenAndServeTLS listens on the provided network address and then calls
  73. // Serve to handle requests on incoming TLS connections.
  74. //
  75. // Filenames containing a certificate and matching private key for the server must
  76. // be provided. If the certificate is signed by a certificate authority, the
  77. // certFile should be the concatenation of the server's certificate followed by the
  78. // CA's certificate.
  79. func (srv *Server) ListenAndServeTLS(certFile, keyFile string, serve ServeFunction) error {
  80. config := &tls.Config{}
  81. if config.NextProtos == nil {
  82. config.NextProtos = []string{"http/1.1"}
  83. }
  84. config.Certificates = make([]tls.Certificate, 1)
  85. var err error
  86. config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  87. if err != nil {
  88. log.Error("Failed to load https cert file %s for %s:%s: %v", certFile, srv.network, srv.address, err)
  89. return err
  90. }
  91. return srv.ListenAndServeTLSConfig(config, serve)
  92. }
  93. // ListenAndServeTLSConfig listens on the provided network address and then calls
  94. // Serve to handle requests on incoming TLS connections.
  95. func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction) error {
  96. go srv.awaitShutdown()
  97. l, err := GetListener(srv.network, srv.address)
  98. if err != nil {
  99. log.Error("Unable to get Listener: %v", err)
  100. return err
  101. }
  102. wl := newWrappedListener(l, srv)
  103. srv.listener = tls.NewListener(wl, tlsConfig)
  104. srv.BeforeBegin(srv.network, srv.address)
  105. return srv.Serve(serve)
  106. }
  107. // Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
  108. // service goroutine for each. The service goroutines read requests and then call
  109. // handler to reply to them. Handler is typically nil, in which case the
  110. // DefaultServeMux is used.
  111. //
  112. // In addition to the standard Serve behaviour each connection is added to a
  113. // sync.Waitgroup so that all outstanding connections can be served before shutting
  114. // down the server.
  115. func (srv *Server) Serve(serve ServeFunction) error {
  116. defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
  117. srv.setState(stateRunning)
  118. GetManager().RegisterServer()
  119. err := serve(srv.listener)
  120. log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
  121. srv.wg.Wait()
  122. srv.setState(stateTerminate)
  123. GetManager().ServerDone()
  124. // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
  125. if err != nil && strings.Contains(err.Error(), "use of closed") {
  126. return nil
  127. }
  128. return err
  129. }
  130. func (srv *Server) getState() state {
  131. srv.lock.RLock()
  132. defer srv.lock.RUnlock()
  133. return srv.state
  134. }
  135. func (srv *Server) setState(st state) {
  136. srv.lock.Lock()
  137. defer srv.lock.Unlock()
  138. srv.state = st
  139. }
  140. type filer interface {
  141. File() (*os.File, error)
  142. }
  143. type wrappedListener struct {
  144. net.Listener
  145. stopped bool
  146. server *Server
  147. }
  148. func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
  149. return &wrappedListener{
  150. Listener: l,
  151. server: srv,
  152. }
  153. }
  154. func (wl *wrappedListener) Accept() (net.Conn, error) {
  155. var c net.Conn
  156. // Set keepalive on TCPListeners connections.
  157. if tcl, ok := wl.Listener.(*net.TCPListener); ok {
  158. tc, err := tcl.AcceptTCP()
  159. if err != nil {
  160. return nil, err
  161. }
  162. _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
  163. _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
  164. c = tc
  165. } else {
  166. var err error
  167. c, err = wl.Listener.Accept()
  168. if err != nil {
  169. return nil, err
  170. }
  171. }
  172. c = wrappedConn{
  173. Conn: c,
  174. server: wl.server,
  175. }
  176. wl.server.wg.Add(1)
  177. return c, nil
  178. }
  179. func (wl *wrappedListener) Close() error {
  180. if wl.stopped {
  181. return syscall.EINVAL
  182. }
  183. wl.stopped = true
  184. return wl.Listener.Close()
  185. }
  186. func (wl *wrappedListener) File() (*os.File, error) {
  187. // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
  188. return wl.Listener.(filer).File()
  189. }
  190. type wrappedConn struct {
  191. net.Conn
  192. server *Server
  193. }
  194. func (w wrappedConn) Close() error {
  195. err := w.Conn.Close()
  196. if err == nil {
  197. w.server.wg.Done()
  198. }
  199. return err
  200. }