diff options
22 files changed, 839 insertions, 244 deletions
diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 69269557..62987771 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -91,6 +91,7 @@ jobs: go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/resetter.txt -covermode=atomic ./tests/plugins/resetter go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/rpc.txt -covermode=atomic ./tests/plugins/rpc go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/kv_plugin.txt -covermode=atomic ./tests/plugins/kv + go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/broadcast_plugin.txt -covermode=atomic ./tests/plugins/broadcast go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/websockets.txt -covermode=atomic ./tests/plugins/websockets go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage-ci/ws_origin.txt -covermode=atomic ./plugins/websockets docker-compose -f ./tests/docker-compose.yaml down diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 227c725b..f23f9b5d 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -90,6 +90,7 @@ jobs: go test -v -race ./tests/plugins/resetter go test -v -race ./tests/plugins/rpc go test -v -race ./tests/plugins/kv + go test -v -race ./tests/plugins/broadcast go test -v -race ./tests/plugins/websockets go test -v -race ./plugins/websockets docker-compose -f ./tests/docker-compose.yaml down @@ -31,6 +31,7 @@ test_coverage: go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/resetter.out -covermode=atomic ./tests/plugins/resetter go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/rpc.out -covermode=atomic ./tests/plugins/rpc go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/kv_plugin.out -covermode=atomic ./tests/plugins/kv + go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/broadcast_plugin.out -covermode=atomic ./tests/plugins/broadcast go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/ws_plugin.out -covermode=atomic ./tests/plugins/websockets go test -v -race -cover -tags=debug -coverpkg=./... -coverprofile=./coverage/ws_origin.out -covermode=atomic ./plugins/websockets cat ./coverage/*.out > ./coverage/summary.out @@ -61,6 +62,7 @@ test: ## Run application tests go test -v -race -tags=debug ./tests/plugins/resetter go test -v -race -tags=debug ./tests/plugins/rpc go test -v -race -tags=debug ./tests/plugins/kv + go test -v -race -tags=debug ./tests/plugins/broadcast go test -v -race -tags=debug ./tests/plugins/websockets go test -v -race -tags=debug ./plugins/websockets docker-compose -f tests/docker-compose.yaml down @@ -91,5 +93,6 @@ testGo1.17beta1: ## Run application tests go1.17beta1 test -v -race -tags=debug ./tests/plugins/rpc go1.17beta1 test -v -race -tags=debug ./tests/plugins/kv go1.17beta1 test -v -race -tags=debug ./tests/plugins/websockets + go1.17beta1 test -v -race -tags=debug ./tests/plugins/broadcast go1.17beta1 test -v -race -tags=debug ./plugins/websockets docker-compose -f tests/docker-compose.yaml down diff --git a/pkg/pubsub/interface.go b/pkg/pubsub/interface.go index 53f92cb8..06252d70 100644 --- a/pkg/pubsub/interface.go +++ b/pkg/pubsub/interface.go @@ -1,7 +1,5 @@ package pubsub -import websocketsv1beta "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" - /* This interface is in BETA. It might be changed. */ @@ -38,18 +36,19 @@ type Subscriber interface { // BETA interface type Publisher interface { // Publish one or multiple Channel. - Publish(messages []byte) error + Publish(message *Message) error // PublishAsync publish message and return immediately // If error occurred it will be printed into the logger - PublishAsync(messages []byte) + PublishAsync(message *Message) } // Reader interface should return next message type Reader interface { - Next() (*websocketsv1beta.Message, error) + Next() (*Message, error) } +// Constructor is a special pub-sub interface made to return a constructed PubSub type type Constructor interface { PSConstruct(key string) (PubSub, error) } diff --git a/pkg/pubsub/psmessage.go b/pkg/pubsub/psmessage.go new file mode 100644 index 00000000..e33d9284 --- /dev/null +++ b/pkg/pubsub/psmessage.go @@ -0,0 +1,15 @@ +package pubsub + +import json "github.com/json-iterator/go" + +// Message represents a single message with payload bound to a particular topic +type Message struct { + // Topic (channel in terms of redis) + Topic string `json:"topic"` + // Payload (on some decode stages might be represented as base64 string) + Payload []byte `json:"payload"` +} + +func (m *Message) MarshalBinary() (data []byte, err error) { + return json.Marshal(m) +} diff --git a/plugins/broadcast/plugin.go b/plugins/broadcast/plugin.go index 3b420a4b..04a4fb80 100644 --- a/plugins/broadcast/plugin.go +++ b/plugins/broadcast/plugin.go @@ -4,13 +4,12 @@ import ( "fmt" "sync" + "github.com/google/uuid" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" - websocketsv1beta "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" - "google.golang.org/protobuf/proto" ) const ( @@ -55,78 +54,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { } func (p *Plugin) Serve() chan error { - const op = errors.Op("broadcast_plugin_serve") - errCh := make(chan error, 1) - - // iterate over config - for k, v := range p.cfg.Data { - if v == nil { - continue - } - - // check type of the v - // should be a map[string]interface{} - switch t := v.(type) { - // correct type - case map[string]interface{}: - if _, ok := t[driver]; !ok { - errCh <- errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", k)) - return errCh - } - default: - errCh <- errors.E(op, errors.Str("wrong type detected in the configuration, please, check yaml indentation")) - return errCh - } - - // config key for the particular sub-driver kv.memcached - configKey := fmt.Sprintf("%s.%s", PluginName, k) - - switch v.(map[string]interface{})[driver] { - case memory: - if _, ok := p.constructors[memory]; !ok { - p.log.Warn("no memory drivers registered", "registered", p.publishers) - continue - } - ps, err := p.constructors[memory].PSConstruct(configKey) - if err != nil { - errCh <- errors.E(op, err) - return errCh - } - - // save the pubsub - p.publishers[k] = ps - case redis: - if _, ok := p.constructors[redis]; !ok { - p.log.Warn("no redis drivers registered", "registered", p.publishers) - continue - } - - // first - try local configuration - switch { - case p.cfgPlugin.Has(configKey): - ps, err := p.constructors[redis].PSConstruct(configKey) - if err != nil { - errCh <- errors.E(op, err) - return errCh - } - - // save the pubsub - p.publishers[k] = ps - case p.cfgPlugin.Has(redis): - ps, err := p.constructors[redis].PSConstruct(configKey) - if err != nil { - errCh <- errors.E(op, err) - return errCh - } - - // save the pubsub - p.publishers[k] = ps - continue - } - } - } - - return errCh + return make(chan error) } func (p *Plugin) Stop() error { @@ -140,61 +68,49 @@ func (p *Plugin) Collects() []interface{} { } // CollectPublishers collect all plugins who implement pubsub.Publisher interface -func (p *Plugin) CollectPublishers(name endure.Named, subscriber pubsub.Constructor) { +func (p *Plugin) CollectPublishers(name endure.Named, constructor pubsub.Constructor) { // key redis, value - interface - p.constructors[name.Name()] = subscriber + p.constructors[name.Name()] = constructor } // Publish is an entry point to the websocket PUBSUB -func (p *Plugin) Publish(m []byte) error { +func (p *Plugin) Publish(m *pubsub.Message) error { p.Lock() defer p.Unlock() const op = errors.Op("broadcast_plugin_publish") - msg := &websocketsv1beta.Message{} - err := proto.Unmarshal(m, msg) - if err != nil { - return errors.E(op, err) - } - - // Get payload - for i := 0; i < len(msg.GetTopics()); i++ { - if len(p.publishers) > 0 { - for j := range p.publishers { - err = p.publishers[j].Publish(m) - if err != nil { - return errors.E(op, err) - } + // check if any publisher registered + if len(p.publishers) > 0 { + for j := range p.publishers { + err := p.publishers[j].Publish(m) + if err != nil { + return errors.E(op, err) } - - return nil } - + return nil + } else { p.log.Warn("no publishers registered") } return nil } -func (p *Plugin) PublishAsync(m []byte) { +func (p *Plugin) PublishAsync(m *pubsub.Message) { go func() { p.Lock() defer p.Unlock() - msg := &websocketsv1beta.Message{} - err := proto.Unmarshal(m, msg) - if err != nil { - p.log.Error("message unmarshal") - } - - // Get payload - for i := 0; i < len(msg.GetTopics()); i++ { - if len(p.publishers) > 0 { - for j := range p.publishers { - p.publishers[j].PublishAsync(m) + // check if any publisher registered + if len(p.publishers) > 0 { + for j := range p.publishers { + err := p.publishers[j].Publish(m) + if err != nil { + p.log.Error("publishAsync", "error", err) + // continue publish to other registered publishers + continue } - return } + } else { p.log.Warn("no publishers registered") } }() @@ -202,10 +118,67 @@ func (p *Plugin) PublishAsync(m []byte) { func (p *Plugin) GetDriver(key string) (pubsub.SubReader, error) { const op = errors.Op("broadcast_plugin_get_driver") - // key - driver, default for example - // we should find `default` in the collected pubsubs constructors - if pub, ok := p.publishers[key]; ok { - return pub, nil + + // choose a driver + if val, ok := p.cfg.Data[key]; ok { + // check type of the v + // should be a map[string]interface{} + switch t := val.(type) { + // correct type + case map[string]interface{}: + if _, ok := t[driver]; !ok { + panic(errors.E(op, errors.Errorf("could not find mandatory driver field in the %s storage", val))) + } + default: + return nil, errors.E(op, errors.Str("wrong type detected in the configuration, please, check yaml indentation")) + } + + // config key for the particular sub-driver kv.memcached + configKey := fmt.Sprintf("%s.%s", PluginName, key) + + switch val.(map[string]interface{})[driver] { + case memory: + if _, ok := p.constructors[memory]; !ok { + return nil, errors.E(op, errors.Errorf("no memory drivers registered, registered: %s", p.publishers)) + } + ps, err := p.constructors[memory].PSConstruct(configKey) + if err != nil { + return nil, errors.E(op, err) + } + + // save the initialized publisher channel + // for the in-memory, register new publishers + p.publishers[uuid.NewString()] = ps + + return ps, nil + case redis: + if _, ok := p.constructors[redis]; !ok { + return nil, errors.E(op, errors.Errorf("no redis drivers registered, registered: %s", p.publishers)) + } + + // first - try local configuration + switch { + case p.cfgPlugin.Has(configKey): + ps, err := p.constructors[redis].PSConstruct(configKey) + if err != nil { + return nil, errors.E(op, err) + } + + // save the pubsub under a config key + // + p.publishers[configKey] = ps + return ps, nil + case p.cfgPlugin.Has(redis): + ps, err := p.constructors[redis].PSConstruct(configKey) + if err != nil { + return nil, errors.E(op, err) + } + + // save the pubsub + p.publishers[configKey] = ps + return ps, nil + } + } } return nil, errors.E(op, errors.Str("could not find driver by provided key")) } diff --git a/plugins/broadcast/rpc.go b/plugins/broadcast/rpc.go index 4c27cdc3..2ee211f8 100644 --- a/plugins/broadcast/rpc.go +++ b/plugins/broadcast/rpc.go @@ -2,9 +2,9 @@ package broadcast import ( "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" - "google.golang.org/protobuf/proto" ) // rpc collectors struct @@ -14,7 +14,7 @@ type rpc struct { } // Publish ... msg is a proto decoded payload -// see: pkg/pubsub/message.fbs +// see: root/proto func (r *rpc) Publish(in *websocketsv1.Request, out *websocketsv1.Response) error { const op = errors.Op("broadcast_publish") @@ -28,15 +28,23 @@ func (r *rpc) Publish(in *websocketsv1.Request, out *websocketsv1.Response) erro msgLen := len(in.GetMessages()) for i := 0; i < msgLen; i++ { - bb, err := proto.Marshal(in.GetMessages()[i]) - if err != nil { - return errors.E(op, err) - } + for j := 0; j < len(in.GetMessages()[i].GetTopics()); j++ { + if in.GetMessages()[i].GetTopics()[j] == "" { + r.log.Warn("message with empty topic, skipping") + // skip empty topics + continue + } + + tmp := &pubsub.Message{ + Topic: in.GetMessages()[i].GetTopics()[j], + Payload: in.GetMessages()[i].GetPayload(), + } - err = r.plugin.Publish(bb) - if err != nil { - out.Ok = false - return errors.E(op, err) + err := r.plugin.Publish(tmp) + if err != nil { + out.Ok = false + return errors.E(op, err) + } } } @@ -45,10 +53,8 @@ func (r *rpc) Publish(in *websocketsv1.Request, out *websocketsv1.Response) erro } // PublishAsync ... -// see: pkg/pubsub/message.fbs +// see: root/proto func (r *rpc) PublishAsync(in *websocketsv1.Request, out *websocketsv1.Response) error { - const op = errors.Op("publish_async") - // just return in case of nil message if in == nil { out.Ok = false @@ -60,13 +66,20 @@ func (r *rpc) PublishAsync(in *websocketsv1.Request, out *websocketsv1.Response) msgLen := len(in.GetMessages()) for i := 0; i < msgLen; i++ { - bb, err := proto.Marshal(in.GetMessages()[i]) - if err != nil { - out.Ok = false - return errors.E(op, err) - } + for j := 0; j < len(in.GetMessages()[i].GetTopics()); j++ { + if in.GetMessages()[i].GetTopics()[j] == "" { + r.log.Warn("message with empty topic, skipping") + // skip empty topics + continue + } + + tmp := &pubsub.Message{ + Topic: in.GetMessages()[i].GetTopics()[j], + Payload: in.GetMessages()[i].GetPayload(), + } - r.plugin.PublishAsync(bb) + r.plugin.PublishAsync(tmp) + } } out.Ok = true diff --git a/plugins/memory/pubsub.go b/plugins/memory/pubsub.go index 87638bd8..d027a8a5 100644 --- a/plugins/memory/pubsub.go +++ b/plugins/memory/pubsub.go @@ -6,14 +6,12 @@ import ( "github.com/spiral/roadrunner/v2/pkg/bst" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" - websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" - "google.golang.org/protobuf/proto" ) type PubSubDriver struct { sync.RWMutex // channel with the messages from the RPC - pushCh chan []byte + pushCh chan *pubsub.Message // user-subscribed topics storage bst.Storage log logger.Logger @@ -21,21 +19,21 @@ type PubSubDriver struct { func NewPubSubDriver(log logger.Logger, _ string) (pubsub.PubSub, error) { ps := &PubSubDriver{ - pushCh: make(chan []byte, 10), + pushCh: make(chan *pubsub.Message, 10), storage: bst.NewBST(), log: log, } return ps, nil } -func (p *PubSubDriver) Publish(message []byte) error { - p.pushCh <- message +func (p *PubSubDriver) Publish(msg *pubsub.Message) error { + p.pushCh <- msg return nil } -func (p *PubSubDriver) PublishAsync(message []byte) { +func (p *PubSubDriver) PublishAsync(msg *pubsub.Message) { go func() { - p.pushCh <- message + p.pushCh <- msg }() } @@ -67,7 +65,7 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { } } -func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { +func (p *PubSubDriver) Next() (*pubsub.Message, error) { msg := <-p.pushCh if msg == nil { return nil, nil @@ -76,20 +74,13 @@ func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { p.RLock() defer p.RUnlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - return nil, err - } - - // push only messages, which are subscribed + // push only messages, which topics are subscibed // TODO better??? - for i := 0; i < len(m.GetTopics()); i++ { - // if we have active subscribers - send a message to a topic - // or send nil instead - if ok := p.storage.Contains(m.GetTopics()[i]); ok { - return m, nil - } + // if we have active subscribers - send a message to a topic + // or send nil instead + if ok := p.storage.Contains(msg.Topic); ok { + return msg, nil } + return nil, nil } diff --git a/plugins/redis/fanin.go b/plugins/redis/fanin.go index 0bdd4cf5..40a99d20 100644 --- a/plugins/redis/fanin.go +++ b/plugins/redis/fanin.go @@ -4,12 +4,10 @@ import ( "context" "sync" - "github.com/spiral/roadrunner/v2/plugins/logger" - websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" - "google.golang.org/protobuf/proto" - "github.com/go-redis/redis/v8" "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/utils" ) @@ -23,13 +21,13 @@ type FanIn struct { log logger.Logger // out channel with all subs - out chan *websocketsv1.Message + out chan *pubsub.Message exit chan struct{} } func newFanIn(redisClient redis.UniversalClient, log logger.Logger) *FanIn { - out := make(chan *websocketsv1.Message, 100) + out := make(chan *pubsub.Message, 100) fi := &FanIn{ out: out, client: redisClient, @@ -67,14 +65,11 @@ func (fi *FanIn) read() { return } - m := &websocketsv1.Message{} - err := proto.Unmarshal(utils.AsBytes(msg.Payload), m) - if err != nil { - fi.log.Error("message unmarshal") - continue + fi.out <- &pubsub.Message{ + Topic: msg.Channel, + Payload: utils.AsBytes(msg.Payload), } - fi.out <- m case <-fi.exit: return } @@ -97,6 +92,6 @@ func (fi *FanIn) stop() error { return nil } -func (fi *FanIn) consume() <-chan *websocketsv1.Message { +func (fi *FanIn) consume() <-chan *pubsub.Message { return fi.out } diff --git a/plugins/redis/pubsub.go b/plugins/redis/pubsub.go index 9c3d0134..6ab281f3 100644 --- a/plugins/redis/pubsub.go +++ b/plugins/redis/pubsub.go @@ -9,8 +9,6 @@ import ( "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" - websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" - "google.golang.org/protobuf/proto" ) type PubSubDriver struct { @@ -83,41 +81,26 @@ func (p *PubSubDriver) stop() { }() } -func (p *PubSubDriver) Publish(msg []byte) error { +func (p *PubSubDriver) Publish(msg *pubsub.Message) error { p.Lock() defer p.Unlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - return errors.E(err) + f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload) + if f.Err() != nil { + return f.Err() } - for j := 0; j < len(m.GetTopics()); j++ { - f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) - if f.Err() != nil { - return f.Err() - } - } return nil } -func (p *PubSubDriver) PublishAsync(msg []byte) { +func (p *PubSubDriver) PublishAsync(msg *pubsub.Message) { go func() { p.Lock() defer p.Unlock() - m := &websocketsv1.Message{} - err := proto.Unmarshal(msg, m) - if err != nil { - p.log.Error("message unmarshal error") - return - } - for j := 0; j < len(m.GetTopics()); j++ { - f := p.universalClient.Publish(context.Background(), m.GetTopics()[j], msg) - if f.Err() != nil { - p.log.Error("redis publish", "error", f.Err()) - } + f := p.universalClient.Publish(context.Background(), msg.Topic, msg.Payload) + if f.Err() != nil { + p.log.Error("redis publish", "error", f.Err()) } }() } @@ -189,6 +172,6 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) { } // Next return next message -func (p *PubSubDriver) Next() (*websocketsv1.Message, error) { +func (p *PubSubDriver) Next() (*pubsub.Message, error) { return <-p.fanin.consume(), nil } diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go index 3d95ede0..00e053ec 100644 --- a/plugins/websockets/pool/workers_pool.go +++ b/plugins/websockets/pool/workers_pool.go @@ -7,7 +7,6 @@ import ( "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" - websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/utils" ) @@ -17,7 +16,7 @@ type WorkersPool struct { resPool sync.Pool log logger.Logger - queue chan *websocketsv1.Message + queue chan *pubsub.Message exit chan struct{} } @@ -25,7 +24,7 @@ type WorkersPool struct { func NewWorkersPool(subscriber pubsub.Subscriber, connections *sync.Map, log logger.Logger) *WorkersPool { wp := &WorkersPool{ connections: connections, - queue: make(chan *websocketsv1.Message, 100), + queue: make(chan *pubsub.Message, 100), subscriber: subscriber, log: log, exit: make(chan struct{}), @@ -43,7 +42,7 @@ func NewWorkersPool(subscriber pubsub.Subscriber, connections *sync.Map, log log return wp } -func (wp *WorkersPool) Queue(msg *websocketsv1.Message) { +func (wp *WorkersPool) Queue(msg *pubsub.Message) { wp.queue <- msg } @@ -83,57 +82,48 @@ func (wp *WorkersPool) do() { //nolint:gocognit return } _ = msg - if msg == nil { + if msg == nil || msg.Topic == "" { continue } - if len(msg.GetTopics()) == 0 { + + // get free map + res := wp.get() + + // get connections for the particular topic + wp.subscriber.Connections(msg.Topic, res) + + if len(res) == 0 { + wp.log.Info("no such topic", "topic", msg.Topic) + wp.put(res) continue } - // send a message to every topic - for i := 0; i < len(msg.GetTopics()); i++ { - // get free map - res := wp.get() + // res is a map with a connectionsID + for connID := range res { + c, ok := wp.connections.Load(connID) + if !ok { + wp.log.Warn("the websocket disconnected before the message being written to it", "topics", msg.Topic) + wp.put(res) + continue + } - // get connections for the particular topic - wp.subscriber.Connections(msg.GetTopics()[i], res) + d, err := json.Marshal(&Response{ + Topic: msg.Topic, + Payload: utils.AsString(msg.Payload), + }) - if len(res) == 0 { - wp.log.Info("no such topic", "topic", msg.GetTopics()[i]) + if err != nil { + wp.log.Error("error marshaling response", "error", err) wp.put(res) - continue + break } - // res is a map with a connectionsID - for connID := range res { - c, ok := wp.connections.Load(connID) - if !ok { - wp.log.Warn("the websocket disconnected before the message being written to it", "topics", msg.GetTopics()[i]) - wp.put(res) - continue - } - - response := &Response{ - Topic: msg.GetTopics()[i], - Payload: utils.AsString(msg.GetPayload()), - } - - d, err := json.Marshal(response) - if err != nil { - wp.log.Error("error marshaling response", "error", err) - wp.put(res) - break - } - - // put data into the bytes buffer - err = c.(*connection.Connection).Write(d) - if err != nil { - for i := 0; i < len(msg.GetTopics()); i++ { - wp.log.Error("error sending payload over the connection", "error", err, "topics", msg.GetTopics()[i]) - } - wp.put(res) - continue - } + // put data into the bytes buffer + err = c.(*connection.Connection).Write(d) + if err != nil { + wp.log.Error("error sending payload over the connection", "error", err, "topic", msg.Topic) + wp.put(res) + continue } } case <-wp.exit: diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 67d5476b..b6ba0f66 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -9,3 +9,7 @@ services: image: redis:6 ports: - "6379:6379" + redis2: + image: redis:6 + ports: + - "6378:6379" diff --git a/tests/plugins/broadcast/broadcast_plugin_test.go b/tests/plugins/broadcast/broadcast_plugin_test.go index ce1aed45..d6510058 100644 --- a/tests/plugins/broadcast/broadcast_plugin_test.go +++ b/tests/plugins/broadcast/broadcast_plugin_test.go @@ -1,13 +1,18 @@ package broadcast import ( + "net" + "net/rpc" "os" "os/signal" "sync" "syscall" "testing" + "time" + "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" + goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -17,6 +22,9 @@ import ( rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets" + websocketsv1 "github.com/spiral/roadrunner/v2/proto/websockets/v1beta" + "github.com/spiral/roadrunner/v2/tests/mocks" + "github.com/spiral/roadrunner/v2/tests/plugins/broadcast/plugins" "github.com/stretchr/testify/assert" ) @@ -112,6 +120,8 @@ func TestBroadcastConfigError(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, + + &plugins.Plugin1{}, ) assert.NoError(t, err) @@ -134,11 +144,18 @@ func TestBroadcastNoConfig(t *testing.T) { Prefix: "rr", } + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("worker destructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("worker constructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("Started RPC service", "address", "tcp://127.0.0.1:6001", "services", []string{}).MinTimes(1) + err = cont.RegisterAll( cfg, &broadcast.Plugin{}, &rpcPlugin.Plugin{}, - &logger.ZapLogger{}, + mockLogger, &server.Plugin{}, &redis.Plugin{}, &websockets.Plugin{}, @@ -157,3 +174,183 @@ func TestBroadcastNoConfig(t *testing.T) { _, err = cont.Serve() assert.NoError(t, err) } + +func TestBroadcastSameSubscriber(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.ErrorLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broadcast-same-section.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("worker destructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("worker constructed", "pid", gomock.Any()).AnyTimes() + mockLogger.EXPECT().Debug("Started RPC service", "address", "tcp://127.0.0.1:6002", "services", []string{"broadcast"}).MinTimes(1) + mockLogger.EXPECT().Debug("message published", "msg", gomock.Any()).MinTimes(1) + + mockLogger.EXPECT().Info(`plugin1: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin1: {foo2 hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin1: {foo3 hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin2: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin3: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin4: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin5: {foo hello}`).Times(3) + mockLogger.EXPECT().Info(`plugin6: {foo hello}`).Times(3) + + err = cont.RegisterAll( + cfg, + &broadcast.Plugin{}, + &rpcPlugin.Plugin{}, + mockLogger, + &server.Plugin{}, + &redis.Plugin{}, + &websockets.Plugin{}, + &httpPlugin.Plugin{}, + &memory.Plugin{}, + + // test - redis + // test2 - redis (port 6378) + // test3 - memory + // test4 - memory + &plugins.Plugin1{}, // foo, foo2, foo3 test + &plugins.Plugin2{}, // foo, test + &plugins.Plugin3{}, // foo, test2 + &plugins.Plugin4{}, // foo, test3 + &plugins.Plugin5{}, // foo, test4 + &plugins.Plugin6{}, // foo, test3 + ) + + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + if err != nil { + t.Fatal(err) + } + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + stopCh := make(chan struct{}, 1) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-stopCh: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 2) + + t.Run("PublishHelloFooFoo2Foo3", BroadcastPublishFooFoo2Foo3) + t.Run("PublishHelloFoo2", BroadcastPublishFoo2) + t.Run("PublishHelloFoo3", BroadcastPublishFoo3) + t.Run("PublishAsyncHelloFooFoo2Foo3", BroadcastPublishAsyncFooFoo2Foo3) + + time.Sleep(time.Second * 4) + stopCh <- struct{}{} + + wg.Wait() +} + +func BroadcastPublishFooFoo2Foo3(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6002") + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.Publish", makeMessage([]byte("hello"), "foo", "foo2", "foo3"), ret) + if err != nil { + t.Fatal(err) + } +} + +func BroadcastPublishFoo2(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6002") + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.Publish", makeMessage([]byte("hello"), "foo"), ret) + if err != nil { + t.Fatal(err) + } +} +func BroadcastPublishFoo3(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6002") + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.Publish", makeMessage([]byte("hello"), "foo3"), ret) + if err != nil { + t.Fatal(err) + } +} +func BroadcastPublishAsyncFooFoo2Foo3(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6002") + if err != nil { + t.Fatal(err) + } + + client := rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)) + + ret := &websocketsv1.Response{} + err = client.Call("broadcast.PublishAsync", makeMessage([]byte("hello"), "foo", "foo2", "foo3"), ret) + if err != nil { + t.Fatal(err) + } +} + +func makeMessage(payload []byte, topics ...string) *websocketsv1.Request { + m := &websocketsv1.Request{ + Messages: []*websocketsv1.Message{ + { + Topics: topics, + Payload: payload, + }, + }, + } + + return m +} diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml index b01dad1e..d8daa251 100644 --- a/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml +++ b/tests/plugins/broadcast/configs/.rr-broadcast-config-error.yaml @@ -25,7 +25,7 @@ broadcast: logs: mode: development - level: error + level: debug endure: grace_period: 120s diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml index d324284d..90790869 100644 --- a/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml +++ b/tests/plugins/broadcast/configs/.rr-broadcast-no-config.yaml @@ -21,7 +21,7 @@ http: logs: mode: development - level: error + level: debug endure: grace_period: 120s diff --git a/tests/plugins/broadcast/configs/.rr-broadcast-same-section.yaml b/tests/plugins/broadcast/configs/.rr-broadcast-same-section.yaml new file mode 100644 index 00000000..360e05e5 --- /dev/null +++ b/tests/plugins/broadcast/configs/.rr-broadcast-same-section.yaml @@ -0,0 +1,43 @@ +rpc: + listen: tcp://127.0.0.1:6002 + +server: + command: "php ../../psr-worker-bench.php" + user: "" + group: "" + relay: "pipes" + relay_timeout: "20s" + +http: + address: 127.0.0.1:21345 + max_request_size: 1024 + middleware: [ "websockets" ] + trusted_subnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + num_workers: 2 + max_jobs: 0 + allocate_timeout: 60s + destroy_timeout: 60s + +broadcast: + test: + driver: redis + addrs: + - "localhost:6379" + test2: + driver: redis + addrs: + - "localhost:6378" + test3: + driver: memory + test4: + driver: memory + +logs: + mode: development + level: debug + +endure: + grace_period: 120s + print_graph: false + log_level: error diff --git a/tests/plugins/broadcast/plugins/plugin1.go b/tests/plugins/broadcast/plugins/plugin1.go new file mode 100644 index 00000000..d3b16256 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin1.go @@ -0,0 +1,67 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin1Name = "plugin1" + +type Plugin1 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin1) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test") + if err != nil { + errCh <- err + return errCh + } + + err = p.driver.Subscribe("1", "foo", "foo2", "foo3") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin1Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin1) Stop() error { + _ = p.driver.Unsubscribe("1", "foo") + _ = p.driver.Unsubscribe("1", "foo2") + _ = p.driver.Unsubscribe("1", "foo3") + return nil +} + +func (p *Plugin1) Name() string { + return Plugin1Name +} diff --git a/tests/plugins/broadcast/plugins/plugin2.go b/tests/plugins/broadcast/plugins/plugin2.go new file mode 100644 index 00000000..2bd819d2 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin2.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin2Name = "plugin2" + +type Plugin2 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin2) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin2) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("2", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin2Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin2) Stop() error { + _ = p.driver.Unsubscribe("2", "foo") + return nil +} + +func (p *Plugin2) Name() string { + return Plugin2Name +} diff --git a/tests/plugins/broadcast/plugins/plugin3.go b/tests/plugins/broadcast/plugins/plugin3.go new file mode 100644 index 00000000..ef926222 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin3.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin3Name = "plugin3" + +type Plugin3 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin3) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin3) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test2") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("3", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin3Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin3) Stop() error { + _ = p.driver.Unsubscribe("3", "foo") + return nil +} + +func (p *Plugin3) Name() string { + return Plugin3Name +} diff --git a/tests/plugins/broadcast/plugins/plugin4.go b/tests/plugins/broadcast/plugins/plugin4.go new file mode 100644 index 00000000..c9b94777 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin4.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin4Name = "plugin4" + +type Plugin4 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin4) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin4) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test3") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("4", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin4Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin4) Stop() error { + _ = p.driver.Unsubscribe("4", "foo") + return nil +} + +func (p *Plugin4) Name() string { + return Plugin4Name +} diff --git a/tests/plugins/broadcast/plugins/plugin5.go b/tests/plugins/broadcast/plugins/plugin5.go new file mode 100644 index 00000000..01562a8f --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin5.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin5Name = "plugin5" + +type Plugin5 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin5) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin5) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test4") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("5", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin5Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin5) Stop() error { + _ = p.driver.Unsubscribe("5", "foo") + return nil +} + +func (p *Plugin5) Name() string { + return Plugin5Name +} diff --git a/tests/plugins/broadcast/plugins/plugin6.go b/tests/plugins/broadcast/plugins/plugin6.go new file mode 100644 index 00000000..76f2d6e8 --- /dev/null +++ b/tests/plugins/broadcast/plugins/plugin6.go @@ -0,0 +1,64 @@ +package plugins + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/broadcast" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const Plugin6Name = "plugin6" + +type Plugin6 struct { + log logger.Logger + b broadcast.Broadcaster + driver pubsub.SubReader +} + +func (p *Plugin6) Init(log logger.Logger, b broadcast.Broadcaster) error { + p.log = log + p.b = b + return nil +} + +func (p *Plugin6) Serve() chan error { + errCh := make(chan error, 1) + + var err error + p.driver, err = p.b.GetDriver("test") + if err != nil { + panic(err) + } + + err = p.driver.Subscribe("6", "foo") + if err != nil { + panic(err) + } + + go func() { + for { + msg, err := p.driver.Next() + if err != nil { + panic(err) + } + + if msg == nil { + continue + } + + p.log.Info(fmt.Sprintf("%s: %s", Plugin6Name, *msg)) + } + }() + + return errCh +} + +func (p *Plugin6) Stop() error { + _ = p.driver.Unsubscribe("6", "foo") + return nil +} + +func (p *Plugin6) Name() string { + return Plugin6Name +} |