You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

376 lines
9.0 KiB

  1. // Package httpdown provides http.ConnState enabled graceful termination of
  2. // http.Server.
  3. package httpdown
  4. import (
  5. "crypto/tls"
  6. "fmt"
  7. "net"
  8. "net/http"
  9. "os"
  10. "os/signal"
  11. "sync"
  12. "syscall"
  13. "time"
  14. "github.com/facebookgo/clock"
  15. "github.com/facebookgo/stats"
  16. )
  17. const (
  18. defaultStopTimeout = time.Minute
  19. defaultKillTimeout = time.Minute
  20. )
  21. // A Server allows encapsulates the process of accepting new connections and
  22. // serving them, and gracefully shutting down the listener without dropping
  23. // active connections.
  24. type Server interface {
  25. // Wait waits for the serving loop to finish. This will happen when Stop is
  26. // called, at which point it returns no error, or if there is an error in the
  27. // serving loop. You must call Wait after calling Serve or ListenAndServe.
  28. Wait() error
  29. // Stop stops the listener. It will block until all connections have been
  30. // closed.
  31. Stop() error
  32. }
  33. // HTTP defines the configuration for serving a http.Server. Multiple calls to
  34. // Serve or ListenAndServe can be made on the same HTTP instance. The default
  35. // timeouts of 1 minute each result in a maximum of 2 minutes before a Stop()
  36. // returns.
  37. type HTTP struct {
  38. // StopTimeout is the duration before we begin force closing connections.
  39. // Defaults to 1 minute.
  40. StopTimeout time.Duration
  41. // KillTimeout is the duration before which we completely give up and abort
  42. // even though we still have connected clients. This is useful when a large
  43. // number of client connections exist and closing them can take a long time.
  44. // Note, this is in addition to the StopTimeout. Defaults to 1 minute.
  45. KillTimeout time.Duration
  46. // Stats is optional. If provided, it will be used to record various metrics.
  47. Stats stats.Client
  48. // Clock allows for testing timing related functionality. Do not specify this
  49. // in production code.
  50. Clock clock.Clock
  51. }
  52. // Serve provides the low-level API which is useful if you're creating your own
  53. // net.Listener.
  54. func (h HTTP) Serve(s *http.Server, l net.Listener) Server {
  55. stopTimeout := h.StopTimeout
  56. if stopTimeout == 0 {
  57. stopTimeout = defaultStopTimeout
  58. }
  59. killTimeout := h.KillTimeout
  60. if killTimeout == 0 {
  61. killTimeout = defaultKillTimeout
  62. }
  63. klock := h.Clock
  64. if klock == nil {
  65. klock = clock.New()
  66. }
  67. ss := &server{
  68. stopTimeout: stopTimeout,
  69. killTimeout: killTimeout,
  70. stats: h.Stats,
  71. clock: klock,
  72. oldConnState: s.ConnState,
  73. listener: l,
  74. server: s,
  75. serveDone: make(chan struct{}),
  76. serveErr: make(chan error, 1),
  77. new: make(chan net.Conn),
  78. active: make(chan net.Conn),
  79. idle: make(chan net.Conn),
  80. closed: make(chan net.Conn),
  81. stop: make(chan chan struct{}),
  82. kill: make(chan chan struct{}),
  83. }
  84. s.ConnState = ss.connState
  85. go ss.manage()
  86. go ss.serve()
  87. return ss
  88. }
  89. // ListenAndServe returns a Server for the given http.Server. It is equivalent
  90. // to ListenAndServe from the standard library, but returns immediately.
  91. // Requests will be accepted in a background goroutine. If the http.Server has
  92. // a non-nil TLSConfig, a TLS enabled listener will be setup.
  93. func (h HTTP) ListenAndServe(s *http.Server) (Server, error) {
  94. addr := s.Addr
  95. if addr == "" {
  96. if s.TLSConfig == nil {
  97. addr = ":http"
  98. } else {
  99. addr = ":https"
  100. }
  101. }
  102. l, err := net.Listen("tcp", addr)
  103. if err != nil {
  104. stats.BumpSum(h.Stats, "listen.error", 1)
  105. return nil, err
  106. }
  107. if s.TLSConfig != nil {
  108. l = tls.NewListener(l, s.TLSConfig)
  109. }
  110. return h.Serve(s, l), nil
  111. }
  112. // server manages the serving process and allows for gracefully stopping it.
  113. type server struct {
  114. stopTimeout time.Duration
  115. killTimeout time.Duration
  116. stats stats.Client
  117. clock clock.Clock
  118. oldConnState func(net.Conn, http.ConnState)
  119. server *http.Server
  120. serveDone chan struct{}
  121. serveErr chan error
  122. listener net.Listener
  123. new chan net.Conn
  124. active chan net.Conn
  125. idle chan net.Conn
  126. closed chan net.Conn
  127. stop chan chan struct{}
  128. kill chan chan struct{}
  129. stopOnce sync.Once
  130. stopErr error
  131. }
  132. func (s *server) connState(c net.Conn, cs http.ConnState) {
  133. if s.oldConnState != nil {
  134. s.oldConnState(c, cs)
  135. }
  136. switch cs {
  137. case http.StateNew:
  138. s.new <- c
  139. case http.StateActive:
  140. s.active <- c
  141. case http.StateIdle:
  142. s.idle <- c
  143. case http.StateHijacked, http.StateClosed:
  144. s.closed <- c
  145. }
  146. }
  147. func (s *server) manage() {
  148. defer func() {
  149. close(s.new)
  150. close(s.active)
  151. close(s.idle)
  152. close(s.closed)
  153. close(s.stop)
  154. close(s.kill)
  155. }()
  156. var stopDone chan struct{}
  157. conns := map[net.Conn]http.ConnState{}
  158. var countNew, countActive, countIdle float64
  159. // decConn decrements the count associated with the current state of the
  160. // given connection.
  161. decConn := func(c net.Conn) {
  162. switch conns[c] {
  163. default:
  164. panic(fmt.Errorf("unknown existing connection: %s", c))
  165. case http.StateNew:
  166. countNew--
  167. case http.StateActive:
  168. countActive--
  169. case http.StateIdle:
  170. countIdle--
  171. }
  172. }
  173. // setup a ticker to report various values every minute. if we don't have a
  174. // Stats implementation provided, we Stop it so it never ticks.
  175. statsTicker := s.clock.Ticker(time.Minute)
  176. if s.stats == nil {
  177. statsTicker.Stop()
  178. }
  179. for {
  180. select {
  181. case <-statsTicker.C:
  182. // we'll only get here when s.stats is not nil
  183. s.stats.BumpAvg("http-state.new", countNew)
  184. s.stats.BumpAvg("http-state.active", countActive)
  185. s.stats.BumpAvg("http-state.idle", countIdle)
  186. s.stats.BumpAvg("http-state.total", countNew+countActive+countIdle)
  187. case c := <-s.new:
  188. conns[c] = http.StateNew
  189. countNew++
  190. case c := <-s.active:
  191. decConn(c)
  192. countActive++
  193. conns[c] = http.StateActive
  194. case c := <-s.idle:
  195. decConn(c)
  196. countIdle++
  197. conns[c] = http.StateIdle
  198. // if we're already stopping, close it
  199. if stopDone != nil {
  200. c.Close()
  201. }
  202. case c := <-s.closed:
  203. stats.BumpSum(s.stats, "conn.closed", 1)
  204. decConn(c)
  205. delete(conns, c)
  206. // if we're waiting to stop and are all empty, we just closed the last
  207. // connection and we're done.
  208. if stopDone != nil && len(conns) == 0 {
  209. close(stopDone)
  210. return
  211. }
  212. case stopDone = <-s.stop:
  213. // if we're already all empty, we're already done
  214. if len(conns) == 0 {
  215. close(stopDone)
  216. return
  217. }
  218. // close current idle connections right away
  219. for c, cs := range conns {
  220. if cs == http.StateIdle {
  221. c.Close()
  222. }
  223. }
  224. // continue the loop and wait for all the ConnState updates which will
  225. // eventually close(stopDone) and return from this goroutine.
  226. case killDone := <-s.kill:
  227. // force close all connections
  228. stats.BumpSum(s.stats, "kill.conn.count", float64(len(conns)))
  229. for c := range conns {
  230. c.Close()
  231. }
  232. // don't block the kill.
  233. close(killDone)
  234. // continue the loop and we wait for all the ConnState updates and will
  235. // return from this goroutine when we're all done. otherwise we'll try to
  236. // send those ConnState updates on closed channels.
  237. }
  238. }
  239. }
  240. func (s *server) serve() {
  241. stats.BumpSum(s.stats, "serve", 1)
  242. s.serveErr <- s.server.Serve(s.listener)
  243. close(s.serveDone)
  244. close(s.serveErr)
  245. }
  246. func (s *server) Wait() error {
  247. if err := <-s.serveErr; !isUseOfClosedError(err) {
  248. return err
  249. }
  250. return nil
  251. }
  252. func (s *server) Stop() error {
  253. s.stopOnce.Do(func() {
  254. defer stats.BumpTime(s.stats, "stop.time").End()
  255. stats.BumpSum(s.stats, "stop", 1)
  256. // first disable keep-alive for new connections
  257. s.server.SetKeepAlivesEnabled(false)
  258. // then close the listener so new connections can't connect come thru
  259. closeErr := s.listener.Close()
  260. <-s.serveDone
  261. // then trigger the background goroutine to stop and wait for it
  262. stopDone := make(chan struct{})
  263. s.stop <- stopDone
  264. // wait for stop
  265. select {
  266. case <-stopDone:
  267. case <-s.clock.After(s.stopTimeout):
  268. defer stats.BumpTime(s.stats, "kill.time").End()
  269. stats.BumpSum(s.stats, "kill", 1)
  270. // stop timed out, wait for kill
  271. killDone := make(chan struct{})
  272. s.kill <- killDone
  273. select {
  274. case <-killDone:
  275. case <-s.clock.After(s.killTimeout):
  276. // kill timed out, give up
  277. stats.BumpSum(s.stats, "kill.timeout", 1)
  278. }
  279. }
  280. if closeErr != nil && !isUseOfClosedError(closeErr) {
  281. stats.BumpSum(s.stats, "listener.close.error", 1)
  282. s.stopErr = closeErr
  283. }
  284. })
  285. return s.stopErr
  286. }
  287. func isUseOfClosedError(err error) bool {
  288. if err == nil {
  289. return false
  290. }
  291. if opErr, ok := err.(*net.OpError); ok {
  292. err = opErr.Err
  293. }
  294. return err.Error() == "use of closed network connection"
  295. }
  296. // ListenAndServe is a convenience function to serve and wait for a SIGTERM
  297. // or SIGINT before shutting down.
  298. func ListenAndServe(s *http.Server, hd *HTTP) error {
  299. if hd == nil {
  300. hd = &HTTP{}
  301. }
  302. hs, err := hd.ListenAndServe(s)
  303. if err != nil {
  304. return err
  305. }
  306. waiterr := make(chan error, 1)
  307. go func() {
  308. defer close(waiterr)
  309. waiterr <- hs.Wait()
  310. }()
  311. signals := make(chan os.Signal, 10)
  312. signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
  313. select {
  314. case err := <-waiterr:
  315. if err != nil {
  316. return err
  317. }
  318. case <-signals:
  319. signal.Stop(signals)
  320. if err := hs.Stop(); err != nil {
  321. return err
  322. }
  323. if err := <-waiterr; err != nil {
  324. return err
  325. }
  326. }
  327. return nil
  328. }