summaryrefslogtreecommitdiff
path: root/plugins/redis
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-06-14 16:39:02 +0300
committerValery Piashchynski <[email protected]>2021-06-14 16:39:02 +0300
commit75ab1e16c64cfd0a6424fe4c546fdbc5e1b992dd (patch)
tree1e9a910071d20021ad0f7ef4fe6099bac6a341ef /plugins/redis
parentdc8ed203c247afd684f198ebbac103a10bfad72a (diff)
- Rework redis with ws plugins
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/redis')
-rw-r--r--plugins/redis/clients.go84
-rw-r--r--plugins/redis/interface.go7
-rw-r--r--plugins/redis/kv.go242
-rw-r--r--plugins/redis/plugin.go181
-rw-r--r--plugins/redis/pubsub.go189
5 files changed, 541 insertions, 162 deletions
diff --git a/plugins/redis/clients.go b/plugins/redis/clients.go
new file mode 100644
index 00000000..d0a184d2
--- /dev/null
+++ b/plugins/redis/clients.go
@@ -0,0 +1,84 @@
+package redis
+
+import (
+ "github.com/go-redis/redis/v8"
+ "github.com/spiral/errors"
+)
+
+// RedisClient return a client based on the provided section key
+// key sample: kv.some-section.redis
+// kv.redis
+// redis (root)
+func (p *Plugin) RedisClient(key string) (redis.UniversalClient, error) {
+ const op = errors.Op("redis_get_client")
+
+ if !p.cfgPlugin.Has(key) {
+ return nil, errors.E(op, errors.Errorf("no such section: %s", key))
+ }
+
+ cfg := &Config{}
+
+ err := p.cfgPlugin.UnmarshalKey(key, cfg)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ cfg.InitDefaults()
+
+ uc := redis.NewUniversalClient(&redis.UniversalOptions{
+ Addrs: cfg.Addrs,
+ DB: cfg.DB,
+ Username: cfg.Username,
+ Password: cfg.Password,
+ SentinelPassword: cfg.SentinelPassword,
+ MaxRetries: cfg.MaxRetries,
+ MinRetryBackoff: cfg.MaxRetryBackoff,
+ MaxRetryBackoff: cfg.MaxRetryBackoff,
+ DialTimeout: cfg.DialTimeout,
+ ReadTimeout: cfg.ReadTimeout,
+ WriteTimeout: cfg.WriteTimeout,
+ PoolSize: cfg.PoolSize,
+ MinIdleConns: cfg.MinIdleConns,
+ MaxConnAge: cfg.MaxConnAge,
+ PoolTimeout: cfg.PoolTimeout,
+ IdleTimeout: cfg.IdleTimeout,
+ IdleCheckFrequency: cfg.IdleCheckFreq,
+ ReadOnly: cfg.ReadOnly,
+ RouteByLatency: cfg.RouteByLatency,
+ RouteRandomly: cfg.RouteRandomly,
+ MasterName: cfg.MasterName,
+ })
+
+ return uc, nil
+}
+
+func (p *Plugin) DefaultClient() redis.UniversalClient {
+ cfg := &Config{}
+ cfg.InitDefaults()
+
+ uc := redis.NewUniversalClient(&redis.UniversalOptions{
+ Addrs: cfg.Addrs,
+ DB: cfg.DB,
+ Username: cfg.Username,
+ Password: cfg.Password,
+ SentinelPassword: cfg.SentinelPassword,
+ MaxRetries: cfg.MaxRetries,
+ MinRetryBackoff: cfg.MaxRetryBackoff,
+ MaxRetryBackoff: cfg.MaxRetryBackoff,
+ DialTimeout: cfg.DialTimeout,
+ ReadTimeout: cfg.ReadTimeout,
+ WriteTimeout: cfg.WriteTimeout,
+ PoolSize: cfg.PoolSize,
+ MinIdleConns: cfg.MinIdleConns,
+ MaxConnAge: cfg.MaxConnAge,
+ PoolTimeout: cfg.PoolTimeout,
+ IdleTimeout: cfg.IdleTimeout,
+ IdleCheckFrequency: cfg.IdleCheckFreq,
+ ReadOnly: cfg.ReadOnly,
+ RouteByLatency: cfg.RouteByLatency,
+ RouteRandomly: cfg.RouteRandomly,
+ MasterName: cfg.MasterName,
+ })
+
+ return uc
+}
diff --git a/plugins/redis/interface.go b/plugins/redis/interface.go
index c0be6137..189b0002 100644
--- a/plugins/redis/interface.go
+++ b/plugins/redis/interface.go
@@ -4,6 +4,9 @@ import "github.com/go-redis/redis/v8"
// Redis in the redis KV plugin interface
type Redis interface {
- // GetClient provides universal redis client
- GetClient() redis.UniversalClient
+ // RedisClient provides universal redis client
+ RedisClient(key string) (redis.UniversalClient, error)
+
+ // DefaultClient provide default redis client based on redis defaults
+ DefaultClient() redis.UniversalClient
}
diff --git a/plugins/redis/kv.go b/plugins/redis/kv.go
new file mode 100644
index 00000000..66cb8384
--- /dev/null
+++ b/plugins/redis/kv.go
@@ -0,0 +1,242 @@
+package redis
+
+import (
+ "context"
+ "strings"
+ "time"
+
+ "github.com/go-redis/redis/v8"
+ "github.com/spiral/errors"
+ kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/kv"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/utils"
+)
+
+type Driver struct {
+ universalClient redis.UniversalClient
+ log logger.Logger
+ cfg *Config
+}
+
+func NewRedisDriver(log logger.Logger, key string, cfgPlugin config.Configurer) (kv.Storage, error) {
+ const op = errors.Op("new_boltdb_driver")
+
+ d := &Driver{
+ log: log,
+ }
+
+ // will be different for every connected driver
+ err := cfgPlugin.UnmarshalKey(key, &d.cfg)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ d.cfg.InitDefaults()
+ d.log = log
+
+ d.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{
+ Addrs: d.cfg.Addrs,
+ DB: d.cfg.DB,
+ Username: d.cfg.Username,
+ Password: d.cfg.Password,
+ SentinelPassword: d.cfg.SentinelPassword,
+ MaxRetries: d.cfg.MaxRetries,
+ MinRetryBackoff: d.cfg.MaxRetryBackoff,
+ MaxRetryBackoff: d.cfg.MaxRetryBackoff,
+ DialTimeout: d.cfg.DialTimeout,
+ ReadTimeout: d.cfg.ReadTimeout,
+ WriteTimeout: d.cfg.WriteTimeout,
+ PoolSize: d.cfg.PoolSize,
+ MinIdleConns: d.cfg.MinIdleConns,
+ MaxConnAge: d.cfg.MaxConnAge,
+ PoolTimeout: d.cfg.PoolTimeout,
+ IdleTimeout: d.cfg.IdleTimeout,
+ IdleCheckFrequency: d.cfg.IdleCheckFreq,
+ ReadOnly: d.cfg.ReadOnly,
+ RouteByLatency: d.cfg.RouteByLatency,
+ RouteRandomly: d.cfg.RouteRandomly,
+ MasterName: d.cfg.MasterName,
+ })
+
+ return d, nil
+}
+
+// Has checks if value exists.
+func (d *Driver) Has(keys ...string) (map[string]bool, error) {
+ const op = errors.Op("redis_driver_has")
+ if keys == nil {
+ return nil, errors.E(op, errors.NoKeys)
+ }
+
+ m := make(map[string]bool, len(keys))
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+
+ exist, err := d.universalClient.Exists(context.Background(), key).Result()
+ if err != nil {
+ return nil, err
+ }
+ if exist == 1 {
+ m[key] = true
+ }
+ }
+ return m, nil
+}
+
+// Get loads key content into slice.
+func (d *Driver) Get(key string) ([]byte, error) {
+ const op = errors.Op("redis_driver_get")
+ // to get cases like " "
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+ return d.universalClient.Get(context.Background(), key).Bytes()
+}
+
+// MGet loads content of multiple values (some values might be skipped).
+// https://redis.io/commands/mget
+// Returns slice with the interfaces with values
+func (d *Driver) MGet(keys ...string) (map[string][]byte, error) {
+ const op = errors.Op("redis_driver_mget")
+ if keys == nil {
+ return nil, errors.E(op, errors.NoKeys)
+ }
+
+ // should not be empty keys
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+ }
+
+ m := make(map[string][]byte, len(keys))
+
+ for _, k := range keys {
+ cmd := d.universalClient.Get(context.Background(), k)
+ if cmd.Err() != nil {
+ if cmd.Err() == redis.Nil {
+ continue
+ }
+ return nil, errors.E(op, cmd.Err())
+ }
+
+ m[k] = utils.AsBytes(cmd.Val())
+ }
+
+ return m, nil
+}
+
+// Set sets value with the TTL in seconds
+// https://redis.io/commands/set
+// Redis `SET key value [expiration]` command.
+//
+// Use expiration for `SETEX`-like behavior.
+// Zero expiration means the key has no expiration time.
+func (d *Driver) Set(items ...*kvv1.Item) error {
+ const op = errors.Op("redis_driver_set")
+ if items == nil {
+ return errors.E(op, errors.NoKeys)
+ }
+ now := time.Now()
+ for _, item := range items {
+ if item == nil {
+ return errors.E(op, errors.EmptyKey)
+ }
+
+ if item.Timeout == "" {
+ err := d.universalClient.Set(context.Background(), item.Key, item.Value, 0).Err()
+ if err != nil {
+ return err
+ }
+ } else {
+ t, err := time.Parse(time.RFC3339, item.Timeout)
+ if err != nil {
+ return err
+ }
+ err = d.universalClient.Set(context.Background(), item.Key, item.Value, t.Sub(now)).Err()
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// Delete one or multiple keys.
+func (d *Driver) Delete(keys ...string) error {
+ const op = errors.Op("redis_driver_delete")
+ if keys == nil {
+ return errors.E(op, errors.NoKeys)
+ }
+
+ // should not be empty keys
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return errors.E(op, errors.EmptyKey)
+ }
+ }
+ return d.universalClient.Del(context.Background(), keys...).Err()
+}
+
+// MExpire https://redis.io/commands/expire
+// timeout in RFC3339
+func (d *Driver) MExpire(items ...*kvv1.Item) error {
+ const op = errors.Op("redis_driver_mexpire")
+ now := time.Now()
+ for _, item := range items {
+ if item == nil {
+ continue
+ }
+ if item.Timeout == "" || strings.TrimSpace(item.Key) == "" {
+ return errors.E(op, errors.Str("should set timeout and at least one key"))
+ }
+
+ t, err := time.Parse(time.RFC3339, item.Timeout)
+ if err != nil {
+ return err
+ }
+
+ // t guessed to be in future
+ // for Redis we use t.Sub, it will result in seconds, like 4.2s
+ d.universalClient.Expire(context.Background(), item.Key, t.Sub(now))
+ }
+
+ return nil
+}
+
+// TTL https://redis.io/commands/ttl
+// return time in seconds (float64) for a given keys
+func (d *Driver) TTL(keys ...string) (map[string]string, error) {
+ const op = errors.Op("redis_driver_ttl")
+ if keys == nil {
+ return nil, errors.E(op, errors.NoKeys)
+ }
+
+ // should not be empty keys
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+ }
+
+ m := make(map[string]string, len(keys))
+
+ for _, key := range keys {
+ duration, err := d.universalClient.TTL(context.Background(), key).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ m[key] = duration.String()
+ }
+ return m, nil
+}
diff --git a/plugins/redis/plugin.go b/plugins/redis/plugin.go
index 47ffeb39..24c21b55 100644
--- a/plugins/redis/plugin.go
+++ b/plugins/redis/plugin.go
@@ -1,15 +1,14 @@
package redis
import (
- "context"
"sync"
"github.com/go-redis/redis/v8"
"github.com/spiral/errors"
- websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
"github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/kv"
"github.com/spiral/roadrunner/v2/plugins/logger"
- "google.golang.org/protobuf/proto"
)
const PluginName = "redis"
@@ -17,80 +16,37 @@ const PluginName = "redis"
type Plugin struct {
sync.RWMutex
// config for RR integration
- cfg *Config
+ cfgPlugin config.Configurer
// logger
log logger.Logger
// redis universal client
universalClient redis.UniversalClient
// fanIn implementation used to deliver messages from all channels to the single websocket point
- fanin *FanIn
-}
-
-func (p *Plugin) GetClient() redis.UniversalClient {
- return p.universalClient
+ stopCh chan struct{}
}
func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error {
- const op = errors.Op("redis_plugin_init")
-
- if !cfg.Has(PluginName) {
- return errors.E(op, errors.Disabled)
- }
-
- err := cfg.UnmarshalKey(PluginName, &p.cfg)
- if err != nil {
- return errors.E(op, errors.Disabled, err)
- }
-
- p.cfg.InitDefaults()
p.log = log
-
- p.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{
- Addrs: p.cfg.Addrs,
- DB: p.cfg.DB,
- Username: p.cfg.Username,
- Password: p.cfg.Password,
- SentinelPassword: p.cfg.SentinelPassword,
- MaxRetries: p.cfg.MaxRetries,
- MinRetryBackoff: p.cfg.MaxRetryBackoff,
- MaxRetryBackoff: p.cfg.MaxRetryBackoff,
- DialTimeout: p.cfg.DialTimeout,
- ReadTimeout: p.cfg.ReadTimeout,
- WriteTimeout: p.cfg.WriteTimeout,
- PoolSize: p.cfg.PoolSize,
- MinIdleConns: p.cfg.MinIdleConns,
- MaxConnAge: p.cfg.MaxConnAge,
- PoolTimeout: p.cfg.PoolTimeout,
- IdleTimeout: p.cfg.IdleTimeout,
- IdleCheckFrequency: p.cfg.IdleCheckFreq,
- ReadOnly: p.cfg.ReadOnly,
- RouteByLatency: p.cfg.RouteByLatency,
- RouteRandomly: p.cfg.RouteRandomly,
- MasterName: p.cfg.MasterName,
- })
-
- // init fanin
- p.fanin = newFanIn(p.universalClient, log)
+ p.cfgPlugin = cfg
+ p.stopCh = make(chan struct{}, 1)
return nil
}
func (p *Plugin) Serve() chan error {
- errCh := make(chan error)
- return errCh
+ return make(chan error)
}
func (p *Plugin) Stop() error {
const op = errors.Op("redis_plugin_stop")
- err := p.fanin.stop()
- if err != nil {
- return errors.E(op, err)
- }
+ p.stopCh <- struct{}{}
- err = p.universalClient.Close()
- if err != nil {
- return errors.E(op, err)
+ if p.universalClient != nil {
+ err := p.universalClient.Close()
+ if err != nil {
+ return errors.E(op, err)
+ }
}
return nil
@@ -103,112 +59,17 @@ func (p *Plugin) Name() string {
// Available interface implementation
func (p *Plugin) Available() {}
-func (p *Plugin) Publish(msg []byte) error {
- p.Lock()
- defer p.Unlock()
-
- m := &websocketsv1.Message{}
- err := proto.Unmarshal(msg, m)
- if err != nil {
- return errors.E(err)
- }
-
- for j := 0; j < len(m.GetTopics()); j++ {
- f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg)
- if f.Err() != nil {
- return f.Err()
- }
- }
- return nil
-}
-
-func (p *Plugin) PublishAsync(msg []byte) {
- go func() {
- p.Lock()
- defer p.Unlock()
- m := &websocketsv1.Message{}
- err := proto.Unmarshal(msg, m)
- if err != nil {
- p.log.Error("message unmarshal error")
- return
- }
-
- for j := 0; j < len(m.GetTopics()); j++ {
- f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg)
- if f.Err() != nil {
- p.log.Error("redis publish", "error", f.Err())
- }
- }
- }()
-}
-
-func (p *Plugin) Subscribe(connectionID string, topics ...string) error {
- // just add a connection
- for i := 0; i < len(topics); i++ {
- // key - topic
- // value - connectionID
- hset := p.universalClient.SAdd(context.Background(), topics[i], connectionID)
- res, err := hset.Result()
- if err != nil {
- return err
- }
- if res == 0 {
- p.log.Warn("could not subscribe to the provided topic", "connectionID", connectionID, "topic", topics[i])
- continue
- }
- }
-
- // and subscribe after
- return p.fanin.sub(topics...)
-}
-
-func (p *Plugin) Unsubscribe(connectionID string, topics ...string) error {
- // Remove topics from the storage
- for i := 0; i < len(topics); i++ {
- srem := p.universalClient.SRem(context.Background(), topics[i], connectionID)
- if srem.Err() != nil {
- return srem.Err()
- }
- }
-
- for i := 0; i < len(topics); i++ {
- // if there are no such topics, we can safely unsubscribe from the redis
- exists := p.universalClient.Exists(context.Background(), topics[i])
- res, err := exists.Result()
- if err != nil {
- return err
- }
-
- // if we have associated connections - skip
- if res == 1 { // exists means that topic still exists and some other nodes may have connections associated with it
- continue
- }
-
- // else - unsubscribe
- err = p.fanin.unsub(topics[i])
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (p *Plugin) Connections(topic string, res map[string]struct{}) {
- hget := p.universalClient.SMembersMap(context.Background(), topic)
- r, err := hget.Result()
+// KVProvide provides KV storage implementation over the redis plugin
+func (p *Plugin) KVProvide(key string) (kv.Storage, error) {
+ const op = errors.Op("redis_plugin_provide")
+ st, err := NewRedisDriver(p.log, key, p.cfgPlugin)
if err != nil {
- panic(err)
+ return nil, errors.E(op, err)
}
- // assighn connections
- // res expected to be from the sync.Pool
- for k := range r {
- res[k] = struct{}{}
- }
+ return st, nil
}
-// Next return next message
-func (p *Plugin) Next() (*websocketsv1.Message, error) {
- return <-p.fanin.consume(), nil
+func (p *Plugin) PSProvide(key string) (pubsub.PubSub, error) {
+ return NewPubSubDriver(p.log, key, p.cfgPlugin, p.stopCh)
}
diff --git a/plugins/redis/pubsub.go b/plugins/redis/pubsub.go
new file mode 100644
index 00000000..dbda7ea4
--- /dev/null
+++ b/plugins/redis/pubsub.go
@@ -0,0 +1,189 @@
+package redis
+
+import (
+ "context"
+ "sync"
+
+ "github.com/go-redis/redis/v8"
+ "github.com/spiral/errors"
+ websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "google.golang.org/protobuf/proto"
+)
+
+type PubSubDriver struct {
+ sync.RWMutex
+ cfg *Config `mapstructure:"redis"`
+
+ log logger.Logger
+ fanin *FanIn
+ universalClient redis.UniversalClient
+ stopCh chan struct{}
+}
+
+func NewPubSubDriver(log logger.Logger, key string, cfgPlugin config.Configurer, stopCh chan struct{}) (pubsub.PubSub, error) {
+ const op = errors.Op("new_pub_sub_driver")
+ ps := &PubSubDriver{
+ log: log,
+ stopCh: stopCh,
+ }
+
+ // will be different for every connected driver
+ err := cfgPlugin.UnmarshalKey(key, &ps.cfg)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ ps.cfg.InitDefaults()
+
+ ps.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{
+ Addrs: ps.cfg.Addrs,
+ DB: ps.cfg.DB,
+ Username: ps.cfg.Username,
+ Password: ps.cfg.Password,
+ SentinelPassword: ps.cfg.SentinelPassword,
+ MaxRetries: ps.cfg.MaxRetries,
+ MinRetryBackoff: ps.cfg.MaxRetryBackoff,
+ MaxRetryBackoff: ps.cfg.MaxRetryBackoff,
+ DialTimeout: ps.cfg.DialTimeout,
+ ReadTimeout: ps.cfg.ReadTimeout,
+ WriteTimeout: ps.cfg.WriteTimeout,
+ PoolSize: ps.cfg.PoolSize,
+ MinIdleConns: ps.cfg.MinIdleConns,
+ MaxConnAge: ps.cfg.MaxConnAge,
+ PoolTimeout: ps.cfg.PoolTimeout,
+ IdleTimeout: ps.cfg.IdleTimeout,
+ IdleCheckFrequency: ps.cfg.IdleCheckFreq,
+ ReadOnly: ps.cfg.ReadOnly,
+ RouteByLatency: ps.cfg.RouteByLatency,
+ RouteRandomly: ps.cfg.RouteRandomly,
+ MasterName: ps.cfg.MasterName,
+ })
+
+ ps.fanin = newFanIn(ps.universalClient, log)
+
+ ps.stop()
+
+ return ps, nil
+}
+
+func (p *PubSubDriver) stop() {
+ go func() {
+ for range p.stopCh {
+ _ = p.fanin.stop()
+ return
+ }
+ }()
+}
+
+func (p *PubSubDriver) Publish(msg []byte) error {
+ p.Lock()
+ defer p.Unlock()
+
+ m := &websocketsv1.Message{}
+ err := proto.Unmarshal(msg, m)
+ if err != nil {
+ return errors.E(err)
+ }
+
+ for j := 0; j < len(m.GetTopics()); j++ {
+ f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg)
+ if f.Err() != nil {
+ return f.Err()
+ }
+ }
+ return nil
+}
+
+func (p *PubSubDriver) PublishAsync(msg []byte) {
+ go func() {
+ p.Lock()
+ defer p.Unlock()
+ m := &websocketsv1.Message{}
+ err := proto.Unmarshal(msg, m)
+ if err != nil {
+ p.log.Error("message unmarshal error")
+ return
+ }
+
+ for j := 0; j < len(m.GetTopics()); j++ {
+ f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg)
+ if f.Err() != nil {
+ p.log.Error("redis publish", "error", f.Err())
+ }
+ }
+ }()
+}
+
+func (p *PubSubDriver) Subscribe(connectionID string, topics ...string) error {
+ // just add a connection
+ for i := 0; i < len(topics); i++ {
+ // key - topic
+ // value - connectionID
+ hset := p.universalClient.SAdd(context.Background(), topics[i], connectionID)
+ res, err := hset.Result()
+ if err != nil {
+ return err
+ }
+ if res == 0 {
+ p.log.Warn("could not subscribe to the provided topic", "connectionID", connectionID, "topic", topics[i])
+ continue
+ }
+ }
+
+ // and subscribe after
+ return p.fanin.sub(topics...)
+}
+
+func (p *PubSubDriver) Unsubscribe(connectionID string, topics ...string) error {
+ // Remove topics from the storage
+ for i := 0; i < len(topics); i++ {
+ srem := p.universalClient.SRem(context.Background(), topics[i], connectionID)
+ if srem.Err() != nil {
+ return srem.Err()
+ }
+ }
+
+ for i := 0; i < len(topics); i++ {
+ // if there are no such topics, we can safely unsubscribe from the redis
+ exists := p.universalClient.Exists(context.Background(), topics[i])
+ res, err := exists.Result()
+ if err != nil {
+ return err
+ }
+
+ // if we have associated connections - skip
+ if res == 1 { // exists means that topic still exists and some other nodes may have connections associated with it
+ continue
+ }
+
+ // else - unsubscribe
+ err = p.fanin.unsub(topics[i])
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) {
+ hget := p.universalClient.SMembersMap(context.Background(), topic)
+ r, err := hget.Result()
+ if err != nil {
+ panic(err)
+ }
+
+ // assighn connections
+ // res expected to be from the sync.Pool
+ for k := range r {
+ res[k] = struct{}{}
+ }
+}
+
+// Next return next message
+func (p *PubSubDriver) Next() (*websocketsv1.Message, error) {
+ return <-p.fanin.consume(), nil
+}