summaryrefslogtreecommitdiff
path: root/plugins/websockets
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-05-29 00:24:30 +0300
committerValery Piashchynski <[email protected]>2021-05-29 00:24:30 +0300
commitfcda08498e8f914bbd0798da898818cd5d0e4348 (patch)
tree62d88384d07997e2373f3b273ba0cb83569ebced /plugins/websockets
parent8f13eb958c7eec49acba6e343edb77c6ede89f09 (diff)
- Add new internal plugin - channel. Which used to deliver messages from
the ws plugin to the http directly Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/websockets')
-rw-r--r--plugins/websockets/executor/executor.go147
-rw-r--r--plugins/websockets/plugin.go61
-rw-r--r--plugins/websockets/pool/workers_pool.go30
-rw-r--r--plugins/websockets/rpc.go33
-rw-r--r--plugins/websockets/validator/access_validator.go39
5 files changed, 227 insertions, 83 deletions
diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go
index 1aa54be9..87fed3a6 100644
--- a/plugins/websockets/executor/executor.go
+++ b/plugins/websockets/executor/executor.go
@@ -1,14 +1,20 @@
package executor
import (
+ "fmt"
+ "net/http"
+ "sync"
+
"github.com/fasthttp/websocket"
json "github.com/json-iterator/go"
"github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/channel"
"github.com/spiral/roadrunner/v2/plugins/logger"
"github.com/spiral/roadrunner/v2/plugins/websockets/commands"
"github.com/spiral/roadrunner/v2/plugins/websockets/connection"
"github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/validator"
)
type Response struct {
@@ -17,6 +23,7 @@ type Response struct {
}
type Executor struct {
+ sync.Mutex
conn *connection.Connection
storage *storage.Storage
log logger.Logger
@@ -25,17 +32,24 @@ type Executor struct {
connID string
// map with the pubsub drivers
- pubsub map[string]pubsub.PubSub
+ pubsub map[string]pubsub.PubSub
+ actualTopics map[string]struct{}
+
+ hub channel.Hub
+ req *http.Request
}
// NewExecutor creates protected connection and starts command loop
-func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub) *Executor {
+func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub, hub channel.Hub, r *http.Request) *Executor {
return &Executor{
- conn: conn,
- connID: connID,
- storage: bst,
- log: log,
- pubsub: pubsubs,
+ conn: conn,
+ connID: connID,
+ storage: bst,
+ log: log,
+ pubsub: pubsubs,
+ hub: hub,
+ actualTopics: make(map[string]struct{}, 10),
+ req: r,
}
}
@@ -52,7 +66,7 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
return errors.E(op, err)
}
- msg := &pubsub.Msg{}
+ msg := &pubsub.Message{}
err = json.Unmarshal(data, msg)
if err != nil {
@@ -60,76 +74,149 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
continue
}
- switch msg.Command() {
+ // nil message, continue
+ if msg == nil {
+ e.log.Warn("get nil message, skipping")
+ continue
+ }
+
+ switch msg.Command {
// handle leave
case commands.Join:
e.log.Debug("get join command", "msg", msg)
- // associate connection with topics
- e.storage.InsertMany(e.connID, msg.Topics())
+
+ err := validator.NewValidator().AssertTopicsAccess(e.hub, e.req, msg.Topics...)
+ if err != nil {
+ resp := &Response{
+ Topic: "#join",
+ Payload: msg.Topics,
+ }
+
+ packet, errJ := json.Marshal(resp)
+ if errJ != nil {
+ e.log.Error("error marshal the body", "error", errJ)
+ return errors.E(op, fmt.Errorf("%v,%v", err, errJ))
+ }
+
+ errW := e.conn.Write(websocket.BinaryMessage, packet)
+ if errW != nil {
+ e.log.Error("error writing payload to the connection", "payload", packet, "error", errW)
+ return errors.E(op, fmt.Errorf("%v,%v", err, errW))
+ }
+
+ continue
+ }
resp := &Response{
Topic: "@join",
- Payload: msg.Topics(),
+ Payload: msg.Topics,
}
packet, err := json.Marshal(resp)
if err != nil {
e.log.Error("error marshal the body", "error", err)
- continue
+ return errors.E(op, err)
}
err = e.conn.Write(websocket.BinaryMessage, packet)
if err != nil {
e.log.Error("error writing payload to the connection", "payload", packet, "error", err)
- continue
+ return errors.E(op, err)
}
// subscribe to the topic
- if br, ok := e.pubsub[msg.Broker()]; ok {
- err = br.Subscribe(msg.Topics()...)
+ if br, ok := e.pubsub[msg.Broker]; ok {
+ err = e.Set(br, msg.Topics)
if err != nil {
- e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error())
- // in case of error, unsubscribe connection from the dead topics
- _ = br.Unsubscribe(msg.Topics()...)
- continue
+ return errors.E(op, err)
}
}
// handle leave
case commands.Leave:
e.log.Debug("get leave command", "msg", msg)
- // remove associated connections from the storage
- e.storage.RemoveMany(e.connID, msg.Topics())
+ // prepare response
resp := &Response{
Topic: "@leave",
- Payload: msg.Topics(),
+ Payload: msg.Topics,
}
packet, err := json.Marshal(resp)
if err != nil {
e.log.Error("error marshal the body", "error", err)
- continue
+ return errors.E(op, err)
}
err = e.conn.Write(websocket.BinaryMessage, packet)
if err != nil {
e.log.Error("error writing payload to the connection", "payload", packet, "error", err)
- continue
+ return errors.E(op, err)
}
- if br, ok := e.pubsub[msg.Broker()]; ok {
- err = br.Unsubscribe(msg.Topics()...)
+ if br, ok := e.pubsub[msg.Broker]; ok {
+ err = e.Leave(br, msg.Topics)
if err != nil {
- e.log.Error("error subscribing to the provided topics", "topics", msg.Topics(), "error", err.Error())
- continue
+ return errors.E(op, err)
}
}
case commands.Headers:
default:
- e.log.Warn("unknown command", "command", msg.Command())
+ e.log.Warn("unknown command", "command", msg.Command)
}
}
}
+
+func (e *Executor) Set(br pubsub.PubSub, topics []string) error {
+ // associate connection with topics
+ err := br.Subscribe(topics...)
+ if err != nil {
+ e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
+ // in case of error, unsubscribe connection from the dead topics
+ _ = br.Unsubscribe(topics...)
+ return err
+ }
+
+ e.storage.InsertMany(e.connID, topics)
+
+ // save topics for the connection
+ for i := 0; i < len(topics); i++ {
+ e.actualTopics[topics[i]] = struct{}{}
+ }
+
+ return nil
+}
+
+func (e *Executor) Leave(br pubsub.PubSub, topics []string) error {
+ // remove associated connections from the storage
+ e.storage.RemoveMany(e.connID, topics)
+ err := br.Unsubscribe(topics...)
+ if err != nil {
+ e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
+ return err
+ }
+
+ // remove topics for the connection
+ for i := 0; i < len(topics); i++ {
+ delete(e.actualTopics, topics[i])
+ }
+
+ return nil
+}
+
+func (e *Executor) CleanUp() {
+ for topic := range e.actualTopics {
+ // remove from the bst
+ e.storage.Remove(e.connID, topic)
+
+ for _, ps := range e.pubsub {
+ _ = ps.Unsubscribe(topic)
+ }
+ }
+
+ for k := range e.actualTopics {
+ delete(e.actualTopics, k)
+ }
+}
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
index 76ef800d..2a060716 100644
--- a/plugins/websockets/plugin.go
+++ b/plugins/websockets/plugin.go
@@ -10,12 +10,15 @@ import (
endure "github.com/spiral/endure/pkg/container"
"github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/channel"
"github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/http/attributes"
"github.com/spiral/roadrunner/v2/plugins/logger"
"github.com/spiral/roadrunner/v2/plugins/websockets/connection"
"github.com/spiral/roadrunner/v2/plugins/websockets/executor"
"github.com/spiral/roadrunner/v2/plugins/websockets/pool"
"github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/validator"
)
const (
@@ -23,7 +26,7 @@ const (
)
type Plugin struct {
- sync.RWMutex
+ mu sync.RWMutex
// Collection with all available pubsubs
pubsubs map[string]pubsub.PubSub
@@ -34,10 +37,13 @@ type Plugin struct {
connections sync.Map
storage *storage.Storage
+ // GO workers pool
workersPool *pool.WorkersPool
+
+ hub channel.Hub
}
-func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error {
+func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.Hub) error {
const op = errors.Op("websockets_plugin_init")
if !cfg.Has(PluginName) {
return errors.E(op, errors.Disabled)
@@ -52,6 +58,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger) error {
p.log = log
p.storage = storage.NewStorage()
p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log)
+ p.hub = channel
return nil
}
@@ -69,10 +76,6 @@ func (p *Plugin) Serve() chan error {
return
}
- if data == nil {
- continue
- }
-
p.workersPool.Queue(data)
}
}(v)
@@ -115,6 +118,22 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
return
}
+ p.mu.Lock()
+
+ r = attributes.Init(r)
+
+ err := validator.NewValidator().AssertServerAccess(p.hub, r)
+ if err != nil {
+ // show the error to the user
+ if av, ok := err.(*validator.AccessValidator); ok {
+ av.Copy(w)
+ } else {
+ w.WriteHeader(400)
+ return
+ }
+ }
+
+ p.mu.Unlock()
// connection upgrader
upgraded := websocket.Upgrader{
@@ -154,13 +173,15 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
p.connections.Delete(connectionID)
}()
+ p.mu.Lock()
// Executor wraps a connection to have a safe abstraction
- p.Lock()
- e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs)
- p.Unlock()
+ e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.hub, r)
+ p.mu.Unlock()
p.log.Info("websocket client connected", "uuid", connectionID)
+ defer e.CleanUp()
+
err = e.StartCommandLoop()
if err != nil {
p.log.Error("command loop error", "error", err.Error())
@@ -170,32 +191,32 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
}
// Publish is an entry point to the websocket PUBSUB
-func (p *Plugin) Publish(msg []pubsub.Message) error {
- p.Lock()
- defer p.Unlock()
+func (p *Plugin) Publish(msg []*pubsub.Message) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
for i := 0; i < len(msg); i++ {
- for j := 0; j < len(msg[i].Topics()); j++ {
- if br, ok := p.pubsubs[msg[i].Broker()]; ok {
+ for j := 0; j < len(msg[i].Topics); j++ {
+ if br, ok := p.pubsubs[msg[i].Broker]; ok {
err := br.Publish(msg)
if err != nil {
return errors.E(err)
}
} else {
- p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker())
+ p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker)
}
}
}
return nil
}
-func (p *Plugin) PublishAsync(msg []pubsub.Message) {
+func (p *Plugin) PublishAsync(msg []*pubsub.Message) {
go func() {
- p.Lock()
- defer p.Unlock()
+ p.mu.Lock()
+ defer p.mu.Unlock()
for i := 0; i < len(msg); i++ {
- for j := 0; j < len(msg[i].Topics()); j++ {
- err := p.pubsubs[msg[i].Broker()].Publish(msg)
+ for j := 0; j < len(msg[i].Topics); j++ {
+ err := p.pubsubs[msg[i].Broker].Publish(msg)
if err != nil {
p.log.Error("publish async error", "error", err)
return
diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go
index 87e931d0..8f18580f 100644
--- a/plugins/websockets/pool/workers_pool.go
+++ b/plugins/websockets/pool/workers_pool.go
@@ -16,7 +16,7 @@ type WorkersPool struct {
resPool sync.Pool
log logger.Logger
- queue chan pubsub.Message
+ queue chan *pubsub.Message
exit chan struct{}
}
@@ -24,7 +24,7 @@ type WorkersPool struct {
func NewWorkersPool(storage *storage.Storage, connections *sync.Map, log logger.Logger) *WorkersPool {
wp := &WorkersPool{
connections: connections,
- queue: make(chan pubsub.Message, 100),
+ queue: make(chan *pubsub.Message, 100),
storage: storage,
log: log,
exit: make(chan struct{}),
@@ -42,7 +42,7 @@ func NewWorkersPool(storage *storage.Storage, connections *sync.Map, log logger.
return wp
}
-func (wp *WorkersPool) Queue(msg pubsub.Message) {
+func (wp *WorkersPool) Queue(msg *pubsub.Message) {
wp.queue <- msg
}
@@ -67,16 +67,26 @@ func (wp *WorkersPool) get() map[string]struct{} {
return wp.resPool.Get().(map[string]struct{})
}
-func (wp *WorkersPool) do() {
+func (wp *WorkersPool) do() { //nolint:gocognit
go func() {
for {
select {
- case msg := <-wp.queue:
+ case msg, ok := <-wp.queue:
+ if !ok {
+ return
+ }
+ // do not handle nil's
+ if msg == nil {
+ continue
+ }
+ if len(msg.Topics) == 0 {
+ continue
+ }
res := wp.get()
// get connections for the particular topic
- wp.storage.GetByPtr(msg.Topics(), res)
+ wp.storage.GetByPtr(msg.Topics, res)
if len(res) == 0 {
- wp.log.Info("no such topic", "topic", msg.Topics())
+ wp.log.Info("no such topic", "topic", msg.Topics)
wp.put(res)
continue
}
@@ -84,14 +94,14 @@ func (wp *WorkersPool) do() {
for i := range res {
c, ok := wp.connections.Load(i)
if !ok {
- wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.Broker(), "topics", msg.Topics())
+ wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.Broker, "topics", msg.Topics)
continue
}
conn := c.(*connection.Connection)
- err := conn.Write(websocket.BinaryMessage, msg.Payload())
+ err := conn.Write(websocket.BinaryMessage, msg.Payload)
if err != nil {
- wp.log.Error("error sending payload over the connection", "broker", msg.Broker(), "topics", msg.Topics())
+ wp.log.Error("error sending payload over the connection", "broker", msg.Broker, "topics", msg.Topics)
wp.put(res)
continue
}
diff --git a/plugins/websockets/rpc.go b/plugins/websockets/rpc.go
index f917bd53..2fb0f1b9 100644
--- a/plugins/websockets/rpc.go
+++ b/plugins/websockets/rpc.go
@@ -12,18 +12,17 @@ type rpc struct {
log logger.Logger
}
-func (r *rpc) Publish(msg []*pubsub.Msg, ok *bool) error {
+func (r *rpc) Publish(msg []*pubsub.Message, ok *bool) error {
const op = errors.Op("broadcast_publish")
r.log.Debug("message published", "msg", msg)
- // publish to the registered broker
- mi := make([]pubsub.Message, 0, len(msg))
- // golang can't convert slice in-place
- // so, we need to convert it manually
- for i := 0; i < len(msg); i++ {
- mi = append(mi, msg[i])
+ // just return in case of nil message
+ if msg == nil {
+ *ok = true
+ return nil
}
- err := r.plugin.Publish(mi)
+
+ err := r.plugin.Publish(msg)
if err != nil {
*ok = false
return errors.E(op, err)
@@ -32,16 +31,16 @@ func (r *rpc) Publish(msg []*pubsub.Msg, ok *bool) error {
return nil
}
-func (r *rpc) PublishAsync(msg []*pubsub.Msg, ok *bool) error {
- // publish to the registered broker
- mi := make([]pubsub.Message, 0, len(msg))
- // golang can't convert slice in-place
- // so, we need to convert it manually
- for i := 0; i < len(msg); i++ {
- mi = append(mi, msg[i])
- }
+func (r *rpc) PublishAsync(msg []*pubsub.Message, ok *bool) error {
+ r.log.Debug("message published", "msg", msg)
- r.plugin.PublishAsync(mi)
+ // just return in case of nil message
+ if msg == nil {
+ *ok = true
+ return nil
+ }
+ // publish to the registered broker
+ r.plugin.PublishAsync(msg)
*ok = true
return nil
diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go
index 9d9522d4..e3fde3d0 100644
--- a/plugins/websockets/validator/access_validator.go
+++ b/plugins/websockets/validator/access_validator.go
@@ -6,6 +6,7 @@ import (
"net/http"
"strings"
+ "github.com/spiral/roadrunner/v2/plugins/channel"
"github.com/spiral/roadrunner/v2/plugins/http/attributes"
)
@@ -67,16 +68,29 @@ func (w *AccessValidator) Error() string {
// AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in
// case of error
-func (w *AccessValidator) AssertServerAccess(f http.HandlerFunc, r *http.Request) error {
+func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error {
if err := attributes.Set(r, "ws:joinServer", true); err != nil {
return err
}
defer delete(attributes.All(r), "ws:joinServer")
- f(w, r)
+ hub.ReceiveCh() <- struct {
+ RW http.ResponseWriter
+ Req *http.Request
+ }{
+ w,
+ r,
+ }
+
+ resp := <-hub.SendCh()
+
+ rmsg := resp.(struct {
+ RW http.ResponseWriter
+ Req *http.Request
+ })
- if !w.IsOK() {
+ if !rmsg.RW.(*AccessValidator).IsOK() {
return w
}
@@ -85,16 +99,29 @@ func (w *AccessValidator) AssertServerAccess(f http.HandlerFunc, r *http.Request
// AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies.
// the decision to authorize user will be based on response code (200).
-func (w *AccessValidator) AssertTopicsAccess(f http.HandlerFunc, r *http.Request, channels ...string) error {
+func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error {
if err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")); err != nil {
return err
}
defer delete(attributes.All(r), "ws:joinTopics")
- f(w, r)
+ hub.ReceiveCh() <- struct {
+ RW http.ResponseWriter
+ Req *http.Request
+ }{
+ w,
+ r,
+ }
+
+ resp := <-hub.SendCh()
+
+ rmsg := resp.(struct {
+ RW http.ResponseWriter
+ Req *http.Request
+ })
- if !w.IsOK() {
+ if !rmsg.RW.(*AccessValidator).IsOK() {
return w
}