diff options
38 files changed, 454 insertions, 595 deletions
diff --git a/pkg/interface/pubsub/interface.go b/pkg/pubsub/interface.go index 30b544db..53f92cb8 100644 --- a/pkg/interface/pubsub/interface.go +++ b/pkg/pubsub/interface.go @@ -1,6 +1,6 @@ package pubsub -import websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" +import websocketsv1beta "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" /* This interface is in BETA. It might be changed. @@ -47,9 +47,9 @@ type Publisher interface { // Reader interface should return next message type Reader interface { - Next() (*websocketsv1.Message, error) + Next() (*websocketsv1beta.Message, error) } -type PSProvider interface { - PSProvide(key string) (PubSub, error) +type Constructor interface { + PSConstruct(key string) (PubSub, error) } diff --git a/pkg/interface/broadcast/broadcast.go b/plugins/broadcast/interface.go index 4c49f7c5..46709d71 100644 --- a/pkg/interface/broadcast/broadcast.go +++ b/plugins/broadcast/interface.go @@ -1,6 +1,6 @@ package broadcast -import "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" +import "github.com/spiral/roadrunner/v2/pkg/pubsub" type Broadcaster interface { GetDriver(key string) (pubsub.SubReader, error) diff --git a/plugins/broadcast/plugin.go b/plugins/broadcast/plugin.go index c43b2e4c..612b6a47 100644 --- a/plugins/broadcast/plugin.go +++ b/plugins/broadcast/plugin.go @@ -6,10 +6,10 @@ import ( endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" + websocketsv1beta "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "google.golang.org/protobuf/proto" ) @@ -30,8 +30,8 @@ type Plugin struct { log logger.Logger // publishers implement Publisher interface // and able to receive a payload - publishers map[string]pubsub.PubSub - providers map[string]pubsub.PSProvider + publishers map[string]pubsub.PubSub + constructors map[string]pubsub.Constructor } func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { @@ -47,7 +47,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { } p.publishers = make(map[string]pubsub.PubSub) - p.providers = make(map[string]pubsub.PSProvider) + p.constructors = make(map[string]pubsub.Constructor) p.log = log p.cfgPlugin = cfg @@ -64,6 +64,8 @@ func (p *Plugin) Serve() chan error { continue } + // check type of the v + // should be a map[string]interface{} switch t := v.(type) { // correct type case map[string]interface{}: @@ -81,11 +83,11 @@ func (p *Plugin) Serve() chan error { switch v.(map[string]interface{})[driver] { case memory: - if _, ok := p.providers[memory]; !ok { + if _, ok := p.constructors[memory]; !ok { p.log.Warn("no memory drivers registered", "registered", p.publishers) continue } - ps, err := p.providers[memory].PSProvide(configKey) + ps, err := p.constructors[memory].PSConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -94,7 +96,7 @@ func (p *Plugin) Serve() chan error { // save the pubsub p.publishers[k] = ps case redis: - if _, ok := p.providers[redis]; !ok { + if _, ok := p.constructors[redis]; !ok { p.log.Warn("no redis drivers registered", "registered", p.publishers) continue } @@ -102,7 +104,7 @@ func (p *Plugin) Serve() chan error { // first - try local configuration switch { case p.cfgPlugin.Has(configKey): - ps, err := p.providers[redis].PSProvide(configKey) + ps, err := p.constructors[redis].PSConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -111,7 +113,7 @@ func (p *Plugin) Serve() chan error { // save the pubsub p.publishers[k] = ps case p.cfgPlugin.Has(redis): - ps, err := p.providers[redis].PSProvide(configKey) + ps, err := p.constructors[redis].PSConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -138,9 +140,9 @@ func (p *Plugin) Collects() []interface{} { } // CollectPublishers collect all plugins who implement pubsub.Publisher interface -func (p *Plugin) CollectPublishers(name endure.Named, subscriber pubsub.PSProvider) { +func (p *Plugin) CollectPublishers(name endure.Named, subscriber pubsub.Constructor) { // key redis, value - interface - p.providers[name.Name()] = subscriber + p.constructors[name.Name()] = subscriber } // Publish is an entry point to the websocket PUBSUB @@ -150,7 +152,7 @@ func (p *Plugin) Publish(m []byte) error { const op = errors.Op("broadcast_plugin_publish") - msg := &websocketsv1.Message{} + msg := &websocketsv1beta.Message{} err := proto.Unmarshal(m, msg) if err != nil { return errors.E(op, err) @@ -179,7 +181,7 @@ func (p *Plugin) PublishAsync(m []byte) { go func() { p.Lock() defer p.Unlock() - msg := &websocketsv1.Message{} + msg := &websocketsv1beta.Message{} err := proto.Unmarshal(m, msg) if err != nil { p.log.Error("message unmarshal") @@ -201,7 +203,7 @@ func (p *Plugin) PublishAsync(m []byte) { func (p *Plugin) GetDriver(key string) (pubsub.SubReader, error) { const op = errors.Op("broadcast_plugin_get_driver") // key - driver, default for example - // we should find `default` in the collected pubsubs providers + // we should find `default` in the collected pubsubs constructors if pub, ok := p.publishers[key]; ok { return pub, nil } diff --git a/plugins/broadcast/rpc.go b/plugins/broadcast/rpc.go index fa853421..4c27cdc3 100644 --- a/plugins/broadcast/rpc.go +++ b/plugins/broadcast/rpc.go @@ -2,8 +2,8 @@ package broadcast import ( "github.com/spiral/errors" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/plugins/logger" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "google.golang.org/protobuf/proto" ) @@ -24,8 +24,7 @@ func (r *rpc) Publish(in *websocketsv1.Request, out *websocketsv1.Response) erro return nil } - r.log.Debug("message published", "msg", in.Messages) - + r.log.Debug("message published", "msg", in.String()) msgLen := len(in.GetMessages()) for i := 0; i < msgLen; i++ { @@ -56,7 +55,7 @@ func (r *rpc) PublishAsync(in *websocketsv1.Request, out *websocketsv1.Response) return nil } - r.log.Debug("message published", "msg", in.Messages) + r.log.Debug("message published", "msg", in.GetMessages()) msgLen := len(in.GetMessages()) diff --git a/plugins/kv/config.go b/plugins/kv/config.go index 66095817..09ba79cd 100644 --- a/plugins/kv/config.go +++ b/plugins/kv/config.go @@ -1,6 +1,6 @@ package kv -// Config represents general storage configuration with keys as the user defined kv-names and values as the drivers +// Config represents general storage configuration with keys as the user defined kv-names and values as the constructors type Config struct { Data map[string]interface{} `mapstructure:"kv"` } diff --git a/plugins/kv/drivers/boltdb/driver.go b/plugins/kv/drivers/boltdb/driver.go index 5f4d98b1..4b675271 100644 --- a/plugins/kv/drivers/boltdb/driver.go +++ b/plugins/kv/drivers/boltdb/driver.go @@ -9,10 +9,10 @@ import ( "time" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/utils" bolt "go.etcd.io/bbolt" ) diff --git a/plugins/kv/drivers/boltdb/plugin.go b/plugins/kv/drivers/boltdb/plugin.go index 28e2a89c..6ae1a1f6 100644 --- a/plugins/kv/drivers/boltdb/plugin.go +++ b/plugins/kv/drivers/boltdb/plugin.go @@ -46,7 +46,7 @@ func (s *Plugin) Stop() error { return nil } -func (s *Plugin) KVProvide(key string) (kv.Storage, error) { +func (s *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("boltdb_plugin_provide") st, err := NewBoltDBDriver(s.log, key, s.cfgPlugin, s.stop) if err != nil { diff --git a/plugins/kv/drivers/memcached/driver.go b/plugins/kv/drivers/memcached/driver.go index c1f79cbb..a2787d72 100644 --- a/plugins/kv/drivers/memcached/driver.go +++ b/plugins/kv/drivers/memcached/driver.go @@ -6,10 +6,10 @@ import ( "github.com/bradfitz/gomemcache/memcache" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" ) type Driver struct { diff --git a/plugins/kv/drivers/memcached/plugin.go b/plugins/kv/drivers/memcached/plugin.go index 936b2047..22ea5cca 100644 --- a/plugins/kv/drivers/memcached/plugin.go +++ b/plugins/kv/drivers/memcached/plugin.go @@ -34,7 +34,7 @@ func (s *Plugin) Name() string { // Available interface implementation func (s *Plugin) Available() {} -func (s *Plugin) KVProvide(key string) (kv.Storage, error) { +func (s *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("boltdb_plugin_provide") st, err := NewMemcachedDriver(s.log, key, s.cfgPlugin) if err != nil { diff --git a/plugins/kv/interface.go b/plugins/kv/interface.go index fd906041..ffdbbe62 100644 --- a/plugins/kv/interface.go +++ b/plugins/kv/interface.go @@ -1,6 +1,6 @@ package kv -import kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" +import kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" // Storage represents single abstract storage. type Storage interface { @@ -29,13 +29,8 @@ type Storage interface { Delete(keys ...string) error } -// StorageDriver interface provide storage -type StorageDriver interface { - Provider -} - -// Provider provides storage based on the config -type Provider interface { - // KVProvide provides Storage based on the config key - KVProvide(key string) (Storage, error) +// Constructor provides storage based on the config +type Constructor interface { + // KVConstruct provides Storage based on the config key + KVConstruct(key string) (Storage, error) } diff --git a/plugins/kv/plugin.go b/plugins/kv/plugin.go index 716e0d4c..03dbaed6 100644 --- a/plugins/kv/plugin.go +++ b/plugins/kv/plugin.go @@ -24,8 +24,8 @@ const ( // Plugin for the unified storage type Plugin struct { log logger.Logger - // drivers contains general storage drivers, such as boltdb, memory, memcached, redis. - drivers map[string]StorageDriver + // constructors contains general storage constructors, such as boltdb, memory, memcached, redis. + constructors map[string]Constructor // storages contains user-defined storages, such as boltdb-north, memcached-us and so on. storages map[string]Storage // KV configuration @@ -43,7 +43,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { if err != nil { return errors.E(op, err) } - p.drivers = make(map[string]StorageDriver, 5) + p.constructors = make(map[string]Constructor, 5) p.storages = make(map[string]Storage, 5) p.log = log p.cfgPlugin = cfg @@ -81,7 +81,7 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit addr: [ "localhost:11211" ] - For this config we should have 3 drivers: memory, boltdb and memcached but 4 KVs: default, boltdb-south, boltdb-north and memcached + For this config we should have 3 constructors: memory, boltdb and memcached but 4 KVs: default, boltdb-south, boltdb-north and memcached when user requests for example boltdb-south, we should provide that particular preconfigured storage */ for k, v := range p.cfg.Data { @@ -90,9 +90,18 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit continue } - if _, ok := v.(map[string]interface{})[driver]; !ok { - errCh <- errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", k)) - return errCh + // check type of the v + // should be a map[string]interface{} + switch t := v.(type) { + // correct type + case map[string]interface{}: + if _, ok := t[driver]; !ok { + errCh <- errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", k)) + return errCh + } + default: + p.log.Warn("wrong type detected in the configuration, please, check yaml indentation") + continue } // config key for the particular sub-driver kv.memcached @@ -100,12 +109,12 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // at this point we know, that driver field present in the configuration switch v.(map[string]interface{})[driver] { case memcached: - if _, ok := p.drivers[memcached]; !ok { - p.log.Warn("no memcached drivers registered", "registered", p.drivers) + if _, ok := p.constructors[memcached]; !ok { + p.log.Warn("no memcached constructors registered", "registered", p.constructors) continue } - storage, err := p.drivers[memcached].KVProvide(configKey) + storage, err := p.constructors[memcached].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -115,12 +124,12 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit p.storages[k] = storage case boltdb: - if _, ok := p.drivers[boltdb]; !ok { - p.log.Warn("no boltdb drivers registered", "registered", p.drivers) + if _, ok := p.constructors[boltdb]; !ok { + p.log.Warn("no boltdb constructors registered", "registered", p.constructors) continue } - storage, err := p.drivers[boltdb].KVProvide(configKey) + storage, err := p.constructors[boltdb].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -129,12 +138,12 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // save the storage p.storages[k] = storage case memory: - if _, ok := p.drivers[memory]; !ok { - p.log.Warn("no in-memory drivers registered", "registered", p.drivers) + if _, ok := p.constructors[memory]; !ok { + p.log.Warn("no in-memory constructors registered", "registered", p.constructors) continue } - storage, err := p.drivers[memory].KVProvide(configKey) + storage, err := p.constructors[memory].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -143,15 +152,15 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // save the storage p.storages[k] = storage case redis: - if _, ok := p.drivers[redis]; !ok { - p.log.Warn("no redis drivers registered", "registered", p.drivers) + if _, ok := p.constructors[redis]; !ok { + p.log.Warn("no redis constructors registered", "registered", p.constructors) continue } // first - try local configuration switch { case p.cfgPlugin.Has(configKey): - storage, err := p.drivers[redis].KVProvide(configKey) + storage, err := p.constructors[redis].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -160,7 +169,7 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // save the storage p.storages[k] = storage case p.cfgPlugin.Has(redis): - storage, err := p.drivers[redis].KVProvide(configKey) + storage, err := p.constructors[redis].KVConstruct(configKey) if err != nil { errCh <- errors.E(op, err) return errCh @@ -194,9 +203,9 @@ func (p *Plugin) Collects() []interface{} { } } -func (p *Plugin) GetAllStorageDrivers(name endure.Named, storage StorageDriver) { - // save the storage driver - p.drivers[name.Name()] = storage +func (p *Plugin) GetAllStorageDrivers(name endure.Named, constructor Constructor) { + // save the storage constructor + p.constructors[name.Name()] = constructor } // RPC returns associated rpc service. diff --git a/plugins/kv/rpc.go b/plugins/kv/rpc.go index ab1f7f31..af763600 100644 --- a/plugins/kv/rpc.go +++ b/plugins/kv/rpc.go @@ -2,8 +2,8 @@ package kv import ( "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" ) // Wrapper for the plugin diff --git a/plugins/memory/kv.go b/plugins/memory/kv.go index 9b7d7259..1cf031d1 100644 --- a/plugins/memory/kv.go +++ b/plugins/memory/kv.go @@ -6,10 +6,10 @@ import ( "time" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" ) type Driver struct { diff --git a/plugins/memory/plugin.go b/plugins/memory/plugin.go index d4d535bf..70badf15 100644 --- a/plugins/memory/plugin.go +++ b/plugins/memory/plugin.go @@ -2,7 +2,7 @@ package memory import ( "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -41,11 +41,11 @@ func (p *Plugin) Stop() error { return nil } -func (p *Plugin) PSProvide(key string) (pubsub.PubSub, error) { +func (p *Plugin) PSConstruct(key string) (pubsub.PubSub, error) { return NewPubSubDriver(p.log, key) } -func (p *Plugin) KVProvide(key string) (kv.Storage, error) { +func (p *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("inmemory_plugin_provide") st, err := NewInMemoryDriver(p.log, key, p.cfgPlugin, p.stop) if err != nil { diff --git a/plugins/memory/pubsub.go b/plugins/memory/pubsub.go index 02246a8f..87638bd8 100644 --- a/plugins/memory/pubsub.go +++ b/plugins/memory/pubsub.go @@ -4,9 +4,9 @@ import ( "sync" "github.com/spiral/roadrunner/v2/pkg/bst" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "google.golang.org/protobuf/proto" ) diff --git a/plugins/redis/fanin.go b/plugins/redis/fanin.go index ac9ebcc2..0bdd4cf5 100644 --- a/plugins/redis/fanin.go +++ b/plugins/redis/fanin.go @@ -4,8 +4,8 @@ import ( "context" "sync" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/plugins/logger" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "google.golang.org/protobuf/proto" "github.com/go-redis/redis/v8" diff --git a/plugins/redis/kv.go b/plugins/redis/kv.go index 66cb8384..320b7443 100644 --- a/plugins/redis/kv.go +++ b/plugins/redis/kv.go @@ -7,10 +7,10 @@ import ( "github.com/go-redis/redis/v8" "github.com/spiral/errors" - kvv1 "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" + kvv1 "github.com/spiral/roadrunner/v2/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/utils" ) diff --git a/plugins/redis/plugin.go b/plugins/redis/plugin.go index 8d997041..9d98790b 100644 --- a/plugins/redis/plugin.go +++ b/plugins/redis/plugin.go @@ -5,7 +5,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -59,8 +59,8 @@ func (p *Plugin) Name() string { // Available interface implementation func (p *Plugin) Available() {} -// KVProvide provides KV storage implementation over the redis plugin -func (p *Plugin) KVProvide(key string) (kv.Storage, error) { +// KVConstruct provides KV storage implementation over the redis plugin +func (p *Plugin) KVConstruct(key string) (kv.Storage, error) { const op = errors.Op("redis_plugin_provide") st, err := NewRedisDriver(p.log, key, p.cfgPlugin) if err != nil { @@ -70,6 +70,6 @@ func (p *Plugin) KVProvide(key string) (kv.Storage, error) { return st, nil } -func (p *Plugin) PSProvide(key string) (pubsub.PubSub, error) { +func (p *Plugin) PSConstruct(key string) (pubsub.PubSub, error) { return NewPubSubDriver(p.log, key, p.cfgPlugin, p.stopCh) } diff --git a/plugins/redis/pubsub.go b/plugins/redis/pubsub.go index c2a88abe..dc391c20 100644 --- a/plugins/redis/pubsub.go +++ b/plugins/redis/pubsub.go @@ -6,10 +6,10 @@ import ( "github.com/go-redis/redis/v8" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "google.golang.org/protobuf/proto" ) diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index 799312ad..0583be0c 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -7,12 +7,12 @@ import ( json "github.com/json-iterator/go" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/validator" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" ) type Response struct { diff --git a/plugins/websockets/origin_test.go b/plugins/websockets/origin_test.go index ec6e1960..bbc49bbb 100644 --- a/plugins/websockets/origin_test.go +++ b/plugins/websockets/origin_test.go @@ -9,6 +9,7 @@ import ( func TestConfig_Origin(t *testing.T) { cfg := &Config{ AllowedOrigin: "*", + Broker: "any", } err := cfg.InitDefault() @@ -28,6 +29,7 @@ func TestConfig_Origin(t *testing.T) { func TestConfig_OriginWildCard(t *testing.T) { cfg := &Config{ AllowedOrigin: "https://*my.site.com", + Broker: "any", } err := cfg.InitDefault() @@ -50,6 +52,7 @@ func TestConfig_OriginWildCard(t *testing.T) { func TestConfig_OriginWildCard2(t *testing.T) { cfg := &Config{ AllowedOrigin: "https://my.*.com", + Broker: "any", } err := cfg.InitDefault() diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index de7443fd..f0b7c6c3 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -10,12 +10,12 @@ import ( "github.com/google/uuid" json "github.com/json-iterator/go" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/pkg/interface/broadcast" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" "github.com/spiral/roadrunner/v2/pkg/payload" phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/process" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/pkg/worker" + "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go index cd9444da..3d95ede0 100644 --- a/plugins/websockets/pool/workers_pool.go +++ b/plugins/websockets/pool/workers_pool.go @@ -4,10 +4,10 @@ import ( "sync" json "github.com/json-iterator/go" - "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/utils" ) @@ -105,10 +105,10 @@ func (wp *WorkersPool) do() { //nolint:gocognit } // res is a map with a connectionsID - for topic := range res { - c, ok := wp.connections.Load(topic) + for connID := range res { + c, ok := wp.connections.Load(connID) if !ok { - wp.log.Warn("the user disconnected connection before the message being written to it", "topics", msg.GetTopics()[i]) + wp.log.Warn("the websocket disconnected before the message being written to it", "topics", msg.GetTopics()[i]) wp.put(res) continue } diff --git a/pkg/proto/kv/v1beta/kv.pb.go b/proto/kv/v1beta/kv.pb.go index 622967b8..622967b8 100644 --- a/pkg/proto/kv/v1beta/kv.pb.go +++ b/proto/kv/v1beta/kv.pb.go diff --git a/pkg/proto/kv/v1beta/kv.proto b/proto/kv/v1beta/kv.proto index 1e3b8177..1e3b8177 100644 --- a/pkg/proto/kv/v1beta/kv.proto +++ b/proto/kv/v1beta/kv.proto diff --git a/pkg/proto/websockets/v1beta/websockets.pb.go b/proto/websockets/v1beta/websockets.pb.go index ad4ebbe7..ad4ebbe7 100644 --- a/pkg/proto/websockets/v1beta/websockets.pb.go +++ b/proto/websockets/v1beta/websockets.pb.go diff --git a/pkg/proto/websockets/v1beta/websockets.proto b/proto/websockets/v1beta/websockets.proto index 5be6f70f..5be6f70f 100644 --- a/pkg/proto/websockets/v1beta/websockets.proto +++ b/proto/websockets/v1beta/websockets.proto diff --git a/tests/plugins/broadcast/broadcast_plugin_test.go b/tests/plugins/broadcast/broadcast_plugin_test.go index 5b195bd0..585a81a9 100644 --- a/tests/plugins/broadcast/broadcast_plugin_test.go +++ b/tests/plugins/broadcast/broadcast_plugin_test.go @@ -1,22 +1,13 @@ package broadcast import ( - "net" - "net/http" - "net/rpc" - "net/url" "os" "os/signal" "sync" "syscall" "testing" - "time" - "github.com/fasthttp/websocket" - json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" - goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -26,7 +17,6 @@ import ( rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets" - "github.com/spiral/roadrunner/v2/utils" "github.com/stretchr/testify/assert" ) @@ -98,109 +88,7 @@ func TestBroadcastInit(t *testing.T) { } }() - t.Run("TestWSInit", wsInit) - stopCh <- struct{}{} wg.Wait() } - -func wsInit(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", "memory", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func publishAsync(t *testing.T, command string, broker string, topics ...string) { - conn, err := net.Dial("tcp", "127.0.0.1:6001") - if err != nil { - panic(err) - } - - client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) - - ret := &websocketsv1.Response{} - err = client.Call("websockets.PublishAsync", makeMessage(command, broker, []byte("hello, PHP"), topics...), ret) - assert.NoError(t, err) - assert.True(t, ret.Ok) -} - -func publishAsync2(t *testing.T, command string, broker string, topics ...string) { - conn, err := net.Dial("tcp", "127.0.0.1:6001") - if err != nil { - panic(err) - } - - client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) - - ret := &websocketsv1.Response{} - err = client.Call("websockets.PublishAsync", makeMessage(command, broker, []byte("hello, PHP2"), topics...), ret) - assert.NoError(t, err) - assert.True(t, ret.Ok) -} - -func publish2(t *testing.T, command string, broker string, topics ...string) { - conn, err := net.Dial("tcp", "127.0.0.1:6001") - if err != nil { - panic(err) - } - - client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) - - ret := &websocketsv1.Response{} - err = client.Call("websockets.Publish", makeMessage(command, broker, []byte("hello, PHP2"), topics...), ret) - assert.NoError(t, err) - assert.True(t, ret.Ok) -} - -func messageWS(command string, broker string, payload []byte, topics ...string) *websocketsv1.Message { - return &websocketsv1.Message{ - Topics: topics, - Command: command, - //Broker: broker, - Payload: payload, - } -} - -func makeMessage(command string, broker string, payload []byte, topics ...string) *websocketsv1.Request { - m := &websocketsv1.Request{ - Messages: []*websocketsv1.Message{ - { - Topics: topics, - Command: command, - //Broker: broker, - Payload: payload, - }, - }, - } - - return m -} diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml index 6962eeb5..8436b65f 100644 --- a/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml +++ b/tests/plugins/broadcast/configs/.rr-broadcast-init.yaml @@ -9,7 +9,7 @@ server: relay_timeout: "20s" http: - address: 127.0.0.1:11111 + address: 127.0.0.1:21345 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] @@ -37,7 +37,7 @@ websockets: logs: mode: development - level: debug + level: error endure: grace_period: 120s diff --git a/tests/plugins/kv/storage_plugin_test.go b/tests/plugins/kv/storage_plugin_test.go index 24b66ae1..1e466e06 100644 --- a/tests/plugins/kv/storage_plugin_test.go +++ b/tests/plugins/kv/storage_plugin_test.go @@ -12,7 +12,6 @@ import ( endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - payload "github.com/spiral/roadrunner/v2/pkg/proto/kv/v1beta" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/kv" "github.com/spiral/roadrunner/v2/plugins/kv/drivers/boltdb" @@ -21,6 +20,7 @@ import ( "github.com/spiral/roadrunner/v2/plugins/memory" "github.com/spiral/roadrunner/v2/plugins/redis" rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + payload "github.com/spiral/roadrunner/v2/proto/kv/v1beta" "github.com/stretchr/testify/assert" ) diff --git a/tests/plugins/websockets/configs/.rr-websockets-memory-allow.yaml b/tests/plugins/websockets/configs/.rr-websockets-allow.yaml index f81e13e3..e6c43857 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-memory-allow.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-allow.yaml @@ -9,7 +9,7 @@ server: relay_timeout: "20s" http: - address: 127.0.0.1:11113 + address: 127.0.0.1:41278 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] diff --git a/tests/plugins/websockets/configs/.rr-websockets-redis-memory-local.yaml b/tests/plugins/websockets/configs/.rr-websockets-allow2.yaml index a077bf9e..d537a80b 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-redis-memory-local.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-allow2.yaml @@ -2,14 +2,14 @@ rpc: listen: tcp://127.0.0.1:6001 server: - command: "php ../../psr-worker-bench.php" + command: "php ../../worker-ok.php" user: "" group: "" relay: "pipes" relay_timeout: "20s" http: - address: 127.0.0.1:13235 + address: 127.0.0.1:41270 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] @@ -19,10 +19,15 @@ http: allocate_timeout: 60s destroy_timeout: 60s +redis: + addrs: + - "localhost:6379" broadcast: test: - driver: memory + driver: redis + addrs: + - "localhost:6379" websockets: broker: test @@ -31,7 +36,7 @@ websockets: logs: mode: development - level: debug + level: error endure: grace_period: 120s diff --git a/tests/plugins/websockets/configs/.rr-websockets-redis-no-section.yaml b/tests/plugins/websockets/configs/.rr-websockets-broker-no-section.yaml index d80993f2..ada23845 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-redis-no-section.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-broker-no-section.yaml @@ -19,6 +19,9 @@ http: allocate_timeout: 60s destroy_timeout: 60s +broadcast: + test1: + driver: no websockets: broker: test diff --git a/tests/plugins/websockets/configs/.rr-websockets-memory-deny.yaml b/tests/plugins/websockets/configs/.rr-websockets-deny.yaml index decb7dcf..594a746d 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-memory-deny.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-deny.yaml @@ -9,7 +9,7 @@ server: relay_timeout: "20s" http: - address: 127.0.0.1:11112 + address: 127.0.0.1:15587 max_request_size: 1024 middleware: [ "websockets" ] trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] diff --git a/tests/plugins/websockets/configs/.rr-websockets-deny2.yaml b/tests/plugins/websockets/configs/.rr-websockets-deny2.yaml new file mode 100644 index 00000000..4deea30a --- /dev/null +++ b/tests/plugins/websockets/configs/.rr-websockets-deny2.yaml @@ -0,0 +1,40 @@ +rpc: + listen: tcp://127.0.0.1:6001 + +server: + command: "php ../../worker-deny.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:15588 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +broadcast: + test: + driver: redis + addrs: + - "localhost:6379" + +websockets: + broker: test + allowed_origin: "*" + path: "/ws" + +logs: + mode: development + level: error + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/websockets/configs/.rr-websockets-init.yaml b/tests/plugins/websockets/configs/.rr-websockets-init.yaml index b6882d84..14472f8a 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-init.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-init.yaml @@ -25,9 +25,7 @@ redis: broadcast: default: - driver: redis - addrs: - - "localhost:6379" + driver: memory websockets: broker: default diff --git a/tests/plugins/websockets/configs/.rr-websockets-memory-stop.yaml b/tests/plugins/websockets/configs/.rr-websockets-stop.yaml index 5377aef2..5377aef2 100644 --- a/tests/plugins/websockets/configs/.rr-websockets-memory-stop.yaml +++ b/tests/plugins/websockets/configs/.rr-websockets-stop.yaml diff --git a/tests/plugins/websockets/websocket_plugin_test.go b/tests/plugins/websockets/websocket_plugin_test.go index cb78117f..29bf28be 100644 --- a/tests/plugins/websockets/websocket_plugin_test.go +++ b/tests/plugins/websockets/websocket_plugin_test.go @@ -16,7 +16,6 @@ import ( json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -26,6 +25,7 @@ import ( rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/utils" "github.com/stretchr/testify/assert" ) @@ -100,54 +100,20 @@ func TestBroadcastInit(t *testing.T) { time.Sleep(time.Second * 1) t.Run("TestWSInit", wsInit) - t.Run("RPCWsMemoryPubAsync", RPCWsMemoryPubAsync) - t.Run("RPCWsMemory", RPCWsMemory) + t.Run("RPCWsMemoryPubAsync", RPCWsPubAsync("11111")) + t.Run("RPCWsMemory", RPCWsPub("11111")) stopCh <- struct{}{} wg.Wait() } -func wsInit(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func TestWSRedisAndMemory(t *testing.T) { +func TestWSRedis(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-redis-memory.yaml", + Path: "configs/.rr-websockets-redis.yaml", Prefix: "rr", } @@ -159,7 +125,6 @@ func TestWSRedisAndMemory(t *testing.T) { &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, - &memory.Plugin{}, &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -210,21 +175,20 @@ func TestWSRedisAndMemory(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryPubAsync", RPCWsMemoryPubAsync) - t.Run("RPCWsMemory", RPCWsMemory) - t.Run("RPCWsRedis", RPCWsRedis) + t.Run("RPCWsRedisPubAsync", RPCWsPubAsync("13235")) + t.Run("RPCWsRedisPub", RPCWsPub("13235")) stopCh <- struct{}{} wg.Wait() } -func TestWSRedisAndMemoryGlobal(t *testing.T) { +func TestWSRedisNoSection(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-redis.yaml", + Path: "configs/.rr-websockets-broker-no-section.yaml", Prefix: "rr", } @@ -236,6 +200,35 @@ func TestWSRedisAndMemoryGlobal(t *testing.T) { &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, + &broadcast.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + _, err = cont.Serve() + assert.Error(t, err) +} + +func TestWSDeny(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-websockets-deny.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, &memory.Plugin{}, &broadcast.Plugin{}, ) @@ -287,20 +280,19 @@ func TestWSRedisAndMemoryGlobal(t *testing.T) { }() time.Sleep(time.Second * 1) - - t.Run("RPCWsRedis", RPCWsRedis) + t.Run("RPCWsMemoryDeny", RPCWsDeny("15587")) stopCh <- struct{}{} wg.Wait() } -func TestWSRedisNoSection(t *testing.T) { +func TestWSDeny2(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-redis-no-section.yaml", + Path: "configs/.rr-websockets-deny2.yaml", Prefix: "rr", } @@ -309,10 +301,9 @@ func TestWSRedisNoSection(t *testing.T) { &rpcPlugin.Plugin{}, &logger.ZapLogger{}, &server.Plugin{}, - &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, - &memory.Plugin{}, + &redis.Plugin{}, &broadcast.Plugin{}, ) assert.NoError(t, err) @@ -322,234 +313,60 @@ func TestWSRedisNoSection(t *testing.T) { t.Fatal(err) } - _, err = cont.Serve() - assert.Error(t, err) -} - -func RPCWsMemoryPubAsync(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - publishAsync(t, "", "memory", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + ch, err := cont.Serve() if err != nil { - panic(err) + t.Fatal(err) } - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + wg := &sync.WaitGroup{} + wg.Add(1) - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publishAsync(t, "", "memory", "foo") + stopCh := make(chan struct{}, 1) go func() { - time.Sleep(time.Second * 5) - publishAsync2(t, "", "memory", "foo2") - }() - - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP\"}", retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func RPCWsMemory(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-stopCh: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } } }() - d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - publish("", "memory", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) - - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publish("", "memory", "foo") - - go func() { - time.Sleep(time.Second * 5) - publish2(t, "", "memory", "foo2") - }() - - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP2\"}", retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) -} - -func RPCWsRedis(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:13235", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - - defer func() { - _ = resp.Body.Close() - }() - - d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - - publish("", "redis", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) - - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publish("", "redis", "foo") - - go func() { - time.Sleep(time.Second * 5) - publish2(t, "", "redis", "foo2") - }() + time.Sleep(time.Second * 1) + t.Run("RPCWsRedisDeny", RPCWsDeny("15588")) - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP2\"}", retMsg) + stopCh <- struct{}{} - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) + wg.Wait() } -func TestWSMemoryDeny(t *testing.T) { +func TestWSStop(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-memory-deny.yaml", + Path: "configs/.rr-websockets-stop.yaml", Prefix: "rr", } @@ -612,73 +429,40 @@ func TestWSMemoryDeny(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryDeny", RPCWsMemoryDeny) + t.Run("RPCWsStop", RPCWsMemoryStop("11114")) stopCh <- struct{}{} wg.Wait() } -func RPCWsMemoryDeny(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11112", Path: "/ws"} +func RPCWsMemoryStop(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 20, + } - c, resp, err := da.Dial(connURL.String(), nil) - assert.NoError(t, err) - assert.NotNil(t, c) - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} - defer func() { - if resp != nil && resp.Body != nil { + c, resp, err := da.Dial(connURL.String(), nil) + assert.NotNil(t, resp) + assert.Error(t, err) + assert.Nil(t, c) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) //nolint:staticcheck + assert.Equal(t, resp.Header.Get("Stop"), "we-dont-like-you") //nolint:staticcheck + if resp != nil && resp.Body != nil { //nolint:staticcheck _ = resp.Body.Close() } - }() - - d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) - if err != nil { - panic(err) } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err := c.ReadMessage() - retMsg := utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"#join","payload":["foo","foo2"]}`, retMsg) - - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) - } - - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) - - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) - - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) } -func TestWSMemoryStop(t *testing.T) { +func TestWSAllow(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-memory-stop.yaml", + Path: "configs/.rr-websockets-allow.yaml", Prefix: "rr", } @@ -741,38 +525,19 @@ func TestWSMemoryStop(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryStop", RPCWsMemoryStop) + t.Run("RPCWsMemoryAllow", RPCWsPub("41278")) stopCh <- struct{}{} wg.Wait() } -func RPCWsMemoryStop(t *testing.T) { - da := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: time.Second * 20, - } - - connURL := url.URL{Scheme: "ws", Host: "localhost:11114", Path: "/ws"} - - c, resp, err := da.Dial(connURL.String(), nil) - assert.NotNil(t, resp) - assert.Error(t, err) - assert.Nil(t, c) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) //nolint:staticcheck - assert.Equal(t, resp.Header.Get("Stop"), "we-dont-like-you") //nolint:staticcheck - if resp != nil && resp.Body != nil { //nolint:staticcheck - _ = resp.Body.Close() - } -} - -func TestWSMemoryOk(t *testing.T) { +func TestWSAllow2(t *testing.T) { cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) assert.NoError(t, err) cfg := &config.Viper{ - Path: "configs/.rr-websockets-memory-allow.yaml", + Path: "configs/.rr-websockets-allow2.yaml", Prefix: "rr", } @@ -835,30 +600,26 @@ func TestWSMemoryOk(t *testing.T) { }() time.Sleep(time.Second * 1) - t.Run("RPCWsMemoryAllow", RPCWsMemoryAllow) + t.Run("RPCWsMemoryAllow", RPCWsPub("41270")) stopCh <- struct{}{} wg.Wait() } -func RPCWsMemoryAllow(t *testing.T) { +func wsInit(t *testing.T) { da := websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: time.Second * 20, } - connURL := url.URL{Scheme: "ws", Host: "localhost:11113", Path: "/ws"} + connURL := url.URL{Scheme: "ws", Host: "localhost:11111", Path: "/ws"} c, resp, err := da.Dial(connURL.String(), nil) assert.NoError(t, err) - assert.NotNil(t, c) - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } + _ = resp.Body.Close() }() d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) @@ -876,48 +637,218 @@ func RPCWsMemoryAllow(t *testing.T) { // subscription done assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) - publish("", "memory", "foo") - - // VERIFY a makeMessage - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) +} - // //// LEAVE foo, foo2 ///////// - d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) - if err != nil { - panic(err) +func RPCWsPubAsync(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 18, + } + + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, resp, err := da.Dial(connURL.String(), nil) + assert.NoError(t, err) + + defer func() { + _ = resp.Body.Close() + }() + + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err := c.ReadMessage() + retMsg := utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) + + publishAsync(t, "placeholder", "foo") + + // VERIFY a makeMessage + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) + + // //// LEAVE foo ///////// + d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + + // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC + publishAsync(t, "placeholder", "foo") + + go func() { + time.Sleep(time.Second * 3) + publishAsync(t, "placeholder", "foo2") + }() + + // should be only makeMessage from the subscribed foo0 topic + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP\"}", retMsg) + + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) + assert.NoError(t, err) } +} - err = c.WriteMessage(websocket.BinaryMessage, d) - assert.NoError(t, err) +func RPCWsPub(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 20, + } - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} - // subscription done - assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + c, resp, err := da.Dial(connURL.String(), nil) + assert.NoError(t, err) - // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC - publish("", "memory", "foo") + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() - go func() { - time.Sleep(time.Second * 5) - publish2(t, "", "memory", "foo2") - }() + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) + if err != nil { + panic(err) + } - // should be only makeMessage from the subscribed foo2 topic - _, msg, err = c.ReadMessage() - retMsg = utils.AsString(msg) - assert.NoError(t, err) - assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP2\"}", retMsg) + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) - err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) - assert.NoError(t, err) + _, msg, err := c.ReadMessage() + retMsg := utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@join","payload":["foo","foo2"]}`, retMsg) + + publish("", "foo") + + // VERIFY a makeMessage + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo\",\"payload\":\"hello, PHP\"}", retMsg) + + // //// LEAVE foo, foo2 ///////// + d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + + // TRY TO PUBLISH TO UNSUBSCRIBED TOPIC + publish("", "foo") + + go func() { + time.Sleep(time.Second * 5) + publish2(t, "", "foo2") + }() + + // should be only makeMessage from the subscribed foo2 topic + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + assert.Equal(t, "{\"topic\":\"foo2\",\"payload\":\"hello, PHP2\"}", retMsg) + + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) + assert.NoError(t, err) + } +} + +func RPCWsDeny(port string) func(t *testing.T) { + return func(t *testing.T) { + da := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: time.Second * 20, + } + + connURL := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, resp, err := da.Dial(connURL.String(), nil) + assert.NoError(t, err) + assert.NotNil(t, c) + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + d, err := json.Marshal(messageWS("join", []byte("hello websockets"), "foo", "foo2")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err := c.ReadMessage() + retMsg := utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"#join","payload":["foo","foo2"]}`, retMsg) + + // //// LEAVE foo, foo2 ///////// + d, err = json.Marshal(messageWS("leave", []byte("hello websockets"), "foo")) + if err != nil { + panic(err) + } + + err = c.WriteMessage(websocket.BinaryMessage, d) + assert.NoError(t, err) + + _, msg, err = c.ReadMessage() + retMsg = utils.AsString(msg) + assert.NoError(t, err) + + // subscription done + assert.Equal(t, `{"topic":"@leave","payload":["foo"]}`, retMsg) + + err = c.WriteControl(websocket.CloseMessage, nil, time.Time{}) + assert.NoError(t, err) + } } +// --------------------------------------------------------------------------------------------------- + func publish(command string, topics ...string) { conn, err := net.Dial("tcp", "127.0.0.1:6001") if err != nil { @@ -947,20 +878,6 @@ func publishAsync(t *testing.T, command string, topics ...string) { assert.True(t, ret.Ok) } -func publishAsync2(t *testing.T, command string, topics ...string) { - conn, err := net.Dial("tcp", "127.0.0.1:6001") - if err != nil { - panic(err) - } - - client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) - - ret := &websocketsv1.Response{} - err = client.Call("broadcast.PublishAsync", makeMessage(command, []byte("hello, PHP2"), topics...), ret) - assert.NoError(t, err) - assert.True(t, ret.Ok) -} - func publish2(t *testing.T, command string, topics ...string) { conn, err := net.Dial("tcp", "127.0.0.1:6001") if err != nil { |