diff options
Diffstat (limited to 'plugins/redis/pubsub')
-rw-r--r-- | plugins/redis/pubsub/channel.go | 97 | ||||
-rw-r--r-- | plugins/redis/pubsub/config.go | 34 | ||||
-rw-r--r-- | plugins/redis/pubsub/pubsub.go | 183 |
3 files changed, 314 insertions, 0 deletions
diff --git a/plugins/redis/pubsub/channel.go b/plugins/redis/pubsub/channel.go new file mode 100644 index 00000000..a1655ab2 --- /dev/null +++ b/plugins/redis/pubsub/channel.go @@ -0,0 +1,97 @@ +package pubsub + +import ( + "context" + "sync" + + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/common/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/utils" +) + +type redisChannel struct { + sync.Mutex + + // redis client + client redis.UniversalClient + pubsub *redis.PubSub + + log logger.Logger + + // out channel with all subs + out chan *pubsub.Message + + exit chan struct{} +} + +func newRedisChannel(redisClient redis.UniversalClient, log logger.Logger) *redisChannel { + out := make(chan *pubsub.Message, 100) + fi := &redisChannel{ + out: out, + client: redisClient, + pubsub: redisClient.Subscribe(context.Background()), + exit: make(chan struct{}), + log: log, + } + + // start reading messages + go fi.read() + + return fi +} + +func (r *redisChannel) sub(topics ...string) error { + const op = errors.Op("redis_sub") + err := r.pubsub.Subscribe(context.Background(), topics...) + if err != nil { + return errors.E(op, err) + } + return nil +} + +// read reads messages from the pubsub subscription +func (r *redisChannel) read() { + for { + select { + // here we receive message from us (which we sent before in Publish) + // it should be compatible with the pubsub.Message structure + // payload should be in the redis.message.payload field + + case msg, ok := <-r.pubsub.Channel(): + // channel closed + if !ok { + return + } + + r.out <- &pubsub.Message{ + Topic: msg.Channel, + Payload: utils.AsBytes(msg.Payload), + } + + case <-r.exit: + return + } + } +} + +func (r *redisChannel) unsub(topic string) error { + const op = errors.Op("redis_unsub") + err := r.pubsub.Unsubscribe(context.Background(), topic) + if err != nil { + return errors.E(op, err) + } + return nil +} + +func (r *redisChannel) stop() error { + r.exit <- struct{}{} + close(r.out) + close(r.exit) + return nil +} + +func (r *redisChannel) message() chan *pubsub.Message { + return r.out +} diff --git a/plugins/redis/pubsub/config.go b/plugins/redis/pubsub/config.go new file mode 100644 index 00000000..bf8d2fc9 --- /dev/null +++ b/plugins/redis/pubsub/config.go @@ -0,0 +1,34 @@ +package pubsub + +import "time" + +type Config struct { + Addrs []string `mapstructure:"addrs"` + DB int `mapstructure:"db"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + MasterName string `mapstructure:"master_name"` + SentinelPassword string `mapstructure:"sentinel_password"` + RouteByLatency bool `mapstructure:"route_by_latency"` + RouteRandomly bool `mapstructure:"route_randomly"` + MaxRetries int `mapstructure:"max_retries"` + DialTimeout time.Duration `mapstructure:"dial_timeout"` + MinRetryBackoff time.Duration `mapstructure:"min_retry_backoff"` + MaxRetryBackoff time.Duration `mapstructure:"max_retry_backoff"` + PoolSize int `mapstructure:"pool_size"` + MinIdleConns int `mapstructure:"min_idle_conns"` + MaxConnAge time.Duration `mapstructure:"max_conn_age"` + ReadTimeout time.Duration `mapstructure:"read_timeout"` + WriteTimeout time.Duration `mapstructure:"write_timeout"` + PoolTimeout time.Duration `mapstructure:"pool_timeout"` + IdleTimeout time.Duration `mapstructure:"idle_timeout"` + IdleCheckFreq time.Duration `mapstructure:"idle_check_freq"` + ReadOnly bool `mapstructure:"read_only"` +} + +// InitDefaults initializing fill config with default values +func (s *Config) InitDefaults() { + if s.Addrs == nil { + s.Addrs = []string{"127.0.0.1:6379"} // default addr is pointing to local storage + } +} diff --git a/plugins/redis/pubsub/pubsub.go b/plugins/redis/pubsub/pubsub.go new file mode 100644 index 00000000..c9ad3d58 --- /dev/null +++ b/plugins/redis/pubsub/pubsub.go @@ -0,0 +1,183 @@ +package pubsub + +import ( + "context" + "sync" + + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/common/pubsub" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +type PubSubDriver struct { + sync.RWMutex + cfg *Config `mapstructure:"redis"` + + log logger.Logger + channel *redisChannel + universalClient redis.UniversalClient + stopCh chan struct{} +} + +func NewPubSubDriver(log logger.Logger, key string, cfgPlugin config.Configurer, stopCh chan struct{}) (*PubSubDriver, error) { + const op = errors.Op("new_pub_sub_driver") + ps := &PubSubDriver{ + log: log, + stopCh: stopCh, + } + + // will be different for every connected driver + err := cfgPlugin.UnmarshalKey(key, &ps.cfg) + if err != nil { + return nil, errors.E(op, err) + } + + ps.cfg.InitDefaults() + + ps.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: ps.cfg.Addrs, + DB: ps.cfg.DB, + Username: ps.cfg.Username, + Password: ps.cfg.Password, + SentinelPassword: ps.cfg.SentinelPassword, + MaxRetries: ps.cfg.MaxRetries, + MinRetryBackoff: ps.cfg.MaxRetryBackoff, + MaxRetryBackoff: ps.cfg.MaxRetryBackoff, + DialTimeout: ps.cfg.DialTimeout, + ReadTimeout: ps.cfg.ReadTimeout, + WriteTimeout: ps.cfg.WriteTimeout, + PoolSize: ps.cfg.PoolSize, + MinIdleConns: ps.cfg.MinIdleConns, + MaxConnAge: ps.cfg.MaxConnAge, + PoolTimeout: ps.cfg.PoolTimeout, + IdleTimeout: ps.cfg.IdleTimeout, + IdleCheckFrequency: ps.cfg.IdleCheckFreq, + ReadOnly: ps.cfg.ReadOnly, + RouteByLatency: ps.cfg.RouteByLatency, + RouteRandomly: ps.cfg.RouteRandomly, + MasterName: ps.cfg.MasterName, + }) + + statusCmd := ps.universalClient.Ping(context.Background()) + if statusCmd.Err() != nil { + return nil, statusCmd.Err() + } + + ps.channel = newRedisChannel(ps.universalClient, log) + + ps.stop() + + return ps, nil +} + +func (p *PubSubDriver) stop() { + go func() { + for range p.stopCh { + _ = p.channel.stop() + return + } + }() +} + +func (p *PubSubDriver) Publish(msg *pubsub.Message) error { + p.Lock() + defer p.Unlock() + + f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload) + if f.Err() != nil { + return f.Err() + } + + return nil +} + +func (p *PubSubDriver) PublishAsync(msg *pubsub.Message) { + go func() { + p.Lock() + defer p.Unlock() + + f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload) + if f.Err() != nil { + p.log.Error("redis publish", "error", f.Err()) + } + }() +} + +func (p *PubSubDriver) Subscribe(connectionID string, topics ...string) error { + // just add a connection + for i := 0; i < len(topics); i++ { + // key - topic + // value - connectionID + hset := p.universalClient.SAdd(context.Background(), topics[i], connectionID) + res, err := hset.Result() + if err != nil { + return err + } + if res == 0 { + p.log.Warn("could not subscribe to the provided topic, you might be already subscribed to it", "connectionID", connectionID, "topic", topics[i]) + continue + } + } + + // and subscribe after + return p.channel.sub(topics...) +} + +func (p *PubSubDriver) Unsubscribe(connectionID string, topics ...string) error { + // Remove topics from the storage + for i := 0; i < len(topics); i++ { + srem := p.universalClient.SRem(context.Background(), topics[i], connectionID) + if srem.Err() != nil { + return srem.Err() + } + } + + for i := 0; i < len(topics); i++ { + // if there are no such topics, we can safely unsubscribe from the redis + exists := p.universalClient.Exists(context.Background(), topics[i]) + res, err := exists.Result() + if err != nil { + return err + } + + // if we have associated connections - skip + if res == 1 { // exists means that topic still exists and some other nodes may have connections associated with it + continue + } + + // else - unsubscribe + err = p.channel.unsub(topics[i]) + if err != nil { + return err + } + } + + return nil +} + +func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { + hget := p.universalClient.SMembersMap(context.Background(), topic) + r, err := hget.Result() + if err != nil { + panic(err) + } + + // assign connections + // res expected to be from the sync.Pool + for k := range r { + res[k] = struct{}{} + } +} + +// Next return next message +func (p *PubSubDriver) Next(ctx context.Context) (*pubsub.Message, error) { + const op = errors.Op("redis_driver_next") + select { + case msg := <-p.channel.message(): + return msg, nil + case <-ctx.Done(): + return nil, errors.E(op, errors.TimeOut, ctx.Err()) + } +} |