summaryrefslogtreecommitdiff
path: root/plugins/websockets/executor/executor.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/websockets/executor/executor.go')
-rw-r--r--plugins/websockets/executor/executor.go147
1 files changed, 117 insertions, 30 deletions
diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go
index 1aa54be9..87fed3a6 100644
--- a/plugins/websockets/executor/executor.go
+++ b/plugins/websockets/executor/executor.go
@@ -1,14 +1,20 @@
package executor
import (
+ "fmt"
+ "net/http"
+ "sync"
+
"github.com/fasthttp/websocket"
json "github.com/json-iterator/go"
"github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/channel"
"github.com/spiral/roadrunner/v2/plugins/logger"
"github.com/spiral/roadrunner/v2/plugins/websockets/commands"
"github.com/spiral/roadrunner/v2/plugins/websockets/connection"
"github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/validator"
)
type Response struct {
@@ -17,6 +23,7 @@ type Response struct {
}
type Executor struct {
+ sync.Mutex
conn *connection.Connection
storage *storage.Storage
log logger.Logger
@@ -25,17 +32,24 @@ type Executor struct {
connID string
// map with the pubsub drivers
- pubsub map[string]pubsub.PubSub
+ pubsub map[string]pubsub.PubSub
+ actualTopics map[string]struct{}
+
+ hub channel.Hub
+ req *http.Request
}
// NewExecutor creates protected connection and starts command loop
-func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub) *Executor {
+func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub, hub channel.Hub, r *http.Request) *Executor {
return &Executor{
- conn: conn,
- connID: connID,
- storage: bst,
- log: log,
- pubsub: pubsubs,
+ conn: conn,
+ connID: connID,
+ storage: bst,
+ log: log,
+ pubsub: pubsubs,
+ hub: hub,
+ actualTopics: make(map[string]struct{}, 10),
+ req: r,
}
}
@@ -52,7 +66,7 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
return errors.E(op, err)
}
- msg := &pubsub.Msg{}
+ msg := &pubsub.Message{}
err = json.Unmarshal(data, msg)
if err != nil {
@@ -60,76 +74,149 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
continue
}
- switch msg.Command() {
+ // nil message, continue
+ if msg == nil {
+ e.log.Warn("get nil message, skipping")
+ continue
+ }
+
+ switch msg.Command {
// handle leave
case commands.Join:
e.log.Debug("get join command", "msg", msg)
- // associate connection with topics
- e.storage.InsertMany(e.connID, msg.Topics())
+
+ err := validator.NewValidator().AssertTopicsAccess(e.hub, e.req, msg.Topics...)
+ if err != nil {
+ resp := &Response{
+ Topic: "#join",
+ Payload: msg.Topics,
+ }
+
+ packet, errJ := json.Marshal(resp)
+ if errJ != nil {
+ e.log.Error("error marshal the body", "error", errJ)
+ return errors.E(op, fmt.Errorf("%v,%v", err, errJ))
+ }
+
+ errW := e.conn.Write(websocket.BinaryMessage, packet)
+ if errW != nil {
+ e.log.Error("error writing payload to the connection", "payload", packet, "error", errW)
+ return errors.E(op, fmt.Errorf("%v,%v", err, errW))
+ }
+
+ continue
+ }
resp := &Response{
Topic: "@join",
- Payload: msg.Topics(),
+ Payload: msg.Topics,
}
packet, err := json.Marshal(resp)
if err != nil {
e.log.Error("error marshal the body", "error", err)
- continue
+ return errors.E(op, err)
}
err = e.conn.Write(websocket.BinaryMessage, packet)
if err != nil {
e.log.Error("error writing payload to the connection", "payload", packet, "error", err)
- continue
+ return errors.E(op, err)
}
// subscribe to the topic
- if br, ok := e.pubsub[msg.Broker()]; ok {
- err = br.Subscribe(msg.Topics()...)
+ if br, ok := e.pubsub[msg.Broker]; ok {
+ err = e.Set(br, msg.Topics)
if err != nil {
- e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error())
- // in case of error, unsubscribe connection from the dead topics
- _ = br.Unsubscribe(msg.Topics()...)
- continue
+ return errors.E(op, err)
}
}
// handle leave
case commands.Leave:
e.log.Debug("get leave command", "msg", msg)
- // remove associated connections from the storage
- e.storage.RemoveMany(e.connID, msg.Topics())
+ // prepare response
resp := &Response{
Topic: "@leave",
- Payload: msg.Topics(),
+ Payload: msg.Topics,
}
packet, err := json.Marshal(resp)
if err != nil {
e.log.Error("error marshal the body", "error", err)
- continue
+ return errors.E(op, err)
}
err = e.conn.Write(websocket.BinaryMessage, packet)
if err != nil {
e.log.Error("error writing payload to the connection", "payload", packet, "error", err)
- continue
+ return errors.E(op, err)
}
- if br, ok := e.pubsub[msg.Broker()]; ok {
- err = br.Unsubscribe(msg.Topics()...)
+ if br, ok := e.pubsub[msg.Broker]; ok {
+ err = e.Leave(br, msg.Topics)
if err != nil {
- e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error())
- continue
+ return errors.E(op, err)
}
}
case commands.Headers:
default:
- e.log.Warn("unknown command", "command", msg.Command())
+ e.log.Warn("unknown command", "command", msg.Command)
}
}
}
+
+func (e *Executor) Set(br pubsub.PubSub, topics []string) error {
+ // associate connection with topics
+ err := br.Subscribe(topics...)
+ if err != nil {
+ e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
+ // in case of error, unsubscribe connection from the dead topics
+ _ = br.Unsubscribe(topics...)
+ return err
+ }
+
+ e.storage.InsertMany(e.connID, topics)
+
+ // save topics for the connection
+ for i := 0; i < len(topics); i++ {
+ e.actualTopics[topics[i]] = struct{}{}
+ }
+
+ return nil
+}
+
+func (e *Executor) Leave(br pubsub.PubSub, topics []string) error {
+ // remove associated connections from the storage
+ e.storage.RemoveMany(e.connID, topics)
+ err := br.Unsubscribe(topics...)
+ if err != nil {
+ e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
+ return err
+ }
+
+ // remove topics for the connection
+ for i := 0; i < len(topics); i++ {
+ delete(e.actualTopics, topics[i])
+ }
+
+ return nil
+}
+
+func (e *Executor) CleanUp() {
+ for topic := range e.actualTopics {
+ // remove from the bst
+ e.storage.Remove(e.connID, topic)
+
+ for _, ps := range e.pubsub {
+ _ = ps.Unsubscribe(topic)
+ }
+ }
+
+ for k := range e.actualTopics {
+ delete(e.actualTopics, k)
+ }
+}