diff options
23 files changed, 429 insertions, 146 deletions
diff --git a/pkg/pubsub/interface.go b/pkg/pubsub/interface.go index 80dab0c3..caf8783f 100644 --- a/pkg/pubsub/interface.go +++ b/pkg/pubsub/interface.go @@ -11,6 +11,7 @@ type PubSub interface { type Subscriber interface { // Subscribe broker to one or multiple topics. Subscribe(topics ...string) error + // Unsubscribe from one or multiply topics Unsubscribe(topics ...string) error } @@ -18,21 +19,14 @@ type Subscriber interface { // Publisher publish one or more messages type Publisher interface { // Publish one or multiple Channel. - Publish(messages []Message) error + Publish(messages []*Message) error // PublishAsync publish message and return immediately // If error occurred it will be printed into the logger - PublishAsync(messages []Message) + PublishAsync(messages []*Message) } // Reader interface should return next message type Reader interface { - Next() (Message, error) -} - -type Message interface { - Command() string - Payload() []byte - Topics() []string - Broker() string + Next() (*Message, error) } diff --git a/pkg/pubsub/message.go b/pkg/pubsub/message.go index 17e6780f..c1a7246a 100644 --- a/pkg/pubsub/message.go +++ b/pkg/pubsub/message.go @@ -4,40 +4,21 @@ import ( json "github.com/json-iterator/go" ) -type Msg struct { +type Message struct { // Topic message been pushed into. - Topics_ []string `json:"topic"` + Topics []string `json:"topic"` // Command (join, leave, headers) - Command_ string `json:"command"` + Command string `json:"command"` // Broker (redis, memory) - Broker_ string `json:"broker"` + Broker string `json:"broker"` // Payload to be broadcasted - Payload_ []byte `json:"payload"` + Payload []byte `json:"payload"` } // MarshalBinary needed to marshal message for the redis -func (m *Msg) MarshalBinary() ([]byte, error) { +func (m *Message) MarshalBinary() ([]byte, error) { return json.Marshal(m) } - -// Payload in raw bytes -func (m *Msg) Payload() []byte { - return m.Payload_ -} - -// Command for the connection -func (m *Msg) Command() string { - return m.Command_ -} - -// Topics to subscribe -func (m *Msg) Topics() []string { - return m.Topics_ -} - -func (m *Msg) Broker() string { - return m.Broker_ -} diff --git a/plugins/channel/interface.go b/plugins/channel/interface.go new file mode 100644 index 00000000..9a405e89 --- /dev/null +++ b/plugins/channel/interface.go @@ -0,0 +1,8 @@ +package channel + +// Hub used as a channel between two or more plugins +// this is not a PUBLIC plugin, API might be changed at any moment +type Hub interface { + SendCh() chan interface{} + ReceiveCh() chan interface{} +} diff --git a/plugins/channel/plugin.go b/plugins/channel/plugin.go new file mode 100644 index 00000000..8901c42e --- /dev/null +++ b/plugins/channel/plugin.go @@ -0,0 +1,52 @@ +package channel + +import ( + "sync" +) + +const ( + PluginName string = "hub" +) + +type Plugin struct { + sync.Mutex + send chan interface{} + receive chan interface{} +} + +func (p *Plugin) Init() error { + p.Lock() + defer p.Unlock() + p.send = make(chan interface{}) + p.receive = make(chan interface{}) + return nil +} + +func (p *Plugin) Serve() chan error { + return make(chan error) +} + +func (p *Plugin) Stop() error { + close(p.receive) + return nil +} + +func (p *Plugin) SendCh() chan interface{} { + p.Lock() + defer p.Unlock() + // bi-directional queue + return p.send +} + +func (p *Plugin) ReceiveCh() chan interface{} { + p.Lock() + defer p.Unlock() + // bi-directional queue + return p.receive +} + +func (p *Plugin) Available() {} + +func (p *Plugin) Name() string { + return PluginName +} diff --git a/plugins/http/channel.go b/plugins/http/channel.go new file mode 100644 index 00000000..42b73730 --- /dev/null +++ b/plugins/http/channel.go @@ -0,0 +1,28 @@ +package http + +import ( + "net/http" +) + +// messages method used to read messages from the ws plugin with the auth requests for the topics and server +func (p *Plugin) messages() { + for msg := range p.hub.ReceiveCh() { + p.RLock() + // msg here is the structure with http.ResponseWriter and http.Request + rmsg := msg.(struct { + RW http.ResponseWriter + Req *http.Request + }) + + p.handler.ServeHTTP(rmsg.RW, rmsg.Req) + + p.hub.SendCh() <- struct { + RW http.ResponseWriter + Req *http.Request + }{ + rmsg.RW, + rmsg.Req, + } + p.RUnlock() + } +} diff --git a/plugins/http/plugin.go b/plugins/http/plugin.go index 8bcffb63..38b3621f 100644 --- a/plugins/http/plugin.go +++ b/plugins/http/plugin.go @@ -14,6 +14,7 @@ import ( "github.com/spiral/roadrunner/v2/pkg/process" "github.com/spiral/roadrunner/v2/pkg/worker" handler "github.com/spiral/roadrunner/v2/pkg/worker_handler" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" httpConfig "github.com/spiral/roadrunner/v2/plugins/http/config" @@ -67,11 +68,14 @@ type Plugin struct { http *http.Server https *http.Server fcgi *http.Server + + // message bus + hub channel.Hub } // Init must return configure svc and return true if svc hasStatus enabled. Must return error in case of // misconfiguration. Services must not be used without proper configuration pushed first. -func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server server.Server) error { +func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server server.Server, channel channel.Hub) error { const op = errors.Op("http_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) @@ -105,6 +109,9 @@ func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server serv p.cfg.Env[RrMode] = "http" p.server = server + p.hub = channel + + go p.messages() return nil } diff --git a/plugins/memory/plugin.go b/plugins/memory/plugin.go index 2ad041aa..49c187bc 100644 --- a/plugins/memory/plugin.go +++ b/plugins/memory/plugin.go @@ -15,14 +15,14 @@ type Plugin struct { log logger.Logger // channel with the messages from the RPC - pushCh chan pubsub.Message + pushCh chan *pubsub.Message // user-subscribed topics topics sync.Map } func (p *Plugin) Init(log logger.Logger) error { p.log = log - p.pushCh = make(chan pubsub.Message, 100) + p.pushCh = make(chan *pubsub.Message, 100) return nil } @@ -34,14 +34,14 @@ func (p *Plugin) Name() string { return PluginName } -func (p *Plugin) Publish(messages []pubsub.Message) error { +func (p *Plugin) Publish(messages []*pubsub.Message) error { for i := 0; i < len(messages); i++ { p.pushCh <- messages[i] } return nil } -func (p *Plugin) PublishAsync(messages []pubsub.Message) { +func (p *Plugin) PublishAsync(messages []*pubsub.Message) { go func() { for i := 0; i < len(messages); i++ { p.pushCh <- messages[i] @@ -63,12 +63,17 @@ func (p *Plugin) Unsubscribe(topics ...string) error { return nil } -func (p *Plugin) Next() (pubsub.Message, error) { +func (p *Plugin) Next() (*pubsub.Message, error) { msg := <-p.pushCh + + if msg == nil { + return nil, nil + } + // push only messages, which are subscribed // TODO better??? - for i := 0; i < len(msg.Topics()); i++ { - if _, ok := p.topics.Load(msg.Topics()[i]); ok { + for i := 0; i < len(msg.Topics); i++ { + if _, ok := p.topics.Load(msg.Topics[i]); ok { return msg, nil } } diff --git a/plugins/redis/fanin.go b/plugins/redis/fanin.go index 8e924b2d..93b13124 100644 --- a/plugins/redis/fanin.go +++ b/plugins/redis/fanin.go @@ -22,13 +22,13 @@ type FanIn struct { log logger.Logger // out channel with all subs - out chan pubsub.Message + out chan *pubsub.Message exit chan struct{} } func NewFanIn(redisClient redis.UniversalClient, log logger.Logger) *FanIn { - out := make(chan pubsub.Message, 100) + out := make(chan *pubsub.Message, 100) fi := &FanIn{ out: out, client: redisClient, @@ -65,7 +65,7 @@ func (fi *FanIn) read() { if !ok { return } - m := &pubsub.Msg{} + m := &pubsub.Message{} err := json.Unmarshal(utils.AsBytes(msg.Payload), m) if err != nil { fi.log.Error("failed to unmarshal payload", "error", err.Error()) @@ -95,6 +95,6 @@ func (fi *FanIn) Stop() error { return nil } -func (fi *FanIn) Consume() <-chan pubsub.Message { +func (fi *FanIn) Consume() <-chan *pubsub.Message { return fi.out } diff --git a/plugins/redis/plugin.go b/plugins/redis/plugin.go index 24ed1f92..c1480de8 100644 --- a/plugins/redis/plugin.go +++ b/plugins/redis/plugin.go @@ -101,13 +101,13 @@ func (p *Plugin) Name() string { // Available interface implementation func (p *Plugin) Available() {} -func (p *Plugin) Publish(msg []pubsub.Message) error { +func (p *Plugin) Publish(msg []*pubsub.Message) error { p.Lock() defer p.Unlock() for i := 0; i < len(msg); i++ { - for j := 0; j < len(msg[i].Topics()); j++ { - f := p.universalClient.Publish(context.Background(), msg[i].Topics()[j], msg[i]) + for j := 0; j < len(msg[i].Topics); j++ { + f := p.universalClient.Publish(context.Background(), msg[i].Topics[j], msg[i]) if f.Err() != nil { return f.Err() } @@ -116,15 +116,15 @@ func (p *Plugin) Publish(msg []pubsub.Message) error { return nil } -func (p *Plugin) PublishAsync(msg []pubsub.Message) { +func (p *Plugin) PublishAsync(msg []*pubsub.Message) { go func() { p.Lock() defer p.Unlock() for i := 0; i < len(msg); i++ { - for j := 0; j < len(msg[i].Topics()); j++ { - f := p.universalClient.Publish(context.Background(), msg[i].Topics()[j], msg[i]) + for j := 0; j < len(msg[i].Topics); j++ { + f := p.universalClient.Publish(context.Background(), msg[i].Topics[j], msg[i]) if f.Err() != nil { - p.log.Error("errors publishing message", "topic", msg[i].Topics()[j], "error", f.Err().Error()) + p.log.Error("errors publishing message", "topic", msg[i].Topics[j], "error", f.Err().Error()) continue } } @@ -141,6 +141,6 @@ func (p *Plugin) Unsubscribe(topics ...string) error { } // Next return next message -func (p *Plugin) Next() (pubsub.Message, error) { +func (p *Plugin) Next() (*pubsub.Message, error) { return <-p.fanin.Consume(), nil } diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index 1aa54be9..87fed3a6 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -1,14 +1,20 @@ package executor import ( + "fmt" + "net/http" + "sync" + "github.com/fasthttp/websocket" json "github.com/json-iterator/go" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/storage" + "github.com/spiral/roadrunner/v2/plugins/websockets/validator" ) type Response struct { @@ -17,6 +23,7 @@ type Response struct { } type Executor struct { + sync.Mutex conn *connection.Connection storage *storage.Storage log logger.Logger @@ -25,17 +32,24 @@ type Executor struct { connID string // map with the pubsub drivers - pubsub map[string]pubsub.PubSub + pubsub map[string]pubsub.PubSub + actualTopics map[string]struct{} + + hub channel.Hub + req *http.Request } // NewExecutor creates protected connection and starts command loop -func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub) *Executor { +func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub, hub channel.Hub, r *http.Request) *Executor { return &Executor{ - conn: conn, - connID: connID, - storage: bst, - log: log, - pubsub: pubsubs, + conn: conn, + connID: connID, + storage: bst, + log: log, + pubsub: pubsubs, + hub: hub, + actualTopics: make(map[string]struct{}, 10), + req: r, } } @@ -52,7 +66,7 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit return errors.E(op, err) } - msg := &pubsub.Msg{} + msg := &pubsub.Message{} err = json.Unmarshal(data, msg) if err != nil { @@ -60,76 +74,149 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit continue } - switch msg.Command() { + // nil message, continue + if msg == nil { + e.log.Warn("get nil message, skipping") + continue + } + + switch msg.Command { // handle leave case commands.Join: e.log.Debug("get join command", "msg", msg) - // associate connection with topics - e.storage.InsertMany(e.connID, msg.Topics()) + + err := validator.NewValidator().AssertTopicsAccess(e.hub, e.req, msg.Topics...) + if err != nil { + resp := &Response{ + Topic: "#join", + Payload: msg.Topics, + } + + packet, errJ := json.Marshal(resp) + if errJ != nil { + e.log.Error("error marshal the body", "error", errJ) + return errors.E(op, fmt.Errorf("%v,%v", err, errJ)) + } + + errW := e.conn.Write(websocket.BinaryMessage, packet) + if errW != nil { + e.log.Error("error writing payload to the connection", "payload", packet, "error", errW) + return errors.E(op, fmt.Errorf("%v,%v", err, errW)) + } + + continue + } resp := &Response{ Topic: "@join", - Payload: msg.Topics(), + Payload: msg.Topics, } packet, err := json.Marshal(resp) if err != nil { e.log.Error("error marshal the body", "error", err) - continue + return errors.E(op, err) } err = e.conn.Write(websocket.BinaryMessage, packet) if err != nil { e.log.Error("error writing payload to the connection", "payload", packet, "error", err) - continue + return errors.E(op, err) } // subscribe to the topic - if br, ok := e.pubsub[msg.Broker()]; ok { - err = br.Subscribe(msg.Topics()...) + if br, ok := e.pubsub[msg.Broker]; ok { + err = e.Set(br, msg.Topics) if err != nil { - e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error()) - // in case of error, unsubscribe connection from the dead topics - _ = br.Unsubscribe(msg.Topics()...) - continue + return errors.E(op, err) } } // handle leave case commands.Leave: e.log.Debug("get leave command", "msg", msg) - // remove associated connections from the storage - e.storage.RemoveMany(e.connID, msg.Topics()) + // prepare response resp := &Response{ Topic: "@leave", - Payload: msg.Topics(), + Payload: msg.Topics, } packet, err := json.Marshal(resp) if err != nil { e.log.Error("error marshal the body", "error", err) - continue + return errors.E(op, err) } err = e.conn.Write(websocket.BinaryMessage, packet) if err != nil { e.log.Error("error writing payload to the connection", "payload", packet, "error", err) - continue + return errors.E(op, err) } - if br, ok := e.pubsub[msg.Broker()]; ok { - err = br.Unsubscribe(msg.Topics()...) + if br, ok := e.pubsub[msg.Broker]; ok { + err = e.Leave(br, msg.Topics) if err != nil { - e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error()) - continue + return errors.E(op, err) } } case commands.Headers: default: - e.log.Warn("unknown command", "command", msg.Command()) + e.log.Warn("unknown command", "command", msg.Command) } } } + +func (e *Executor) Set(br pubsub.PubSub, topics []string) error { + // associate connection with topics + err := br.Subscribe(topics...) + if err != nil { + e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + // in case of error, unsubscribe connection from the dead topics + _ = br.Unsubscribe(topics...) + return err + } + + e.storage.InsertMany(e.connID, topics) + + // save topics for the connection + for i := 0; i < len(topics); i++ { + e.actualTopics[topics[i]] = struct{}{} + } + + return nil +} + +func (e *Executor) Leave(br pubsub.PubSub, topics []string) error { + // remove associated connections from the storage + e.storage.RemoveMany(e.connID, topics) + err := br.Unsubscribe(topics...) + if err != nil { + e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + return err + } + + // remove topics for the connection + for i := 0; i < len(topics); i++ { + delete(e.actualTopics, topics[i]) + } + + return nil +} + +func (e *Executor) CleanUp() { + for topic := range e.actualTopics { + // remove from the bst + e.storage.Remove(e.connID, topic) + + for _, ps := range e.pubsub { + _ = ps.Unsubscribe(topic) + } + } + + for k := range e.actualTopics { + delete(e.actualTopics, k) + } +} diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index 76ef800d..2a060716 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -10,12 +10,15 @@ import ( endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/executor" "github.com/spiral/roadrunner/v2/plugins/websockets/pool" "github.com/spiral/roadrunner/v2/plugins/websockets/storage" + "github.com/spiral/roadrunner/v2/plugins/websockets/validator" ) const ( @@ -23,7 +26,7 @@ const ( ) type Plugin struct { - sync.RWMutex + mu sync.RWMutex // Collection with all available pubsubs pubsubs map[string]pubsub.PubSub @@ -34,10 +37,13 @@ type Plugin struct { connections sync.Map storage *storage.Storage + // GO workers pool workersPool *pool.WorkersPool + + hub channel.Hub } -func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.Hub) error { const op = errors.Op("websockets_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) @@ -52,6 +58,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error { p.log = log p.storage = storage.NewStorage() p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log) + p.hub = channel return nil } @@ -69,10 +76,6 @@ func (p *Plugin) Serve() chan error { return } - if data == nil { - continue - } - p.workersPool.Queue(data) } }(v) @@ -115,6 +118,22 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } + p.mu.Lock() + + r = attributes.Init(r) + + err := validator.NewValidator().AssertServerAccess(p.hub, r) + if err != nil { + // show the error to the user + if av, ok := err.(*validator.AccessValidator); ok { + av.Copy(w) + } else { + w.WriteHeader(400) + return + } + } + + p.mu.Unlock() // connection upgrader upgraded := websocket.Upgrader{ @@ -154,13 +173,15 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { p.connections.Delete(connectionID) }() + p.mu.Lock() // Executor wraps a connection to have a safe abstraction - p.Lock() - e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs) - p.Unlock() + e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.hub, r) + p.mu.Unlock() p.log.Info("websocket client connected", "uuid", connectionID) + defer e.CleanUp() + err = e.StartCommandLoop() if err != nil { p.log.Error("command loop error", "error", err.Error()) @@ -170,32 +191,32 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { } // Publish is an entry point to the websocket PUBSUB -func (p *Plugin) Publish(msg []pubsub.Message) error { - p.Lock() - defer p.Unlock() +func (p *Plugin) Publish(msg []*pubsub.Message) error { + p.mu.Lock() + defer p.mu.Unlock() for i := 0; i < len(msg); i++ { - for j := 0; j < len(msg[i].Topics()); j++ { - if br, ok := p.pubsubs[msg[i].Broker()]; ok { + for j := 0; j < len(msg[i].Topics); j++ { + if br, ok := p.pubsubs[msg[i].Broker]; ok { err := br.Publish(msg) if err != nil { return errors.E(err) } } else { - p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker()) + p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker) } } } return nil } -func (p *Plugin) PublishAsync(msg []pubsub.Message) { +func (p *Plugin) PublishAsync(msg []*pubsub.Message) { go func() { - p.Lock() - defer p.Unlock() + p.mu.Lock() + defer p.mu.Unlock() for i := 0; i < len(msg); i++ { - for j := 0; j < len(msg[i].Topics()); j++ { - err := p.pubsubs[msg[i].Broker()].Publish(msg) + for j := 0; j < len(msg[i].Topics); j++ { + err := p.pubsubs[msg[i].Broker].Publish(msg) if err != nil { p.log.Error("publish async error", "error", err) return diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go index 87e931d0..8f18580f 100644 --- a/plugins/websockets/pool/workers_pool.go +++ b/plugins/websockets/pool/workers_pool.go @@ -16,7 +16,7 @@ type WorkersPool struct { resPool sync.Pool log logger.Logger - queue chan pubsub.Message + queue chan *pubsub.Message exit chan struct{} } @@ -24,7 +24,7 @@ type WorkersPool struct { func NewWorkersPool(storage *storage.Storage, connections *sync.Map, log logger.Logger) *WorkersPool { wp := &WorkersPool{ connections: connections, - queue: make(chan pubsub.Message, 100), + queue: make(chan *pubsub.Message, 100), storage: storage, log: log, exit: make(chan struct{}), @@ -42,7 +42,7 @@ func NewWorkersPool(storage *storage.Storage, connections *sync.Map, log logger. return wp } -func (wp *WorkersPool) Queue(msg pubsub.Message) { +func (wp *WorkersPool) Queue(msg *pubsub.Message) { wp.queue <- msg } @@ -67,16 +67,26 @@ func (wp *WorkersPool) get() map[string]struct{} { return wp.resPool.Get().(map[string]struct{}) } -func (wp *WorkersPool) do() { +func (wp *WorkersPool) do() { //nolint:gocognit go func() { for { select { - case msg := <-wp.queue: + case msg, ok := <-wp.queue: + if !ok { + return + } + // do not handle nil's + if msg == nil { + continue + } + if len(msg.Topics) == 0 { + continue + } res := wp.get() // get connections for the particular topic - wp.storage.GetByPtr(msg.Topics(), res) + wp.storage.GetByPtr(msg.Topics, res) if len(res) == 0 { - wp.log.Info("no such topic", "topic", msg.Topics()) + wp.log.Info("no such topic", "topic", msg.Topics) wp.put(res) continue } @@ -84,14 +94,14 @@ func (wp *WorkersPool) do() { for i := range res { c, ok := wp.connections.Load(i) if !ok { - wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.Broker(), "topics", msg.Topics()) + wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.Broker, "topics", msg.Topics) continue } conn := c.(*connection.Connection) - err := conn.Write(websocket.BinaryMessage, msg.Payload()) + err := conn.Write(websocket.BinaryMessage, msg.Payload) if err != nil { - wp.log.Error("error sending payload over the connection", "broker", msg.Broker(), "topics", msg.Topics()) + wp.log.Error("error sending payload over the connection", "broker", msg.Broker, "topics", msg.Topics) wp.put(res) continue } diff --git a/plugins/websockets/rpc.go b/plugins/websockets/rpc.go index f917bd53..2fb0f1b9 100644 --- a/plugins/websockets/rpc.go +++ b/plugins/websockets/rpc.go @@ -12,18 +12,17 @@ type rpc struct { log logger.Logger } -func (r *rpc) Publish(msg []*pubsub.Msg, ok *bool) error { +func (r *rpc) Publish(msg []*pubsub.Message, ok *bool) error { const op = errors.Op("broadcast_publish") r.log.Debug("message published", "msg", msg) - // publish to the registered broker - mi := make([]pubsub.Message, 0, len(msg)) - // golang can't convert slice in-place - // so, we need to convert it manually - for i := 0; i < len(msg); i++ { - mi = append(mi, msg[i]) + // just return in case of nil message + if msg == nil { + *ok = true + return nil } - err := r.plugin.Publish(mi) + + err := r.plugin.Publish(msg) if err != nil { *ok = false return errors.E(op, err) @@ -32,16 +31,16 @@ func (r *rpc) Publish(msg []*pubsub.Msg, ok *bool) error { return nil } -func (r *rpc) PublishAsync(msg []*pubsub.Msg, ok *bool) error { - // publish to the registered broker - mi := make([]pubsub.Message, 0, len(msg)) - // golang can't convert slice in-place - // so, we need to convert it manually - for i := 0; i < len(msg); i++ { - mi = append(mi, msg[i]) - } +func (r *rpc) PublishAsync(msg []*pubsub.Message, ok *bool) error { + r.log.Debug("message published", "msg", msg) - r.plugin.PublishAsync(mi) + // just return in case of nil message + if msg == nil { + *ok = true + return nil + } + // publish to the registered broker + r.plugin.PublishAsync(msg) *ok = true return nil diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go index 9d9522d4..e3fde3d0 100644 --- a/plugins/websockets/validator/access_validator.go +++ b/plugins/websockets/validator/access_validator.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/http/attributes" ) @@ -67,16 +68,29 @@ func (w *AccessValidator) Error() string { // AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in // case of error -func (w *AccessValidator) AssertServerAccess(f http.HandlerFunc, r *http.Request) error { +func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error { if err := attributes.Set(r, "ws:joinServer", true); err != nil { return err } defer delete(attributes.All(r), "ws:joinServer") - f(w, r) + hub.ReceiveCh() <- struct { + RW http.ResponseWriter + Req *http.Request + }{ + w, + r, + } + + resp := <-hub.SendCh() + + rmsg := resp.(struct { + RW http.ResponseWriter + Req *http.Request + }) - if !w.IsOK() { + if !rmsg.RW.(*AccessValidator).IsOK() { return w } @@ -85,16 +99,29 @@ func (w *AccessValidator) AssertServerAccess(f http.HandlerFunc, r *http.Request // AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies. // the decision to authorize user will be based on response code (200). -func (w *AccessValidator) AssertTopicsAccess(f http.HandlerFunc, r *http.Request, channels ...string) error { +func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error { if err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")); err != nil { return err } defer delete(attributes.All(r), "ws:joinTopics") - f(w, r) + hub.ReceiveCh() <- struct { + RW http.ResponseWriter + Req *http.Request + }{ + w, + r, + } + + resp := <-hub.SendCh() + + rmsg := resp.(struct { + RW http.ResponseWriter + Req *http.Request + }) - if !w.IsOK() { + if !rmsg.RW.(*AccessValidator).IsOK() { return w } diff --git a/tests/plugins/config/config_test.go b/tests/plugins/config/config_test.go index b6063cec..3cf026bd 100755 --- a/tests/plugins/config/config_test.go +++ b/tests/plugins/config/config_test.go @@ -7,6 +7,7 @@ import ( "time" endure "github.com/spiral/endure/pkg/container" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/rpc" @@ -33,6 +34,11 @@ func TestViperProvider_Init(t *testing.T) { t.Fatal(err) } + err = container.Register(&channel.Plugin{}) + if err != nil { + t.Fatal(err) + } + err = container.Init() if err != nil { t.Fatal(err) @@ -82,6 +88,7 @@ func TestConfigOverwriteFail(t *testing.T) { &rpc.Plugin{}, vp, &Foo2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -103,6 +110,7 @@ func TestConfigOverwriteValid(t *testing.T) { &logger.ZapLogger{}, &rpc.Plugin{}, vp, + &channel.Plugin{}, &Foo2{}, ) assert.NoError(t, err) @@ -155,6 +163,7 @@ func TestConfigEnvVariables(t *testing.T) { &rpc.Plugin{}, vp, &Foo2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -206,6 +215,7 @@ func TestConfigEnvVariablesFail(t *testing.T) { &rpc.Plugin{}, vp, &Foo2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -237,6 +247,11 @@ func TestConfigProvider_GeneralSection(t *testing.T) { t.Fatal(err) } + err = container.Register(&channel.Plugin{}) + if err != nil { + t.Fatal(err) + } + err = container.Init() if err != nil { t.Fatal(err) diff --git a/tests/plugins/gzip/plugin_test.go b/tests/plugins/gzip/plugin_test.go index 844fd411..5294e672 100644 --- a/tests/plugins/gzip/plugin_test.go +++ b/tests/plugins/gzip/plugin_test.go @@ -11,6 +11,7 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/gzip" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -35,6 +36,7 @@ func TestGzipPlugin(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &gzip.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -128,6 +130,7 @@ func TestMiddlewareNotExist(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &gzip.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/headers/headers_plugin_test.go b/tests/plugins/headers/headers_plugin_test.go index 49d86b00..e4903335 100644 --- a/tests/plugins/headers/headers_plugin_test.go +++ b/tests/plugins/headers/headers_plugin_test.go @@ -11,6 +11,7 @@ import ( "time" endure "github.com/spiral/endure/pkg/container" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/headers" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -34,6 +35,7 @@ func TestHeadersInit(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -100,6 +102,7 @@ func TestRequestHeaders(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -185,6 +188,7 @@ func TestResponseHeaders(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -271,6 +275,7 @@ func TestCORSHeaders(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/http/http_plugin_test.go b/tests/plugins/http/http_plugin_test.go index 128eec26..aa57077d 100644 --- a/tests/plugins/http/http_plugin_test.go +++ b/tests/plugins/http/http_plugin_test.go @@ -24,6 +24,7 @@ import ( goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" "github.com/spiral/roadrunner/v2/pkg/events" "github.com/spiral/roadrunner/v2/pkg/process" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/gzip" "github.com/spiral/roadrunner/v2/plugins/informer" @@ -62,6 +63,7 @@ func TestHTTPInit(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -126,6 +128,7 @@ func TestHTTPNoConfigSection(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -193,6 +196,7 @@ func TestHTTPInformerReset(t *testing.T) { &httpPlugin.Plugin{}, &informer.Plugin{}, &resetter.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -315,6 +319,7 @@ func TestSSL(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -451,6 +456,7 @@ func TestSSLRedirect(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -540,6 +546,7 @@ func TestSSLPushPipes(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -630,6 +637,7 @@ func TestFastCGI_RequestUri(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -724,6 +732,7 @@ func TestH2CUpgrade(t *testing.T) { mockLogger, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -815,6 +824,7 @@ func TestH2C(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -907,6 +917,7 @@ func TestHttpMiddleware(t *testing.T) { &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1053,6 +1064,7 @@ logs: &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1138,6 +1150,7 @@ func TestHttpEnvVariables(t *testing.T) { &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1225,6 +1238,7 @@ func TestHttpBrokenPipes(t *testing.T) { &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1286,6 +1300,7 @@ func TestHTTPSupervisedPool(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &informer.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1488,6 +1503,7 @@ func TestHTTPBigRequestSize(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1580,6 +1596,7 @@ func TestStaticEtagPlugin(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1678,6 +1695,7 @@ func TestStaticPluginSecurity(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1827,6 +1845,7 @@ func TestStaticPlugin(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -1941,6 +1960,7 @@ func TestStaticDisabled_Error(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) assert.Error(t, cont.Init()) @@ -1962,6 +1982,7 @@ func TestStaticFilesDisabled(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -2054,6 +2075,7 @@ func TestStaticFilesForbid(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/logger/logger_test.go b/tests/plugins/logger/logger_test.go index d2877781..f63a6a5d 100644 --- a/tests/plugins/logger/logger_test.go +++ b/tests/plugins/logger/logger_test.go @@ -9,6 +9,7 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -98,6 +99,7 @@ func TestLoggerRawErr(t *testing.T) { mockLogger, &server.Plugin{}, &http.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -224,6 +226,7 @@ func TestLoggerNoConfig2(t *testing.T) { &logger.ZapLogger{}, &http.Plugin{}, &server.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/metrics/metrics_test.go b/tests/plugins/metrics/metrics_test.go index 8be567ec..48c01f24 100644 --- a/tests/plugins/metrics/metrics_test.go +++ b/tests/plugins/metrics/metrics_test.go @@ -15,6 +15,7 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -144,6 +145,7 @@ func TestMetricsIssue571(t *testing.T) { &server.Plugin{}, mockLogger, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/reload/reload_plugin_test.go b/tests/plugins/reload/reload_plugin_test.go index 6db7b6d0..41c9c92f 100644 --- a/tests/plugins/reload/reload_plugin_test.go +++ b/tests/plugins/reload/reload_plugin_test.go @@ -16,6 +16,7 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/reload" @@ -65,6 +66,7 @@ func TestReloadInit(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -161,6 +163,7 @@ func TestReloadHugeNumberOfFiles(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -270,6 +273,7 @@ func TestReloadFilterFileExt(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -400,6 +404,7 @@ func TestReloadCopy100(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -677,6 +682,7 @@ func TestReloadNoRecursion(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/status/plugin_test.go b/tests/plugins/status/plugin_test.go index 663f4ee3..06983199 100644 --- a/tests/plugins/status/plugin_test.go +++ b/tests/plugins/status/plugin_test.go @@ -14,6 +14,7 @@ import ( endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -38,6 +39,7 @@ func TestStatusHttp(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -125,6 +127,7 @@ func TestStatusRPC(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -204,6 +207,7 @@ func TestReadyHttp(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -291,6 +295,7 @@ func TestReadinessRPCWorkerNotReady(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/websockets/websocket_plugin_test.go b/tests/plugins/websockets/websocket_plugin_test.go index 087d1bc9..61ef186b 100644 --- a/tests/plugins/websockets/websocket_plugin_test.go +++ b/tests/plugins/websockets/websocket_plugin_test.go @@ -16,6 +16,7 @@ import ( json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -30,16 +31,16 @@ import ( type Msg struct { // Topic message been pushed into. - Topics_ []string `json:"topic"` + Topics []string `json:"topic"` // Command (join, leave, headers) - Command_ string `json:"command"` + Command string `json:"command"` // Broker (redis, memory) - Broker_ string `json:"broker"` + Broker string `json:"broker"` // Payload to be broadcasted - Payload_ []byte `json:"payload"` + Payload []byte `json:"payload"` } func TestBroadcastInit(t *testing.T) { @@ -59,6 +60,7 @@ func TestBroadcastInit(t *testing.T) { &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -168,6 +170,7 @@ func TestWSRedisAndMemory(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, + &channel.Plugin{}, ) assert.NoError(t, err) @@ -504,9 +507,9 @@ func publish2(command string, broker string, topics ...string) { func message(command string, broker string, payload []byte, topics ...string) *Msg { return &Msg{ - Topics_: topics, - Command_: command, - Broker_: broker, - Payload_: payload, + Topics: topics, + Command: command, + Broker: broker, + Payload: payload, } } |