summaryrefslogtreecommitdiff
path: root/plugins/redis/pubsub
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/redis/pubsub')
-rw-r--r--plugins/redis/pubsub/channel.go97
-rw-r--r--plugins/redis/pubsub/config.go34
-rw-r--r--plugins/redis/pubsub/pubsub.go183
3 files changed, 314 insertions, 0 deletions
diff --git a/plugins/redis/pubsub/channel.go b/plugins/redis/pubsub/channel.go
new file mode 100644
index 00000000..a1655ab2
--- /dev/null
+++ b/plugins/redis/pubsub/channel.go
@@ -0,0 +1,97 @@
+package pubsub
+
+import (
+ "context"
+ "sync"
+
+ "github.com/go-redis/redis/v8"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/common/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/utils"
+)
+
+type redisChannel struct {
+ sync.Mutex
+
+ // redis client
+ client redis.UniversalClient
+ pubsub *redis.PubSub
+
+ log logger.Logger
+
+ // out channel with all subs
+ out chan *pubsub.Message
+
+ exit chan struct{}
+}
+
+func newRedisChannel(redisClient redis.UniversalClient, log logger.Logger) *redisChannel {
+ out := make(chan *pubsub.Message, 100)
+ fi := &redisChannel{
+ out: out,
+ client: redisClient,
+ pubsub: redisClient.Subscribe(context.Background()),
+ exit: make(chan struct{}),
+ log: log,
+ }
+
+ // start reading messages
+ go fi.read()
+
+ return fi
+}
+
+func (r *redisChannel) sub(topics ...string) error {
+ const op = errors.Op("redis_sub")
+ err := r.pubsub.Subscribe(context.Background(), topics...)
+ if err != nil {
+ return errors.E(op, err)
+ }
+ return nil
+}
+
+// read reads messages from the pubsub subscription
+func (r *redisChannel) read() {
+ for {
+ select {
+ // here we receive message from us (which we sent before in Publish)
+ // it should be compatible with the pubsub.Message structure
+ // payload should be in the redis.message.payload field
+
+ case msg, ok := <-r.pubsub.Channel():
+ // channel closed
+ if !ok {
+ return
+ }
+
+ r.out <- &pubsub.Message{
+ Topic: msg.Channel,
+ Payload: utils.AsBytes(msg.Payload),
+ }
+
+ case <-r.exit:
+ return
+ }
+ }
+}
+
+func (r *redisChannel) unsub(topic string) error {
+ const op = errors.Op("redis_unsub")
+ err := r.pubsub.Unsubscribe(context.Background(), topic)
+ if err != nil {
+ return errors.E(op, err)
+ }
+ return nil
+}
+
+func (r *redisChannel) stop() error {
+ r.exit <- struct{}{}
+ close(r.out)
+ close(r.exit)
+ return nil
+}
+
+func (r *redisChannel) message() chan *pubsub.Message {
+ return r.out
+}
diff --git a/plugins/redis/pubsub/config.go b/plugins/redis/pubsub/config.go
new file mode 100644
index 00000000..bf8d2fc9
--- /dev/null
+++ b/plugins/redis/pubsub/config.go
@@ -0,0 +1,34 @@
+package pubsub
+
+import "time"
+
+type Config struct {
+ Addrs []string `mapstructure:"addrs"`
+ DB int `mapstructure:"db"`
+ Username string `mapstructure:"username"`
+ Password string `mapstructure:"password"`
+ MasterName string `mapstructure:"master_name"`
+ SentinelPassword string `mapstructure:"sentinel_password"`
+ RouteByLatency bool `mapstructure:"route_by_latency"`
+ RouteRandomly bool `mapstructure:"route_randomly"`
+ MaxRetries int `mapstructure:"max_retries"`
+ DialTimeout time.Duration `mapstructure:"dial_timeout"`
+ MinRetryBackoff time.Duration `mapstructure:"min_retry_backoff"`
+ MaxRetryBackoff time.Duration `mapstructure:"max_retry_backoff"`
+ PoolSize int `mapstructure:"pool_size"`
+ MinIdleConns int `mapstructure:"min_idle_conns"`
+ MaxConnAge time.Duration `mapstructure:"max_conn_age"`
+ ReadTimeout time.Duration `mapstructure:"read_timeout"`
+ WriteTimeout time.Duration `mapstructure:"write_timeout"`
+ PoolTimeout time.Duration `mapstructure:"pool_timeout"`
+ IdleTimeout time.Duration `mapstructure:"idle_timeout"`
+ IdleCheckFreq time.Duration `mapstructure:"idle_check_freq"`
+ ReadOnly bool `mapstructure:"read_only"`
+}
+
+// InitDefaults initializing fill config with default values
+func (s *Config) InitDefaults() {
+ if s.Addrs == nil {
+ s.Addrs = []string{"127.0.0.1:6379"} // default addr is pointing to local storage
+ }
+}
diff --git a/plugins/redis/pubsub/pubsub.go b/plugins/redis/pubsub/pubsub.go
new file mode 100644
index 00000000..c9ad3d58
--- /dev/null
+++ b/plugins/redis/pubsub/pubsub.go
@@ -0,0 +1,183 @@
+package pubsub
+
+import (
+ "context"
+ "sync"
+
+ "github.com/go-redis/redis/v8"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/common/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+)
+
+type PubSubDriver struct {
+ sync.RWMutex
+ cfg *Config `mapstructure:"redis"`
+
+ log logger.Logger
+ channel *redisChannel
+ universalClient redis.UniversalClient
+ stopCh chan struct{}
+}
+
+func NewPubSubDriver(log logger.Logger, key string, cfgPlugin config.Configurer, stopCh chan struct{}) (*PubSubDriver, error) {
+ const op = errors.Op("new_pub_sub_driver")
+ ps := &PubSubDriver{
+ log: log,
+ stopCh: stopCh,
+ }
+
+ // will be different for every connected driver
+ err := cfgPlugin.UnmarshalKey(key, &ps.cfg)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ ps.cfg.InitDefaults()
+
+ ps.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{
+ Addrs: ps.cfg.Addrs,
+ DB: ps.cfg.DB,
+ Username: ps.cfg.Username,
+ Password: ps.cfg.Password,
+ SentinelPassword: ps.cfg.SentinelPassword,
+ MaxRetries: ps.cfg.MaxRetries,
+ MinRetryBackoff: ps.cfg.MaxRetryBackoff,
+ MaxRetryBackoff: ps.cfg.MaxRetryBackoff,
+ DialTimeout: ps.cfg.DialTimeout,
+ ReadTimeout: ps.cfg.ReadTimeout,
+ WriteTimeout: ps.cfg.WriteTimeout,
+ PoolSize: ps.cfg.PoolSize,
+ MinIdleConns: ps.cfg.MinIdleConns,
+ MaxConnAge: ps.cfg.MaxConnAge,
+ PoolTimeout: ps.cfg.PoolTimeout,
+ IdleTimeout: ps.cfg.IdleTimeout,
+ IdleCheckFrequency: ps.cfg.IdleCheckFreq,
+ ReadOnly: ps.cfg.ReadOnly,
+ RouteByLatency: ps.cfg.RouteByLatency,
+ RouteRandomly: ps.cfg.RouteRandomly,
+ MasterName: ps.cfg.MasterName,
+ })
+
+ statusCmd := ps.universalClient.Ping(context.Background())
+ if statusCmd.Err() != nil {
+ return nil, statusCmd.Err()
+ }
+
+ ps.channel = newRedisChannel(ps.universalClient, log)
+
+ ps.stop()
+
+ return ps, nil
+}
+
+func (p *PubSubDriver) stop() {
+ go func() {
+ for range p.stopCh {
+ _ = p.channel.stop()
+ return
+ }
+ }()
+}
+
+func (p *PubSubDriver) Publish(msg *pubsub.Message) error {
+ p.Lock()
+ defer p.Unlock()
+
+ f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload)
+ if f.Err() != nil {
+ return f.Err()
+ }
+
+ return nil
+}
+
+func (p *PubSubDriver) PublishAsync(msg *pubsub.Message) {
+ go func() {
+ p.Lock()
+ defer p.Unlock()
+
+ f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload)
+ if f.Err() != nil {
+ p.log.Error("redis publish", "error", f.Err())
+ }
+ }()
+}
+
+func (p *PubSubDriver) Subscribe(connectionID string, topics ...string) error {
+ // just add a connection
+ for i := 0; i < len(topics); i++ {
+ // key - topic
+ // value - connectionID
+ hset := p.universalClient.SAdd(context.Background(), topics[i], connectionID)
+ res, err := hset.Result()
+ if err != nil {
+ return err
+ }
+ if res == 0 {
+ p.log.Warn("could not subscribe to the provided topic, you might be already subscribed to it", "connectionID", connectionID, "topic", topics[i])
+ continue
+ }
+ }
+
+ // and subscribe after
+ return p.channel.sub(topics...)
+}
+
+func (p *PubSubDriver) Unsubscribe(connectionID string, topics ...string) error {
+ // Remove topics from the storage
+ for i := 0; i < len(topics); i++ {
+ srem := p.universalClient.SRem(context.Background(), topics[i], connectionID)
+ if srem.Err() != nil {
+ return srem.Err()
+ }
+ }
+
+ for i := 0; i < len(topics); i++ {
+ // if there are no such topics, we can safely unsubscribe from the redis
+ exists := p.universalClient.Exists(context.Background(), topics[i])
+ res, err := exists.Result()
+ if err != nil {
+ return err
+ }
+
+ // if we have associated connections - skip
+ if res == 1 { // exists means that topic still exists and some other nodes may have connections associated with it
+ continue
+ }
+
+ // else - unsubscribe
+ err = p.channel.unsub(topics[i])
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) {
+ hget := p.universalClient.SMembersMap(context.Background(), topic)
+ r, err := hget.Result()
+ if err != nil {
+ panic(err)
+ }
+
+ // assign connections
+ // res expected to be from the sync.Pool
+ for k := range r {
+ res[k] = struct{}{}
+ }
+}
+
+// Next return next message
+func (p *PubSubDriver) Next(ctx context.Context) (*pubsub.Message, error) {
+ const op = errors.Op("redis_driver_next")
+ select {
+ case msg := <-p.channel.message():
+ return msg, nil
+ case <-ctx.Done():
+ return nil, errors.E(op, errors.TimeOut, ctx.Err())
+ }
+}