|
|
- package redis
-
- import (
- "errors"
- "fmt"
- "sync"
- "time"
-
- "github.com/go-redis/redis/internal"
- "github.com/go-redis/redis/internal/pool"
- "github.com/go-redis/redis/internal/proto"
- )
-
- var errPingTimeout = errors.New("redis: ping timeout")
-
- // PubSub implements Pub/Sub commands bas described in
- // http://redis.io/topics/pubsub. Message receiving is NOT safe
- // for concurrent use by multiple goroutines.
- //
- // PubSub automatically reconnects to Redis Server and resubscribes
- // to the channels in case of network errors.
- type PubSub struct {
- opt *Options
-
- newConn func([]string) (*pool.Conn, error)
- closeConn func(*pool.Conn) error
-
- mu sync.Mutex
- cn *pool.Conn
- channels map[string]struct{}
- patterns map[string]struct{}
- closed bool
- exit chan struct{}
-
- cmd *Cmd
-
- chOnce sync.Once
- ch chan *Message
- ping chan struct{}
- }
-
- func (c *PubSub) init() {
- c.exit = make(chan struct{})
- }
-
- func (c *PubSub) conn() (*pool.Conn, error) {
- c.mu.Lock()
- cn, err := c._conn(nil)
- c.mu.Unlock()
- return cn, err
- }
-
- func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
- if c.closed {
- return nil, pool.ErrClosed
- }
- if c.cn != nil {
- return c.cn, nil
- }
-
- channels := mapKeys(c.channels)
- channels = append(channels, newChannels...)
-
- cn, err := c.newConn(channels)
- if err != nil {
- return nil, err
- }
-
- if err := c.resubscribe(cn); err != nil {
- _ = c.closeConn(cn)
- return nil, err
- }
-
- c.cn = cn
- return cn, nil
- }
-
- func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
- return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
- return writeCmd(wr, cmd)
- })
- }
-
- func (c *PubSub) resubscribe(cn *pool.Conn) error {
- var firstErr error
-
- if len(c.channels) > 0 {
- err := c._subscribe(cn, "subscribe", mapKeys(c.channels))
- if err != nil && firstErr == nil {
- firstErr = err
- }
- }
-
- if len(c.patterns) > 0 {
- err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns))
- if err != nil && firstErr == nil {
- firstErr = err
- }
- }
-
- return firstErr
- }
-
- func mapKeys(m map[string]struct{}) []string {
- s := make([]string, len(m))
- i := 0
- for k := range m {
- s[i] = k
- i++
- }
- return s
- }
-
- func (c *PubSub) _subscribe(
- cn *pool.Conn, redisCmd string, channels []string,
- ) error {
- args := make([]interface{}, 0, 1+len(channels))
- args = append(args, redisCmd)
- for _, channel := range channels {
- args = append(args, channel)
- }
- cmd := NewSliceCmd(args...)
- return c.writeCmd(cn, cmd)
- }
-
- func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
- c.mu.Lock()
- c._releaseConn(cn, err, allowTimeout)
- c.mu.Unlock()
- }
-
- func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
- if c.cn != cn {
- return
- }
- if internal.IsBadConn(err, allowTimeout) {
- c._reconnect(err)
- }
- }
-
- func (c *PubSub) _reconnect(reason error) {
- _ = c._closeTheCn(reason)
- _, _ = c._conn(nil)
- }
-
- func (c *PubSub) _closeTheCn(reason error) error {
- if c.cn == nil {
- return nil
- }
- if !c.closed {
- internal.Logf("redis: discarding bad PubSub connection: %s", reason)
- }
- err := c.closeConn(c.cn)
- c.cn = nil
- return err
- }
-
- func (c *PubSub) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- if c.closed {
- return pool.ErrClosed
- }
- c.closed = true
- close(c.exit)
-
- err := c._closeTheCn(pool.ErrClosed)
- return err
- }
-
- // Subscribe the client to the specified channels. It returns
- // empty subscription if there are no channels.
- func (c *PubSub) Subscribe(channels ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- err := c.subscribe("subscribe", channels...)
- if c.channels == nil {
- c.channels = make(map[string]struct{})
- }
- for _, s := range channels {
- c.channels[s] = struct{}{}
- }
- return err
- }
-
- // PSubscribe the client to the given patterns. It returns
- // empty subscription if there are no patterns.
- func (c *PubSub) PSubscribe(patterns ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- err := c.subscribe("psubscribe", patterns...)
- if c.patterns == nil {
- c.patterns = make(map[string]struct{})
- }
- for _, s := range patterns {
- c.patterns[s] = struct{}{}
- }
- return err
- }
-
- // Unsubscribe the client from the given channels, or from all of
- // them if none is given.
- func (c *PubSub) Unsubscribe(channels ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- for _, channel := range channels {
- delete(c.channels, channel)
- }
- err := c.subscribe("unsubscribe", channels...)
- return err
- }
-
- // PUnsubscribe the client from the given patterns, or from all of
- // them if none is given.
- func (c *PubSub) PUnsubscribe(patterns ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- for _, pattern := range patterns {
- delete(c.patterns, pattern)
- }
- err := c.subscribe("punsubscribe", patterns...)
- return err
- }
-
- func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
- cn, err := c._conn(channels)
- if err != nil {
- return err
- }
-
- err = c._subscribe(cn, redisCmd, channels)
- c._releaseConn(cn, err, false)
- return err
- }
-
- func (c *PubSub) Ping(payload ...string) error {
- args := []interface{}{"ping"}
- if len(payload) == 1 {
- args = append(args, payload[0])
- }
- cmd := NewCmd(args...)
-
- cn, err := c.conn()
- if err != nil {
- return err
- }
-
- err = c.writeCmd(cn, cmd)
- c.releaseConn(cn, err, false)
- return err
- }
-
- // Subscription received after a successful subscription to channel.
- type Subscription struct {
- // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
- Kind string
- // Channel name we have subscribed to.
- Channel string
- // Number of channels we are currently subscribed to.
- Count int
- }
-
- func (m *Subscription) String() string {
- return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
- }
-
- // Message received as result of a PUBLISH command issued by another client.
- type Message struct {
- Channel string
- Pattern string
- Payload string
- }
-
- func (m *Message) String() string {
- return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
- }
-
- // Pong received as result of a PING command issued by another client.
- type Pong struct {
- Payload string
- }
-
- func (p *Pong) String() string {
- if p.Payload != "" {
- return fmt.Sprintf("Pong<%s>", p.Payload)
- }
- return "Pong"
- }
-
- func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
- switch reply := reply.(type) {
- case string:
- return &Pong{
- Payload: reply,
- }, nil
- case []interface{}:
- switch kind := reply[0].(string); kind {
- case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
- return &Subscription{
- Kind: kind,
- Channel: reply[1].(string),
- Count: int(reply[2].(int64)),
- }, nil
- case "message":
- return &Message{
- Channel: reply[1].(string),
- Payload: reply[2].(string),
- }, nil
- case "pmessage":
- return &Message{
- Pattern: reply[1].(string),
- Channel: reply[2].(string),
- Payload: reply[3].(string),
- }, nil
- case "pong":
- return &Pong{
- Payload: reply[1].(string),
- }, nil
- default:
- return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
- }
- default:
- return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
- }
- }
-
- // ReceiveTimeout acts like Receive but returns an error if message
- // is not received in time. This is low-level API and in most cases
- // Channel should be used instead.
- func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
- if c.cmd == nil {
- c.cmd = NewCmd()
- }
-
- cn, err := c.conn()
- if err != nil {
- return nil, err
- }
-
- err = cn.WithReader(timeout, func(rd *proto.Reader) error {
- return c.cmd.readReply(rd)
- })
-
- c.releaseConn(cn, err, timeout > 0)
- if err != nil {
- return nil, err
- }
-
- return c.newMessage(c.cmd.Val())
- }
-
- // Receive returns a message as a Subscription, Message, Pong or error.
- // See PubSub example for details. This is low-level API and in most cases
- // Channel should be used instead.
- func (c *PubSub) Receive() (interface{}, error) {
- return c.ReceiveTimeout(0)
- }
-
- // ReceiveMessage returns a Message or error ignoring Subscription and Pong
- // messages. This is low-level API and in most cases Channel should be used
- // instead.
- func (c *PubSub) ReceiveMessage() (*Message, error) {
- for {
- msg, err := c.Receive()
- if err != nil {
- return nil, err
- }
-
- switch msg := msg.(type) {
- case *Subscription:
- // Ignore.
- case *Pong:
- // Ignore.
- case *Message:
- return msg, nil
- default:
- err := fmt.Errorf("redis: unknown message: %T", msg)
- return nil, err
- }
- }
- }
-
- // Channel returns a Go channel for concurrently receiving messages.
- // It periodically sends Ping messages to test connection health.
- // The channel is closed with PubSub. Receive* APIs can not be used
- // after channel is created.
- func (c *PubSub) Channel() <-chan *Message {
- c.chOnce.Do(c.initChannel)
- return c.ch
- }
-
- func (c *PubSub) initChannel() {
- c.ch = make(chan *Message, 100)
- c.ping = make(chan struct{}, 10)
-
- go func() {
- var errCount int
- for {
- msg, err := c.Receive()
- if err != nil {
- if err == pool.ErrClosed {
- close(c.ch)
- return
- }
- if errCount > 0 {
- time.Sleep(c.retryBackoff(errCount))
- }
- errCount++
- continue
- }
- errCount = 0
-
- // Any message is as good as a ping.
- select {
- case c.ping <- struct{}{}:
- default:
- }
-
- switch msg := msg.(type) {
- case *Subscription:
- // Ignore.
- case *Pong:
- // Ignore.
- case *Message:
- c.ch <- msg
- default:
- internal.Logf("redis: unknown message: %T", msg)
- }
- }
- }()
-
- go func() {
- const timeout = 5 * time.Second
-
- timer := time.NewTimer(timeout)
- timer.Stop()
-
- healthy := true
- for {
- timer.Reset(timeout)
- select {
- case <-c.ping:
- healthy = true
- if !timer.Stop() {
- <-timer.C
- }
- case <-timer.C:
- pingErr := c.Ping()
- if healthy {
- healthy = false
- } else {
- if pingErr == nil {
- pingErr = errPingTimeout
- }
- c.mu.Lock()
- c._reconnect(pingErr)
- c.mu.Unlock()
- }
- case <-c.exit:
- return
- }
- }
- }()
- }
-
- func (c *PubSub) retryBackoff(attempt int) time.Duration {
- return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
- }
|