diff options
author | Valery Piashchynski <[email protected]> | 2021-06-14 16:39:02 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2021-06-14 16:39:02 +0300 |
commit | 75ab1e16c64cfd0a6424fe4c546fdbc5e1b992dd (patch) | |
tree | 1e9a910071d20021ad0f7ef4fe6099bac6a341ef /plugins/redis | |
parent | dc8ed203c247afd684f198ebbac103a10bfad72a (diff) |
- Rework redis with ws plugins
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/redis')
-rw-r--r-- | plugins/redis/clients.go | 84 | ||||
-rw-r--r-- | plugins/redis/interface.go | 7 | ||||
-rw-r--r-- | plugins/redis/kv.go | 242 | ||||
-rw-r--r-- | plugins/redis/plugin.go | 181 | ||||
-rw-r--r-- | plugins/redis/pubsub.go | 189 |
5 files changed, 541 insertions, 162 deletions
diff --git a/plugins/redis/clients.go b/plugins/redis/clients.go new file mode 100644 index 00000000..d0a184d2 --- /dev/null +++ b/plugins/redis/clients.go @@ -0,0 +1,84 @@ +package redis + +import ( + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" +) + +// RedisClient return a client based on the provided section key +// key sample: kv.some-section.redis +// kv.redis +// redis (root) +func (p *Plugin) RedisClient(key string) (redis.UniversalClient, error) { + const op = errors.Op("redis_get_client") + + if !p.cfgPlugin.Has(key) { + return nil, errors.E(op, errors.Errorf("no such section: %s", key)) + } + + cfg := &Config{} + + err := p.cfgPlugin.UnmarshalKey(key, cfg) + if err != nil { + return nil, errors.E(op, err) + } + + cfg.InitDefaults() + + uc := redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: cfg.Addrs, + DB: cfg.DB, + Username: cfg.Username, + Password: cfg.Password, + SentinelPassword: cfg.SentinelPassword, + MaxRetries: cfg.MaxRetries, + MinRetryBackoff: cfg.MaxRetryBackoff, + MaxRetryBackoff: cfg.MaxRetryBackoff, + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + MaxConnAge: cfg.MaxConnAge, + PoolTimeout: cfg.PoolTimeout, + IdleTimeout: cfg.IdleTimeout, + IdleCheckFrequency: cfg.IdleCheckFreq, + ReadOnly: cfg.ReadOnly, + RouteByLatency: cfg.RouteByLatency, + RouteRandomly: cfg.RouteRandomly, + MasterName: cfg.MasterName, + }) + + return uc, nil +} + +func (p *Plugin) DefaultClient() redis.UniversalClient { + cfg := &Config{} + cfg.InitDefaults() + + uc := redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: cfg.Addrs, + DB: cfg.DB, + Username: cfg.Username, + Password: cfg.Password, + SentinelPassword: cfg.SentinelPassword, + MaxRetries: cfg.MaxRetries, + MinRetryBackoff: cfg.MaxRetryBackoff, + MaxRetryBackoff: cfg.MaxRetryBackoff, + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + MaxConnAge: cfg.MaxConnAge, + PoolTimeout: cfg.PoolTimeout, + IdleTimeout: cfg.IdleTimeout, + IdleCheckFrequency: cfg.IdleCheckFreq, + ReadOnly: cfg.ReadOnly, + RouteByLatency: cfg.RouteByLatency, + RouteRandomly: cfg.RouteRandomly, + MasterName: cfg.MasterName, + }) + + return uc +} diff --git a/plugins/redis/interface.go b/plugins/redis/interface.go index c0be6137..189b0002 100644 --- a/plugins/redis/interface.go +++ b/plugins/redis/interface.go @@ -4,6 +4,9 @@ import "github.com/go-redis/redis/v8" // Redis in the redis KV plugin interface type Redis interface { - // GetClient provides universal redis client - GetClient() redis.UniversalClient + // RedisClient provides universal redis client + RedisClient(key string) (redis.UniversalClient, error) + + // DefaultClient provide default redis client based on redis defaults + DefaultClient() redis.UniversalClient } diff --git a/plugins/redis/kv.go b/plugins/redis/kv.go new file mode 100644 index 00000000..66cb8384 --- /dev/null +++ b/plugins/redis/kv.go @@ -0,0 +1,242 @@ +package redis + +import ( + "context" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" + kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/utils" +) + +type Driver struct { + universalClient redis.UniversalClient + log logger.Logger + cfg *Config +} + +func NewRedisDriver(log logger.Logger, key string, cfgPlugin config.Configurer) (kv.Storage, error) { + const op = errors.Op("new_boltdb_driver") + + d := &Driver{ + log: log, + } + + // will be different for every connected driver + err := cfgPlugin.UnmarshalKey(key, &d.cfg) + if err != nil { + return nil, errors.E(op, err) + } + + d.cfg.InitDefaults() + d.log = log + + d.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: d.cfg.Addrs, + DB: d.cfg.DB, + Username: d.cfg.Username, + Password: d.cfg.Password, + SentinelPassword: d.cfg.SentinelPassword, + MaxRetries: d.cfg.MaxRetries, + MinRetryBackoff: d.cfg.MaxRetryBackoff, + MaxRetryBackoff: d.cfg.MaxRetryBackoff, + DialTimeout: d.cfg.DialTimeout, + ReadTimeout: d.cfg.ReadTimeout, + WriteTimeout: d.cfg.WriteTimeout, + PoolSize: d.cfg.PoolSize, + MinIdleConns: d.cfg.MinIdleConns, + MaxConnAge: d.cfg.MaxConnAge, + PoolTimeout: d.cfg.PoolTimeout, + IdleTimeout: d.cfg.IdleTimeout, + IdleCheckFrequency: d.cfg.IdleCheckFreq, + ReadOnly: d.cfg.ReadOnly, + RouteByLatency: d.cfg.RouteByLatency, + RouteRandomly: d.cfg.RouteRandomly, + MasterName: d.cfg.MasterName, + }) + + return d, nil +} + +// Has checks if value exists. +func (d *Driver) Has(keys ...string) (map[string]bool, error) { + const op = errors.Op("redis_driver_has") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + m := make(map[string]bool, len(keys)) + for _, key := range keys { + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + + exist, err := d.universalClient.Exists(context.Background(), key).Result() + if err != nil { + return nil, err + } + if exist == 1 { + m[key] = true + } + } + return m, nil +} + +// Get loads key content into slice. +func (d *Driver) Get(key string) ([]byte, error) { + const op = errors.Op("redis_driver_get") + // to get cases like " " + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + return d.universalClient.Get(context.Background(), key).Bytes() +} + +// MGet loads content of multiple values (some values might be skipped). +// https://redis.io/commands/mget +// Returns slice with the interfaces with values +func (d *Driver) MGet(keys ...string) (map[string][]byte, error) { + const op = errors.Op("redis_driver_mget") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for _, key := range keys { + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string][]byte, len(keys)) + + for _, k := range keys { + cmd := d.universalClient.Get(context.Background(), k) + if cmd.Err() != nil { + if cmd.Err() == redis.Nil { + continue + } + return nil, errors.E(op, cmd.Err()) + } + + m[k] = utils.AsBytes(cmd.Val()) + } + + return m, nil +} + +// Set sets value with the TTL in seconds +// https://redis.io/commands/set +// Redis `SET key value [expiration]` command. +// +// Use expiration for `SETEX`-like behavior. +// Zero expiration means the key has no expiration time. +func (d *Driver) Set(items ...*kvv1.Item) error { + const op = errors.Op("redis_driver_set") + if items == nil { + return errors.E(op, errors.NoKeys) + } + now := time.Now() + for _, item := range items { + if item == nil { + return errors.E(op, errors.EmptyKey) + } + + if item.Timeout == "" { + err := d.universalClient.Set(context.Background(), item.Key, item.Value, 0).Err() + if err != nil { + return err + } + } else { + t, err := time.Parse(time.RFC3339, item.Timeout) + if err != nil { + return err + } + err = d.universalClient.Set(context.Background(), item.Key, item.Value, t.Sub(now)).Err() + if err != nil { + return err + } + } + } + return nil +} + +// Delete one or multiple keys. +func (d *Driver) Delete(keys ...string) error { + const op = errors.Op("redis_driver_delete") + if keys == nil { + return errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for _, key := range keys { + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return errors.E(op, errors.EmptyKey) + } + } + return d.universalClient.Del(context.Background(), keys...).Err() +} + +// MExpire https://redis.io/commands/expire +// timeout in RFC3339 +func (d *Driver) MExpire(items ...*kvv1.Item) error { + const op = errors.Op("redis_driver_mexpire") + now := time.Now() + for _, item := range items { + if item == nil { + continue + } + if item.Timeout == "" || strings.TrimSpace(item.Key) == "" { + return errors.E(op, errors.Str("should set timeout and at least one key")) + } + + t, err := time.Parse(time.RFC3339, item.Timeout) + if err != nil { + return err + } + + // t guessed to be in future + // for Redis we use t.Sub, it will result in seconds, like 4.2s + d.universalClient.Expire(context.Background(), item.Key, t.Sub(now)) + } + + return nil +} + +// TTL https://redis.io/commands/ttl +// return time in seconds (float64) for a given keys +func (d *Driver) TTL(keys ...string) (map[string]string, error) { + const op = errors.Op("redis_driver_ttl") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for _, key := range keys { + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string]string, len(keys)) + + for _, key := range keys { + duration, err := d.universalClient.TTL(context.Background(), key).Result() + if err != nil { + return nil, err + } + + m[key] = duration.String() + } + return m, nil +} diff --git a/plugins/redis/plugin.go b/plugins/redis/plugin.go index 47ffeb39..24c21b55 100644 --- a/plugins/redis/plugin.go +++ b/plugins/redis/plugin.go @@ -1,15 +1,14 @@ package redis import ( - "context" "sync" "github.com/go-redis/redis/v8" "github.com/spiral/errors" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" - "google.golang.org/protobuf/proto" ) const PluginName = "redis" @@ -17,80 +16,37 @@ const PluginName = "redis" type Plugin struct { sync.RWMutex // config for RR integration - cfg *Config + cfgPlugin config.Configurer // logger log logger.Logger // redis universal client universalClient redis.UniversalClient // fanIn implementation used to deliver messages from all channels to the single websocket point - fanin *FanIn -} - -func (p *Plugin) GetClient() redis.UniversalClient { - return p.universalClient + stopCh chan struct{} } func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { - const op = errors.Op("redis_plugin_init") - - if !cfg.Has(PluginName) { - return errors.E(op, errors.Disabled) - } - - err := cfg.UnmarshalKey(PluginName, &p.cfg) - if err != nil { - return errors.E(op, errors.Disabled, err) - } - - p.cfg.InitDefaults() p.log = log - - p.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{ - Addrs: p.cfg.Addrs, - DB: p.cfg.DB, - Username: p.cfg.Username, - Password: p.cfg.Password, - SentinelPassword: p.cfg.SentinelPassword, - MaxRetries: p.cfg.MaxRetries, - MinRetryBackoff: p.cfg.MaxRetryBackoff, - MaxRetryBackoff: p.cfg.MaxRetryBackoff, - DialTimeout: p.cfg.DialTimeout, - ReadTimeout: p.cfg.ReadTimeout, - WriteTimeout: p.cfg.WriteTimeout, - PoolSize: p.cfg.PoolSize, - MinIdleConns: p.cfg.MinIdleConns, - MaxConnAge: p.cfg.MaxConnAge, - PoolTimeout: p.cfg.PoolTimeout, - IdleTimeout: p.cfg.IdleTimeout, - IdleCheckFrequency: p.cfg.IdleCheckFreq, - ReadOnly: p.cfg.ReadOnly, - RouteByLatency: p.cfg.RouteByLatency, - RouteRandomly: p.cfg.RouteRandomly, - MasterName: p.cfg.MasterName, - }) - - // init fanin - p.fanin = newFanIn(p.universalClient, log) + p.cfgPlugin = cfg + p.stopCh = make(chan struct{}, 1) return nil } func (p *Plugin) Serve() chan error { - errCh := make(chan error) - return errCh + return make(chan error) } func (p *Plugin) Stop() error { const op = errors.Op("redis_plugin_stop") - err := p.fanin.stop() - if err != nil { - return errors.E(op, err) - } + p.stopCh <- struct{}{} - err = p.universalClient.Close() - if err != nil { - return errors.E(op, err) + if p.universalClient != nil { + err := p.universalClient.Close() + if err != nil { + return errors.E(op, err) + } } return nil @@ -103,112 +59,17 @@ func (p *Plugin) Name() string { // Available interface implementation func (p *Plugin) Available() {} -func (p *Plugin) Publish(msg []byte) error { - p.Lock() - defer p.Unlock() - - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - return errors.E(err) - } - - for j := 0; j < len(m.GetTopics()); j++ { - f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) - if f.Err() != nil { - return f.Err() - } - } - return nil -} - -func (p *Plugin) PublishAsync(msg []byte) { - go func() { - p.Lock() - defer p.Unlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - p.log.Error("message unmarshal error") - return - } - - for j := 0; j < len(m.GetTopics()); j++ { - f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) - if f.Err() != nil { - p.log.Error("redis publish", "error", f.Err()) - } - } - }() -} - -func (p *Plugin) Subscribe(connectionID string, topics ...string) error { - // just add a connection - for i := 0; i < len(topics); i++ { - // key - topic - // value - connectionID - hset := p.universalClient.SAdd(context.Background(), topics[i], connectionID) - res, err := hset.Result() - if err != nil { - return err - } - if res == 0 { - p.log.Warn("could not subscribe to the provided topic", "connectionID", connectionID, "topic", topics[i]) - continue - } - } - - // and subscribe after - return p.fanin.sub(topics...) -} - -func (p *Plugin) Unsubscribe(connectionID string, topics ...string) error { - // Remove topics from the storage - for i := 0; i < len(topics); i++ { - srem := p.universalClient.SRem(context.Background(), topics[i], connectionID) - if srem.Err() != nil { - return srem.Err() - } - } - - for i := 0; i < len(topics); i++ { - // if there are no such topics, we can safely unsubscribe from the redis - exists := p.universalClient.Exists(context.Background(), topics[i]) - res, err := exists.Result() - if err != nil { - return err - } - - // if we have associated connections - skip - if res == 1 { // exists means that topic still exists and some other nodes may have connections associated with it - continue - } - - // else - unsubscribe - err = p.fanin.unsub(topics[i]) - if err != nil { - return err - } - } - - return nil -} - -func (p *Plugin) Connections(topic string, res map[string]struct{}) { - hget := p.universalClient.SMembersMap(context.Background(), topic) - r, err := hget.Result() +// KVProvide provides KV storage implementation over the redis plugin +func (p *Plugin) KVProvide(key string) (kv.Storage, error) { + const op = errors.Op("redis_plugin_provide") + st, err := NewRedisDriver(p.log, key, p.cfgPlugin) if err != nil { - panic(err) + return nil, errors.E(op, err) } - // assighn connections - // res expected to be from the sync.Pool - for k := range r { - res[k] = struct{}{} - } + return st, nil } -// Next return next message -func (p *Plugin) Next() (*websocketsv1.Message, error) { - return <-p.fanin.consume(), nil +func (p *Plugin) PSProvide(key string) (pubsub.PubSub, error) { + return NewPubSubDriver(p.log, key, p.cfgPlugin, p.stopCh) } diff --git a/plugins/redis/pubsub.go b/plugins/redis/pubsub.go new file mode 100644 index 00000000..dbda7ea4 --- /dev/null +++ b/plugins/redis/pubsub.go @@ -0,0 +1,189 @@ +package redis + +import ( + "context" + "sync" + + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" + websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "google.golang.org/protobuf/proto" +) + +type PubSubDriver struct { + sync.RWMutex + cfg *Config `mapstructure:"redis"` + + log logger.Logger + fanin *FanIn + universalClient redis.UniversalClient + stopCh chan struct{} +} + +func NewPubSubDriver(log logger.Logger, key string, cfgPlugin config.Configurer, stopCh chan struct{}) (pubsub.PubSub, error) { + const op = errors.Op("new_pub_sub_driver") + ps := &PubSubDriver{ + log: log, + stopCh: stopCh, + } + + // will be different for every connected driver + err := cfgPlugin.UnmarshalKey(key, &ps.cfg) + if err != nil { + return nil, errors.E(op, err) + } + + ps.cfg.InitDefaults() + + ps.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: ps.cfg.Addrs, + DB: ps.cfg.DB, + Username: ps.cfg.Username, + Password: ps.cfg.Password, + SentinelPassword: ps.cfg.SentinelPassword, + MaxRetries: ps.cfg.MaxRetries, + MinRetryBackoff: ps.cfg.MaxRetryBackoff, + MaxRetryBackoff: ps.cfg.MaxRetryBackoff, + DialTimeout: ps.cfg.DialTimeout, + ReadTimeout: ps.cfg.ReadTimeout, + WriteTimeout: ps.cfg.WriteTimeout, + PoolSize: ps.cfg.PoolSize, + MinIdleConns: ps.cfg.MinIdleConns, + MaxConnAge: ps.cfg.MaxConnAge, + PoolTimeout: ps.cfg.PoolTimeout, + IdleTimeout: ps.cfg.IdleTimeout, + IdleCheckFrequency: ps.cfg.IdleCheckFreq, + ReadOnly: ps.cfg.ReadOnly, + RouteByLatency: ps.cfg.RouteByLatency, + RouteRandomly: ps.cfg.RouteRandomly, + MasterName: ps.cfg.MasterName, + }) + + ps.fanin = newFanIn(ps.universalClient, log) + + ps.stop() + + return ps, nil +} + +func (p *PubSubDriver) stop() { + go func() { + for range p.stopCh { + _ = p.fanin.stop() + return + } + }() +} + +func (p *PubSubDriver) Publish(msg []byte) error { + p.Lock() + defer p.Unlock() + + m := &websocketsv1.Message{} + err := proto.Unmarshal(msg, m) + if err != nil { + return errors.E(err) + } + + for j := 0; j < len(m.GetTopics()); j++ { + f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) + if f.Err() != nil { + return f.Err() + } + } + return nil +} + +func (p *PubSubDriver) PublishAsync(msg []byte) { + go func() { + p.Lock() + defer p.Unlock() + m := &websocketsv1.Message{} + err := proto.Unmarshal(msg, m) + if err != nil { + p.log.Error("message unmarshal error") + return + } + + for j := 0; j < len(m.GetTopics()); j++ { + f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) + if f.Err() != nil { + p.log.Error("redis publish", "error", f.Err()) + } + } + }() +} + +func (p *PubSubDriver) Subscribe(connectionID string, topics ...string) error { + // just add a connection + for i := 0; i < len(topics); i++ { + // key - topic + // value - connectionID + hset := p.universalClient.SAdd(context.Background(), topics[i], connectionID) + res, err := hset.Result() + if err != nil { + return err + } + if res == 0 { + p.log.Warn("could not subscribe to the provided topic", "connectionID", connectionID, "topic", topics[i]) + continue + } + } + + // and subscribe after + return p.fanin.sub(topics...) +} + +func (p *PubSubDriver) Unsubscribe(connectionID string, topics ...string) error { + // Remove topics from the storage + for i := 0; i < len(topics); i++ { + srem := p.universalClient.SRem(context.Background(), topics[i], connectionID) + if srem.Err() != nil { + return srem.Err() + } + } + + for i := 0; i < len(topics); i++ { + // if there are no such topics, we can safely unsubscribe from the redis + exists := p.universalClient.Exists(context.Background(), topics[i]) + res, err := exists.Result() + if err != nil { + return err + } + + // if we have associated connections - skip + if res == 1 { // exists means that topic still exists and some other nodes may have connections associated with it + continue + } + + // else - unsubscribe + err = p.fanin.unsub(topics[i]) + if err != nil { + return err + } + } + + return nil +} + +func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { + hget := p.universalClient.SMembersMap(context.Background(), topic) + r, err := hget.Result() + if err != nil { + panic(err) + } + + // assighn connections + // res expected to be from the sync.Pool + for k := range r { + res[k] = struct{}{} + } +} + +// Next return next message +func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { + return <-p.fanin.consume(), nil +} |