diff options
Diffstat (limited to 'plugins/websockets/executor/executor.go')
-rw-r--r-- | plugins/websockets/executor/executor.go | 147 |
1 files changed, 117 insertions, 30 deletions
diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index 1aa54be9..87fed3a6 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -1,14 +1,20 @@ package executor import ( + "fmt" + "net/http" + "sync" + "github.com/fasthttp/websocket" json "github.com/json-iterator/go" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/storage" + "github.com/spiral/roadrunner/v2/plugins/websockets/validator" ) type Response struct { @@ -17,6 +23,7 @@ type Response struct { } type Executor struct { + sync.Mutex conn *connection.Connection storage *storage.Storage log logger.Logger @@ -25,17 +32,24 @@ type Executor struct { connID string // map with the pubsub drivers - pubsub map[string]pubsub.PubSub + pubsub map[string]pubsub.PubSub + actualTopics map[string]struct{} + + hub channel.Hub + req *http.Request } // NewExecutor creates protected connection and starts command loop -func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub) *Executor { +func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub, hub channel.Hub, r *http.Request) *Executor { return &Executor{ - conn: conn, - connID: connID, - storage: bst, - log: log, - pubsub: pubsubs, + conn: conn, + connID: connID, + storage: bst, + log: log, + pubsub: pubsubs, + hub: hub, + actualTopics: make(map[string]struct{}, 10), + req: r, } } @@ -52,7 +66,7 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit return errors.E(op, err) } - msg := &pubsub.Msg{} + msg := &pubsub.Message{} err = json.Unmarshal(data, msg) if err != nil { @@ -60,76 +74,149 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit continue } - switch msg.Command() { + // nil message, continue + if msg == nil { + e.log.Warn("get nil message, skipping") + continue + } + + switch msg.Command { // handle leave case commands.Join: e.log.Debug("get join command", "msg", msg) - // associate connection with topics - e.storage.InsertMany(e.connID, msg.Topics()) + + err := validator.NewValidator().AssertTopicsAccess(e.hub, e.req, msg.Topics...) + if err != nil { + resp := &Response{ + Topic: "#join", + Payload: msg.Topics, + } + + packet, errJ := json.Marshal(resp) + if errJ != nil { + e.log.Error("error marshal the body", "error", errJ) + return errors.E(op, fmt.Errorf("%v,%v", err, errJ)) + } + + errW := e.conn.Write(websocket.BinaryMessage, packet) + if errW != nil { + e.log.Error("error writing payload to the connection", "payload", packet, "error", errW) + return errors.E(op, fmt.Errorf("%v,%v", err, errW)) + } + + continue + } resp := &Response{ Topic: "@join", - Payload: msg.Topics(), + Payload: msg.Topics, } packet, err := json.Marshal(resp) if err != nil { e.log.Error("error marshal the body", "error", err) - continue + return errors.E(op, err) } err = e.conn.Write(websocket.BinaryMessage, packet) if err != nil { e.log.Error("error writing payload to the connection", "payload", packet, "error", err) - continue + return errors.E(op, err) } // subscribe to the topic - if br, ok := e.pubsub[msg.Broker()]; ok { - err = br.Subscribe(msg.Topics()...) + if br, ok := e.pubsub[msg.Broker]; ok { + err = e.Set(br, msg.Topics) if err != nil { - e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error()) - // in case of error, unsubscribe connection from the dead topics - _ = br.Unsubscribe(msg.Topics()...) - continue + return errors.E(op, err) } } // handle leave case commands.Leave: e.log.Debug("get leave command", "msg", msg) - // remove associated connections from the storage - e.storage.RemoveMany(e.connID, msg.Topics()) + // prepare response resp := &Response{ Topic: "@leave", - Payload: msg.Topics(), + Payload: msg.Topics, } packet, err := json.Marshal(resp) if err != nil { e.log.Error("error marshal the body", "error", err) - continue + return errors.E(op, err) } err = e.conn.Write(websocket.BinaryMessage, packet) if err != nil { e.log.Error("error writing payload to the connection", "payload", packet, "error", err) - continue + return errors.E(op, err) } - if br, ok := e.pubsub[msg.Broker()]; ok { - err = br.Unsubscribe(msg.Topics()...) + if br, ok := e.pubsub[msg.Broker]; ok { + err = e.Leave(br, msg.Topics) if err != nil { - e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error()) - continue + return errors.E(op, err) } } case commands.Headers: default: - e.log.Warn("unknown command", "command", msg.Command()) + e.log.Warn("unknown command", "command", msg.Command) } } } + +func (e *Executor) Set(br pubsub.PubSub, topics []string) error { + // associate connection with topics + err := br.Subscribe(topics...) + if err != nil { + e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + // in case of error, unsubscribe connection from the dead topics + _ = br.Unsubscribe(topics...) + return err + } + + e.storage.InsertMany(e.connID, topics) + + // save topics for the connection + for i := 0; i < len(topics); i++ { + e.actualTopics[topics[i]] = struct{}{} + } + + return nil +} + +func (e *Executor) Leave(br pubsub.PubSub, topics []string) error { + // remove associated connections from the storage + e.storage.RemoveMany(e.connID, topics) + err := br.Unsubscribe(topics...) + if err != nil { + e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + return err + } + + // remove topics for the connection + for i := 0; i < len(topics); i++ { + delete(e.actualTopics, topics[i]) + } + + return nil +} + +func (e *Executor) CleanUp() { + for topic := range e.actualTopics { + // remove from the bst + e.storage.Remove(e.connID, topic) + + for _, ps := range e.pubsub { + _ = ps.Unsubscribe(topic) + } + } + + for k := range e.actualTopics { + delete(e.actualTopics, k) + } +} |