summaryrefslogtreecommitdiff
path: root/plugins/websockets
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-06-01 00:10:31 +0300
committerGitHub <[email protected]>2021-06-01 00:10:31 +0300
commit548ee4432e48b316ada00feec1a6b89e67ae4f2f (patch)
tree5cd2aaeeafdb50e3e46824197c721223f54695bf /plugins/websockets
parent8cd696bbca8fac2ced30d8172c41b7434ec86650 (diff)
parentdf4d316d519cea6dff654bd917521a616a37f769 (diff)
#660 feat(plugin): `broadcast` and `broadcast-ws` plugins update to RR2
#660 feat(plugin): `broadcast` and `broadcast-ws` plugins update to RR2
Diffstat (limited to 'plugins/websockets')
-rw-r--r--plugins/websockets/commands/enums.go9
-rw-r--r--plugins/websockets/config.go58
-rw-r--r--plugins/websockets/connection/connection.go67
-rw-r--r--plugins/websockets/doc/broadcast.drawio1
-rw-r--r--plugins/websockets/doc/doc.go27
-rw-r--r--plugins/websockets/executor/executor.go226
-rw-r--r--plugins/websockets/plugin.go386
-rw-r--r--plugins/websockets/pool/workers_pool.go117
-rw-r--r--plugins/websockets/rpc.go47
-rw-r--r--plugins/websockets/schema/message.fbs10
-rw-r--r--plugins/websockets/schema/message/Message.go118
-rw-r--r--plugins/websockets/storage/storage.go79
-rw-r--r--plugins/websockets/storage/storage_test.go299
-rw-r--r--plugins/websockets/validator/access_validator.go76
14 files changed, 1520 insertions, 0 deletions
diff --git a/plugins/websockets/commands/enums.go b/plugins/websockets/commands/enums.go
new file mode 100644
index 00000000..18c63be3
--- /dev/null
+++ b/plugins/websockets/commands/enums.go
@@ -0,0 +1,9 @@
+package commands
+
+type Command string
+
+const (
+ Leave string = "leave"
+ Join string = "join"
+ Headers string = "headers"
+)
diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go
new file mode 100644
index 00000000..be4aaa82
--- /dev/null
+++ b/plugins/websockets/config.go
@@ -0,0 +1,58 @@
+package websockets
+
+import (
+ "time"
+
+ "github.com/spiral/roadrunner/v2/pkg/pool"
+)
+
+/*
+websockets:
+ # pubsubs should implement PubSub interface to be collected via endure.Collects
+
+ pubsubs:["redis", "amqp", "memory"]
+ # path used as websockets path
+ path: "/ws"
+*/
+
+// Config represents configuration for the ws plugin
+type Config struct {
+ // http path for the websocket
+ Path string `mapstructure:"path"`
+ // ["redis", "amqp", "memory"]
+ PubSubs []string `mapstructure:"pubsubs"`
+ Middleware []string `mapstructure:"middleware"`
+
+ Pool *pool.Config `mapstructure:"pool"`
+}
+
+// InitDefault initialize default values for the ws config
+func (c *Config) InitDefault() {
+ if c.Path == "" {
+ c.Path = "/ws"
+ }
+ if len(c.PubSubs) == 0 {
+ // memory used by default
+ c.PubSubs = append(c.PubSubs, "memory")
+ }
+
+ if c.Pool == nil {
+ c.Pool = &pool.Config{}
+ if c.Pool.NumWorkers == 0 {
+ // 2 workers by default
+ c.Pool.NumWorkers = 2
+ }
+
+ if c.Pool.AllocateTimeout == 0 {
+ c.Pool.AllocateTimeout = time.Minute
+ }
+
+ if c.Pool.DestroyTimeout == 0 {
+ c.Pool.DestroyTimeout = time.Minute
+ }
+ if c.Pool.Supervisor == nil {
+ return
+ }
+ c.Pool.Supervisor.InitDefaults()
+ }
+}
diff --git a/plugins/websockets/connection/connection.go b/plugins/websockets/connection/connection.go
new file mode 100644
index 00000000..2b847173
--- /dev/null
+++ b/plugins/websockets/connection/connection.go
@@ -0,0 +1,67 @@
+package connection
+
+import (
+ "sync"
+
+ "github.com/fasthttp/websocket"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+)
+
+// Connection represents wrapped and safe to use from the different threads websocket connection
+type Connection struct {
+ sync.RWMutex
+ log logger.Logger
+ conn *websocket.Conn
+}
+
+func NewConnection(wsConn *websocket.Conn, log logger.Logger) *Connection {
+ return &Connection{
+ conn: wsConn,
+ log: log,
+ }
+}
+
+func (c *Connection) Write(mt int, data []byte) error {
+ c.Lock()
+ defer c.Unlock()
+
+ const op = errors.Op("websocket_write")
+ // handle a case when a goroutine tried to write into the closed connection
+ defer func() {
+ if r := recover(); r != nil {
+ c.log.Warn("panic handled, tried to write into the closed connection")
+ }
+ }()
+
+ err := c.conn.WriteMessage(mt, data)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ return nil
+}
+
+func (c *Connection) Read() (int, []byte, error) {
+ const op = errors.Op("websocket_read")
+
+ mt, data, err := c.conn.ReadMessage()
+ if err != nil {
+ return -1, nil, errors.E(op, err)
+ }
+
+ return mt, data, nil
+}
+
+func (c *Connection) Close() error {
+ c.Lock()
+ defer c.Unlock()
+ const op = errors.Op("websocket_close")
+
+ err := c.conn.Close()
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ return nil
+}
diff --git a/plugins/websockets/doc/broadcast.drawio b/plugins/websockets/doc/broadcast.drawio
new file mode 100644
index 00000000..230870f2
--- /dev/null
+++ b/plugins/websockets/doc/broadcast.drawio
@@ -0,0 +1 @@
+<mxfile host="Electron" modified="2021-05-27T20:56:56.848Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/14.5.1 Chrome/89.0.4389.128 Electron/12.0.9 Safari/537.36" etag="Pt0MY_-SPz7R7foQA1VL" version="14.5.1" type="device"><diagram id="fD2kwGC0DAS2S_q_IsmE" name="Page-1">7V1Zc9rIFv411HUeULV28WgwTjxjx06wJ8l9mRKoAU2ExEjCS3797W4t9IbAoMZLrqcqg1oLrT77OV8fOuZg8fgx9ZfzqySAUccAwWPHPOsYhmECC/0PjzwVI7ppesXILA2Dcmw9MAp/wXIQlKOrMIAZc2GeJFEeLtnBSRLHcJIzY36aJg/sZdMkYr916c+gMDCa+JE4+i0M8nkxagIA1ic+wXA2z/kzC7+6uhzI5n6QPFBD5rBjDtIkyYtPi8cBjPDyVQtT3He+4Ww9sxTG+S43PFlXl7c/+n48uTWjn1/SEMSn3V45t/ypemMYoAUoD5M0nyezJPaj4Xq0vx69TJIlukxHg//APH8q6eev8gQNzfNFVJ6Fj2H+nfr8A30GmmGXh2eYYUB18FQdxHn69J0+oG/Dx+v7yFF1o7gy5WJlySqdwIblqFjMT2cwb7qupCleLOobyoX/CJMFRBNCF6Qw8vPwnuUmv2TKWX1dfetNEqI5G6CUIN2y7eKeSoB0y2AfUsy1vG9NfvSBmsh6iDCFnEGuvFUc/53d9od/zbI/xr9On7y7rm7swSFpsooDGJTkOIBfKB75QZ3axi8Ui/xgOETOLxlaw/wUaws0ECcxrMbOQ7xYFE9x9FbLZNaBLMWwwrPp7hQPvvejVflVlpQTLv0xUvoM9fwonMXo8wQtDUzRwD1M8xAp1dPyxCIMgoJRYBb+8sfkeXiVl5iNyavY/Y59Vq87fgB8ZN641PjlzWs1S1OkgaPFdS0f3wUacD2PkbqKZPtKc/WYHntHMp1mUInUms7LSq1LSy3YUWpZLe9u0/IvL7XSpTeclxTbyjpQYvsV+gEa+Tz8fnvyATsmMMuw14O8oTRZ4Fea44PlapytxnhOMb48g3GQ4W/LictVX/aQpD9his8skR/2PjSCuYFk5eOBpjuGyyoE4zCFUOka1+H0jHssDWHrAqPIdcZ7oybS77Zlm+y6m62Qk+UR3VNAyybVRJHy27A/uh78ObwdofFPd2iVwclo+PWvIfpwfvd58EEg9cM8zOFo6RMd+ICiOZbkU6RTB0mUpORqM/ChN50QhZsmPyF1xpl4cDzFdyRxTo0D8tfknAs8sJGEumFxXrFhlwL5sA7SdKvkmDkVnzlAlb/0igzuzm7ys7zkFm2psaMt1e2XtKWiBzzy7+FaNYVJXJvLB58Yyili98pSjtPEDyZ+lq9tbvY+NGyzdALNNAxWPtsyl6brauyTLbsaOILJNAWG8H4HgiKLaRkeR9EDI6Ly0Y6hWazVdLRqmdUT1Ph/buOlNPvB+bKDNLttCZIsT4S+OUk2mpMbiHt6wFChm3X2qci3UJHsaOI52v3Fbu8VocMDWjGBtM9zdqc2/k/q7JI/0dkt/tpydh1W9TpeyQy0rwskvq6lzNcFwooeUW/uWUOgEkq1ut1Vb04iP8vCCac69XZVpyWqzs3R3qGVBvRm/hN1QambNos3qLiuruSVeuR89ztc0HyHaTffgT4U8243V/rCXsDuTsAutnxvR+GtcrKYhelxXGT1elqP+uMChg0FtOdKSM/hmbfXzO7CPJ99g+5tESiDF0HuDjUCVUkxZZDt9+Fdmc3eFQqUXNdhS0ftZBaRyXW0qgBXh2A69xx1LpboL38dnl2M8PcY4OJz92p4df31x6F+1hQ6E2lSMXB7Y5I8VJhU7AFOVoyeJKdoKPKzTOvvs+ki/1N/7D1G/t/9weWPv8r0F7Ps5h626hDrxPha7o6+ls7Yne2VvP0tTeWJ7mZqBPPWVPKxqpCw4gf7MEk+QhlA1Ltfb9CjQP9udKBoQj2woSsTzZ7jmr4kBGo13+/slO+XyWYb+f7B2Y9o8V83eFwtT/8MgssvZ3nW3afAvo8g7pmsf369ZichlC7FrngWY0cZbD0HJJ21JwjLoMQQMvXuj9d4ynGwLMX4hCmaI09iGs7EatrRvBuaXk18urMX1OTcOI7L6sS28vo9U+PTy0dLAleahGIE9+Xouae3upn0DSbO0S2dXXWrFXLqntbj4BK6oTmGAoo26SOKoAI9/WxZoIWn4SOO4mmiiqQRyBwuCGyY2LxSdevGevwsXMzQzKNwjP71f61SiF9yBmOY+mj2532MS4aplt3PWjKQllAYs3UxS+jVF9Em0lXmvu6TWGktlbIr6kx7Vlpwfytp7mgl24cWH0REQ3Qq5WQ9urlr5LnD7R3QDNNWUvTk4gqkGlWUSpqYkCLmzV3/8mL0SSDpM+MDc+IBIPMv+6Bny0L3KflrKz5wdI2DHCDbtqsCNJUpQNES7QMTeuUqcfcK8z5q0xbVZhNn7x7gG5xog31k+9llFBfwbFr6RRtzuNvuYHO4ivINTaShlcmnmwMViQ9swwikgeqZfq5ekfCJBssVtYg00dCGEpGuck9Y5S46vFvOUj8QwGgkcH1ADOhEaF79cYo+zfAnfE+JXitx3HCcJZOfmDOYR0hvxNJMrlssCrhbhBURFhY/COMZfu3iRWJ0STGwRryxMHPZ5ASeQeTK5a5EBKf5YY4EVpZZ6bO3xDWmy1Xoq2OKaUxHwjS2MssjMA2BRAyiEJZVrmeIaBtLZLlsWcSWgBg894i2WRSrii+XzOI4/67wBs3+AhmgELHaKToLlo/oX7IyoBjv5tgw43MWdQ6zcbfkXHyu9IOZ0wGcJKlPZI9cg8vKaRQiK1p/9VpMSsGpBq7qnRxIW64m+QpDXYpr0JqM+fvQ2JIfm6f8yF6vX4glPunhk/zMb5NlOCH3nBWTJTpi86xUzKGP7UkqTAKcpDAIkVIaEKW1SNKnD6KGuvGfooRsoynuTv0H9O/4KYdSdXbcNxvUavnkH6TuileJINH2AyJPyFKk2Qf5vDhdwGtWwti8+k3QVdOIOHxT4uQJmRB8fO4vwgjL+ycY3UP81Pb1isd7R6YHtCp5SKtfmc32epprK9IvkmKM6PrHAec3SzP/Haaatq2YJodJNVfu9nfPJVmNJnW71T2nSGY3GMxDM5Suw/t5vCku3lxAhUiSnTr/KMA9SvEO7X1KworAePvt53/WRs+2oXhyFK3I1ZsrWkfIzMm3egoKZhgHxAVAUkRunyRRVBa05tjEgQUFiS03rsDHZZLhwyJ+ID76PM+xnVpGK2yVXizZ1ypyp9GMdIGGTQQjx23VtnTd1mzu0dXxEbY4CEzy6fYWBeng5vLu48VngbisYtgSGAgk56vONvQCSxbMe8bYJMDpjZQWxHazDwAqW1uHXzvG7Mo2A7rCql8slhFcQJIcAjer8YhstQ7xuk3x8u4eAbchX9IY+CAiiGBGS0IEV0IEQxURJNlXtPBRmM07TIV/XAYGb0zNbRMKE2gAWJZnuLbnGIDbm6yO9y051vNV+yPPw9ao21Il52MJnFpN54n9UtoOJ/qeswWVLNywBfeMXOxe0w1qUMyWuP13UOdJL87eyVZea8MO67VjZFoWB/ppCSUCND5teiynyBbt8waswjukJ+C38raEUDcNjQuqu0drZqKLsdAn5GF0l8i7StIFCtqwoe+PbgmFk7RImj6E+bz2AnKcoszqsAgFmckk9HMYdOj6iEzuW3WZp94EyjHwY8/GSYp2XGabQzZ45ku7zJYcdPfKG2i4r9pxkNTOX2PLKhGJ+5GUJf1JvsJ4OfCQdZtKlOj//gLLWTzOSLK66MiRIkEkK1el7pG+Kr3+xTtR9u42ZW/x+7LbQTDZumZ5ki11VVMOXVPR+0huy8VCXmUNfgPymlzyua3N/JZmHY2Ckl1Ot3OMLOhm/pSkI/1s3l34S95k/4RP2GB3RQtdGnHyxOoSRokotuIO8Hu6dLuMceY6O6BYDrPurskGWK4n4hHqDsttQ4HlNBY97nUupm5I+MZzYLrt8sVI2TYl2cKrS4KJxrXYEzZeETnJc38yJ04u3eORtHeUZ//fHFEsfl+nqYtEMWT9M9QRRbRZBVEmBYyKqKsKn5VptaS8eVK4HCks3daALoqId0xqSBDulY8ZhPc8gCNb+nE1RjBVlZwgN/snMTZsVplFv617r57UeJRpCKOABmfQX0ENs5OReLwYyJJW9T4/Z76Okm5K5jekvo8OvzuIrTyhLUJPDGZNU8JSbYDL5NXGV1STV9Vk/8jBbHXh9p6RL9o23ZRkopB0Rljw/rjGBVes5y+Hp6QVawmolWDIRnf90eDrRR9ddX73mTp6D4GNubUajzxotkPvgZvvyycbKK5htzzyKkBhoCqGORv6aL03egJNdw122VtquNw1WC7Rdc1V0ROlUSPRAJzvw8Hd7fVXUZ7LJuwUKv6kEn7ck73A06sMSfWxr0NDFpIC4AxPzxWHpHzHZk9io6Vun7KEsyHW99RSwEdBhiFvI3he9spQmhSwPb7gptc/VvRiaX9DTAy8byo4EipIdioemQpiluB9U8GVUMF4cSrI9qS8Zyp4Eiq8eCHSfOFC5H5NoRXt2G0ldKtI+spDN0vcHF/UIetWP0yJge0DtMpyclhgDMZPHQZXIOWot+fmb61H2U6V8D7YsXc1dsOgfrxATXQMKyjpb0BBg2vx1E7BuKv3+Mi7q6R7pfy15T/T8KqxJVwPxVeWjdtdpR/6Cy77YVI9j+PiyofYhDE1gJBFfu4dHnuDIlSqmF8sjBSHkWHNE1VBp0zTMk3uw2AHhNvb1HXeNl3nuNxvj7Rju1yDZSTjaIZLjOJ0eSHiHRLTtjwFtOxSP2RR2a2jAWMqJfvbxIM9STwo20Z03HhQTPPyv4EJTj5eiz1A31x53uWQEj3ZPrqjolcsMSNyCWcwDkg3hE19F0ihXNJlgW/HQYrZG7ts0FV4XeuQXXs8YIm1sqT7EfCXy+fW8hXN2sCzptA+r2JSJp7UunFJhVngIUg1fAI36gynIT7PYh7C9X7lRTbrUEDfV/GeFssyslejO4+EcZdrPnLE+dt4rt9qwFGHArVgHOYa1LKJ8av9rJ/R15W/MYzCpYD1MSnPs/QDyHCFg1H3eg7LcstVNicVuDUevFP9rvF/KKX+ZQVXNXRmXDrME/QYGKicrYtnWzVqqh1z0qqGmi6BHq3bUqubjkemQ0cWqyVpPcYJ7EPJyKs4/Jcs293dxRmxCRn6hvXV+Gcy8XUFZIn8aGZNEOpxVMRykq/RwB9UvmoPv+rwEU6QaU1L3u8YjZ3MJK3TVMxMB3hq59SPi27EkjBzIhDPlIC+AkKx8xQpmfuaGufMC/VHt91qG5TCd8HGlFXlWIVcX18Wd90U0jfjOtMVSVd59Frv3eJfR+l7FOYV5qs0Ltmd4+FCjyicAbGlTXtlhDXJBWA9kTGFcyR28FvDXhyeYeVzeeO+NY98le23s3sy15rPe+/gW2MxTzAV1sEoCgXnV0mAg/7h/wA=</diagram></mxfile> \ No newline at end of file
diff --git a/plugins/websockets/doc/doc.go b/plugins/websockets/doc/doc.go
new file mode 100644
index 00000000..fc214be8
--- /dev/null
+++ b/plugins/websockets/doc/doc.go
@@ -0,0 +1,27 @@
+package doc
+
+/*
+RPC message structure:
+
+type Msg struct {
+ // Topic message been pushed into.
+ Topics_ []string `json:"topic"`
+
+ // Command (join, leave, headers)
+ Command_ string `json:"command"`
+
+ // Broker (redis, memory)
+ Broker_ string `json:"broker"`
+
+ // Payload to be broadcasted
+ Payload_ []byte `json:"payload"`
+}
+
+1. Topics - string array (slice) with topics to join or leave
+2. Command - string, command to apply on the provided topics
+3. Broker - string, pub-sub broker to use, for the one-node systems might be used `memory` broker or `redis`. For the multi-node -
+`redis` broker should be used.
+4. Payload - raw byte array to send to the subscribers (binary messages).
+
+
+*/
diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go
new file mode 100644
index 00000000..24ea19ce
--- /dev/null
+++ b/plugins/websockets/executor/executor.go
@@ -0,0 +1,226 @@
+package executor
+
+import (
+ "fmt"
+ "net/http"
+ "sync"
+
+ "github.com/fasthttp/websocket"
+ json "github.com/json-iterator/go"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/commands"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/connection"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/validator"
+)
+
+type Response struct {
+ Topic string `json:"topic"`
+ Payload []string `json:"payload"`
+}
+
+type Executor struct {
+ sync.Mutex
+ conn *connection.Connection
+ storage *storage.Storage
+ log logger.Logger
+
+ // associated connection ID
+ connID string
+
+ // map with the pubsub drivers
+ pubsub map[string]pubsub.PubSub
+ actualTopics map[string]struct{}
+
+ req *http.Request
+ accessValidator validator.AccessValidatorFn
+}
+
+// NewExecutor creates protected connection and starts command loop
+func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage,
+ connID string, pubsubs map[string]pubsub.PubSub, av validator.AccessValidatorFn, r *http.Request) *Executor {
+ return &Executor{
+ conn: conn,
+ connID: connID,
+ storage: bst,
+ log: log,
+ pubsub: pubsubs,
+ accessValidator: av,
+ actualTopics: make(map[string]struct{}, 10),
+ req: r,
+ }
+}
+
+func (e *Executor) StartCommandLoop() error { //nolint:gocognit
+ const op = errors.Op("executor_command_loop")
+ for {
+ mt, data, err := e.conn.Read()
+ if err != nil {
+ if mt == -1 {
+ e.log.Info("socket was closed", "reason", err, "message type", mt)
+ return nil
+ }
+
+ return errors.E(op, err)
+ }
+
+ msg := &pubsub.Message{}
+
+ err = json.Unmarshal(data, msg)
+ if err != nil {
+ e.log.Error("error unmarshal message", "error", err)
+ continue
+ }
+
+ // nil message, continue
+ if msg == nil {
+ e.log.Warn("get nil message, skipping")
+ continue
+ }
+
+ switch msg.Command {
+ // handle leave
+ case commands.Join:
+ e.log.Debug("get join command", "msg", msg)
+
+ val, err := e.accessValidator(e.req, msg.Topics...)
+ if err != nil {
+ if val != nil {
+ e.log.Debug("validation error", "status", val.Status, "headers", val.Header, "body", val.Body)
+ }
+
+ resp := &Response{
+ Topic: "#join",
+ Payload: msg.Topics,
+ }
+
+ packet, errJ := json.Marshal(resp)
+ if errJ != nil {
+ e.log.Error("error marshal the body", "error", errJ)
+ return errors.E(op, fmt.Errorf("%v,%v", err, errJ))
+ }
+
+ errW := e.conn.Write(websocket.BinaryMessage, packet)
+ if errW != nil {
+ e.log.Error("error writing payload to the connection", "payload", packet, "error", errW)
+ return errors.E(op, fmt.Errorf("%v,%v", err, errW))
+ }
+
+ continue
+ }
+
+ resp := &Response{
+ Topic: "@join",
+ Payload: msg.Topics,
+ }
+
+ packet, err := json.Marshal(resp)
+ if err != nil {
+ e.log.Error("error marshal the body", "error", err)
+ return errors.E(op, err)
+ }
+
+ err = e.conn.Write(websocket.BinaryMessage, packet)
+ if err != nil {
+ e.log.Error("error writing payload to the connection", "payload", packet, "error", err)
+ return errors.E(op, err)
+ }
+
+ // subscribe to the topic
+ if br, ok := e.pubsub[msg.Broker]; ok {
+ err = e.Set(br, msg.Topics)
+ if err != nil {
+ return errors.E(op, err)
+ }
+ }
+
+ // handle leave
+ case commands.Leave:
+ e.log.Debug("get leave command", "msg", msg)
+
+ // prepare response
+ resp := &Response{
+ Topic: "@leave",
+ Payload: msg.Topics,
+ }
+
+ packet, err := json.Marshal(resp)
+ if err != nil {
+ e.log.Error("error marshal the body", "error", err)
+ return errors.E(op, err)
+ }
+
+ err = e.conn.Write(websocket.BinaryMessage, packet)
+ if err != nil {
+ e.log.Error("error writing payload to the connection", "payload", packet, "error", err)
+ return errors.E(op, err)
+ }
+
+ if br, ok := e.pubsub[msg.Broker]; ok {
+ err = e.Leave(br, msg.Topics)
+ if err != nil {
+ return errors.E(op, err)
+ }
+ }
+
+ case commands.Headers:
+
+ default:
+ e.log.Warn("unknown command", "command", msg.Command)
+ }
+ }
+}
+
+func (e *Executor) Set(br pubsub.PubSub, topics []string) error {
+ // associate connection with topics
+ err := br.Subscribe(topics...)
+ if err != nil {
+ e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
+ // in case of error, unsubscribe connection from the dead topics
+ _ = br.Unsubscribe(topics...)
+ return err
+ }
+
+ e.storage.InsertMany(e.connID, topics)
+
+ // save topics for the connection
+ for i := 0; i < len(topics); i++ {
+ e.actualTopics[topics[i]] = struct{}{}
+ }
+
+ return nil
+}
+
+func (e *Executor) Leave(br pubsub.PubSub, topics []string) error {
+ // remove associated connections from the storage
+ e.storage.RemoveMany(e.connID, topics)
+ err := br.Unsubscribe(topics...)
+ if err != nil {
+ e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
+ return err
+ }
+
+ // remove topics for the connection
+ for i := 0; i < len(topics); i++ {
+ delete(e.actualTopics, topics[i])
+ }
+
+ return nil
+}
+
+func (e *Executor) CleanUp() {
+ for topic := range e.actualTopics {
+ // remove from the bst
+ e.storage.Remove(e.connID, topic)
+
+ for _, ps := range e.pubsub {
+ _ = ps.Unsubscribe(topic)
+ }
+ }
+
+ for k := range e.actualTopics {
+ delete(e.actualTopics, k)
+ }
+}
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
new file mode 100644
index 00000000..9b21ff8f
--- /dev/null
+++ b/plugins/websockets/plugin.go
@@ -0,0 +1,386 @@
+package websockets
+
+import (
+ "context"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/fasthttp/websocket"
+ "github.com/google/uuid"
+ json "github.com/json-iterator/go"
+ endure "github.com/spiral/endure/pkg/container"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/payload"
+ phpPool "github.com/spiral/roadrunner/v2/pkg/pool"
+ "github.com/spiral/roadrunner/v2/pkg/process"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/pkg/worker"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/http/attributes"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/plugins/server"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/connection"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/executor"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/pool"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/validator"
+)
+
+const (
+ PluginName string = "websockets"
+)
+
+type Plugin struct {
+ sync.RWMutex
+ // Collection with all available pubsubs
+ pubsubs map[string]pubsub.PubSub
+
+ cfg *Config
+ log logger.Logger
+
+ // global connections map
+ connections sync.Map
+ storage *storage.Storage
+
+ // GO workers pool
+ workersPool *pool.WorkersPool
+
+ wsUpgrade *websocket.Upgrader
+ serveExit chan struct{}
+
+ phpPool phpPool.Pool
+ server server.Server
+
+ // function used to validate access to the requested resource
+ accessValidator validator.AccessValidatorFn
+}
+
+func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error {
+ const op = errors.Op("websockets_plugin_init")
+ if !cfg.Has(PluginName) {
+ return errors.E(op, errors.Disabled)
+ }
+
+ err := cfg.UnmarshalKey(PluginName, &p.cfg)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ p.cfg.InitDefault()
+
+ p.pubsubs = make(map[string]pubsub.PubSub)
+ p.log = log
+ p.storage = storage.NewStorage()
+ p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log)
+ p.wsUpgrade = &websocket.Upgrader{
+ HandshakeTimeout: time.Second * 60,
+ }
+ p.serveExit = make(chan struct{})
+ p.server = server
+
+ return nil
+}
+
+func (p *Plugin) Serve() chan error {
+ errCh := make(chan error)
+
+ go func() {
+ var err error
+ p.Lock()
+ defer p.Unlock()
+
+ p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{
+ Debug: p.cfg.Pool.Debug,
+ NumWorkers: p.cfg.Pool.NumWorkers,
+ MaxJobs: p.cfg.Pool.MaxJobs,
+ AllocateTimeout: p.cfg.Pool.AllocateTimeout,
+ DestroyTimeout: p.cfg.Pool.DestroyTimeout,
+ Supervisor: p.cfg.Pool.Supervisor,
+ }, map[string]string{"RR_MODE": "http"})
+ if err != nil {
+ errCh <- err
+ }
+
+ p.accessValidator = p.defaultAccessValidator(p.phpPool)
+ }()
+
+ // run all pubsubs drivers
+ for _, v := range p.pubsubs {
+ go func(ps pubsub.PubSub) {
+ for {
+ select {
+ case <-p.serveExit:
+ return
+ default:
+ data, err := ps.Next()
+ if err != nil {
+ errCh <- err
+ return
+ }
+ p.workersPool.Queue(data)
+ }
+ }
+ }(v)
+ }
+
+ return errCh
+}
+
+func (p *Plugin) Stop() error {
+ // close workers pool
+ p.workersPool.Stop()
+ p.Lock()
+ p.phpPool.Destroy(context.Background())
+ p.Unlock()
+ return nil
+}
+
+func (p *Plugin) Collects() []interface{} {
+ return []interface{}{
+ p.GetPublishers,
+ }
+}
+
+func (p *Plugin) Available() {}
+
+func (p *Plugin) RPC() interface{} {
+ return &rpc{
+ plugin: p,
+ log: p.log,
+ }
+}
+
+func (p *Plugin) Name() string {
+ return PluginName
+}
+
+// GetPublishers collects all pubsubs
+func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) {
+ p.pubsubs[name.Name()] = pub
+}
+
+func (p *Plugin) Middleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != p.cfg.Path {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // we need to lock here, because accessValidator might not be set in the Serve func at the moment
+ p.RLock()
+ // before we hijacked connection, we still can write to the response headers
+ val, err := p.accessValidator(r)
+ p.RUnlock()
+ if err != nil {
+ p.log.Error("validation error")
+ w.WriteHeader(400)
+ return
+ }
+
+ if val.Status != http.StatusOK {
+ for k, v := range val.Header {
+ for i := 0; i < len(v); i++ {
+ w.Header().Add(k, v[i])
+ }
+ }
+ w.WriteHeader(val.Status)
+ _, _ = w.Write(val.Body)
+ return
+ }
+
+ // upgrade connection to websocket connection
+ _conn, err := p.wsUpgrade.Upgrade(w, r, nil)
+ if err != nil {
+ // connection hijacked, do not use response.writer or request
+ p.log.Error("upgrade connection", "error", err)
+ return
+ }
+
+ // construct safe connection protected by mutexes
+ safeConn := connection.NewConnection(_conn, p.log)
+ // generate UUID from the connection
+ connectionID := uuid.NewString()
+ // store connection
+ p.connections.Store(connectionID, safeConn)
+
+ defer func() {
+ // close the connection on exit
+ err = safeConn.Close()
+ if err != nil {
+ p.log.Error("connection close", "error", err)
+ }
+
+ // when exiting - delete the connection
+ p.connections.Delete(connectionID)
+ }()
+
+ // Executor wraps a connection to have a safe abstraction
+ e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r)
+ p.log.Info("websocket client connected", "uuid", connectionID)
+ defer e.CleanUp()
+
+ err = e.StartCommandLoop()
+ if err != nil {
+ p.log.Error("command loop error, disconnecting", "error", err.Error())
+ return
+ }
+
+ p.log.Info("disconnected", "connectionID", connectionID)
+ })
+}
+
+// Workers returns slice with the process states for the workers
+func (p *Plugin) Workers() []process.State {
+ p.RLock()
+ defer p.RUnlock()
+
+ workers := p.workers()
+
+ ps := make([]process.State, 0, len(workers))
+ for i := 0; i < len(workers); i++ {
+ state, err := process.WorkerProcessState(workers[i])
+ if err != nil {
+ return nil
+ }
+ ps = append(ps, state)
+ }
+
+ return ps
+}
+
+// internal
+func (p *Plugin) workers() []worker.BaseProcess {
+ return p.phpPool.Workers()
+}
+
+// Reset destroys the old pool and replaces it with new one, waiting for old pool to die
+func (p *Plugin) Reset() error {
+ p.Lock()
+ defer p.Unlock()
+ const op = errors.Op("ws_plugin_reset")
+ p.log.Info("WS plugin got restart request. Restarting...")
+ p.phpPool.Destroy(context.Background())
+ p.phpPool = nil
+
+ var err error
+ p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{
+ Debug: p.cfg.Pool.Debug,
+ NumWorkers: p.cfg.Pool.NumWorkers,
+ MaxJobs: p.cfg.Pool.MaxJobs,
+ AllocateTimeout: p.cfg.Pool.AllocateTimeout,
+ DestroyTimeout: p.cfg.Pool.DestroyTimeout,
+ Supervisor: p.cfg.Pool.Supervisor,
+ }, map[string]string{"RR_MODE": "http"})
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ // attach validators
+ p.accessValidator = p.defaultAccessValidator(p.phpPool)
+
+ p.log.Info("WS plugin successfully restarted")
+ return nil
+}
+
+// Publish is an entry point to the websocket PUBSUB
+func (p *Plugin) Publish(msg []*pubsub.Message) error {
+ p.Lock()
+ defer p.Unlock()
+
+ for i := 0; i < len(msg); i++ {
+ for j := 0; j < len(msg[i].Topics); j++ {
+ if br, ok := p.pubsubs[msg[i].Broker]; ok {
+ err := br.Publish(msg)
+ if err != nil {
+ return errors.E(err)
+ }
+ } else {
+ p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker)
+ }
+ }
+ }
+ return nil
+}
+
+func (p *Plugin) PublishAsync(msg []*pubsub.Message) {
+ go func() {
+ p.Lock()
+ defer p.Unlock()
+ for i := 0; i < len(msg); i++ {
+ for j := 0; j < len(msg[i].Topics); j++ {
+ err := p.pubsubs[msg[i].Broker].Publish(msg)
+ if err != nil {
+ p.log.Error("publish async error", "error", err)
+ return
+ }
+ }
+ }
+ }()
+}
+
+func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn {
+ return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) {
+ p.RLock()
+ defer p.RUnlock()
+ const op = errors.Op("access_validator")
+
+ p.log.Debug("validation", "topics", topics)
+ r = attributes.Init(r)
+
+ // if channels len is eq to 0, we use serverValidator
+ if len(topics) == 0 {
+ ctx, err := validator.ServerAccessValidator(r)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val, err := exec(ctx, pool)
+ if err != nil {
+ return nil, errors.E(err)
+ }
+
+ return val, nil
+ }
+
+ ctx, err := validator.TopicsAccessValidator(r, topics...)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val, err := exec(ctx, pool)
+ if err != nil {
+ return nil, errors.E(op)
+ }
+
+ if val.Status != http.StatusOK {
+ return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status))
+ }
+
+ return val, nil
+ }
+}
+
+// go:inline
+func exec(ctx []byte, pool phpPool.Pool) (*validator.AccessValidator, error) {
+ const op = errors.Op("exec")
+ pd := payload.Payload{
+ Context: ctx,
+ }
+
+ resp, err := pool.Exec(pd)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val := &validator.AccessValidator{
+ Body: resp.Body,
+ }
+
+ err = json.Unmarshal(resp.Context, val)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ return val, nil
+}
diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go
new file mode 100644
index 00000000..8f18580f
--- /dev/null
+++ b/plugins/websockets/pool/workers_pool.go
@@ -0,0 +1,117 @@
+package pool
+
+import (
+ "sync"
+
+ "github.com/fasthttp/websocket"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/connection"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+)
+
+type WorkersPool struct {
+ storage *storage.Storage
+ connections *sync.Map
+ resPool sync.Pool
+ log logger.Logger
+
+ queue chan *pubsub.Message
+ exit chan struct{}
+}
+
+// NewWorkersPool constructs worker pool for the websocket connections
+func NewWorkersPool(storage *storage.Storage, connections *sync.Map, log logger.Logger) *WorkersPool {
+ wp := &WorkersPool{
+ connections: connections,
+ queue: make(chan *pubsub.Message, 100),
+ storage: storage,
+ log: log,
+ exit: make(chan struct{}),
+ }
+
+ wp.resPool.New = func() interface{} {
+ return make(map[string]struct{}, 10)
+ }
+
+ // start 10 workers
+ for i := 0; i < 10; i++ {
+ wp.do()
+ }
+
+ return wp
+}
+
+func (wp *WorkersPool) Queue(msg *pubsub.Message) {
+ wp.queue <- msg
+}
+
+func (wp *WorkersPool) Stop() {
+ for i := 0; i < 10; i++ {
+ wp.exit <- struct{}{}
+ }
+
+ close(wp.exit)
+}
+
+func (wp *WorkersPool) put(res map[string]struct{}) {
+ // optimized
+ // https://go-review.googlesource.com/c/go/+/110055/
+ // not O(n), but O(1)
+ for k := range res {
+ delete(res, k)
+ }
+}
+
+func (wp *WorkersPool) get() map[string]struct{} {
+ return wp.resPool.Get().(map[string]struct{})
+}
+
+func (wp *WorkersPool) do() { //nolint:gocognit
+ go func() {
+ for {
+ select {
+ case msg, ok := <-wp.queue:
+ if !ok {
+ return
+ }
+ // do not handle nil's
+ if msg == nil {
+ continue
+ }
+ if len(msg.Topics) == 0 {
+ continue
+ }
+ res := wp.get()
+ // get connections for the particular topic
+ wp.storage.GetByPtr(msg.Topics, res)
+ if len(res) == 0 {
+ wp.log.Info("no such topic", "topic", msg.Topics)
+ wp.put(res)
+ continue
+ }
+
+ for i := range res {
+ c, ok := wp.connections.Load(i)
+ if !ok {
+ wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.Broker, "topics", msg.Topics)
+ continue
+ }
+
+ conn := c.(*connection.Connection)
+ err := conn.Write(websocket.BinaryMessage, msg.Payload)
+ if err != nil {
+ wp.log.Error("error sending payload over the connection", "broker", msg.Broker, "topics", msg.Topics)
+ wp.put(res)
+ continue
+ }
+ }
+
+ wp.put(res)
+ case <-wp.exit:
+ wp.log.Info("get exit signal, exiting from the workers pool")
+ return
+ }
+ }
+ }()
+}
diff --git a/plugins/websockets/rpc.go b/plugins/websockets/rpc.go
new file mode 100644
index 00000000..2fb0f1b9
--- /dev/null
+++ b/plugins/websockets/rpc.go
@@ -0,0 +1,47 @@
+package websockets
+
+import (
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+)
+
+// rpc collectors struct
+type rpc struct {
+ plugin *Plugin
+ log logger.Logger
+}
+
+func (r *rpc) Publish(msg []*pubsub.Message, ok *bool) error {
+ const op = errors.Op("broadcast_publish")
+ r.log.Debug("message published", "msg", msg)
+
+ // just return in case of nil message
+ if msg == nil {
+ *ok = true
+ return nil
+ }
+
+ err := r.plugin.Publish(msg)
+ if err != nil {
+ *ok = false
+ return errors.E(op, err)
+ }
+ *ok = true
+ return nil
+}
+
+func (r *rpc) PublishAsync(msg []*pubsub.Message, ok *bool) error {
+ r.log.Debug("message published", "msg", msg)
+
+ // just return in case of nil message
+ if msg == nil {
+ *ok = true
+ return nil
+ }
+ // publish to the registered broker
+ r.plugin.PublishAsync(msg)
+
+ *ok = true
+ return nil
+}
diff --git a/plugins/websockets/schema/message.fbs b/plugins/websockets/schema/message.fbs
new file mode 100644
index 00000000..f2d92c78
--- /dev/null
+++ b/plugins/websockets/schema/message.fbs
@@ -0,0 +1,10 @@
+namespace message;
+
+table Message {
+ command:string;
+ broker:string;
+ topics:[string];
+ payload:[byte];
+}
+
+root_type Message;
diff --git a/plugins/websockets/schema/message/Message.go b/plugins/websockets/schema/message/Message.go
new file mode 100644
index 00000000..26bbd12c
--- /dev/null
+++ b/plugins/websockets/schema/message/Message.go
@@ -0,0 +1,118 @@
+// Code generated by the FlatBuffers compiler. DO NOT EDIT.
+
+package message
+
+import (
+ flatbuffers "github.com/google/flatbuffers/go"
+)
+
+type Message struct {
+ _tab flatbuffers.Table
+}
+
+func GetRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message {
+ n := flatbuffers.GetUOffsetT(buf[offset:])
+ x := &Message{}
+ x.Init(buf, n+offset)
+ return x
+}
+
+func GetSizePrefixedRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message {
+ n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:])
+ x := &Message{}
+ x.Init(buf, n+offset+flatbuffers.SizeUint32)
+ return x
+}
+
+func (rcv *Message) Init(buf []byte, i flatbuffers.UOffsetT) {
+ rcv._tab.Bytes = buf
+ rcv._tab.Pos = i
+}
+
+func (rcv *Message) Table() flatbuffers.Table {
+ return rcv._tab
+}
+
+func (rcv *Message) Command() []byte {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
+ if o != 0 {
+ return rcv._tab.ByteVector(o + rcv._tab.Pos)
+ }
+ return nil
+}
+
+func (rcv *Message) Broker() []byte {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
+ if o != 0 {
+ return rcv._tab.ByteVector(o + rcv._tab.Pos)
+ }
+ return nil
+}
+
+func (rcv *Message) Topics(j int) []byte {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
+ if o != 0 {
+ a := rcv._tab.Vector(o)
+ return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j*4))
+ }
+ return nil
+}
+
+func (rcv *Message) TopicsLength() int {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
+ if o != 0 {
+ return rcv._tab.VectorLen(o)
+ }
+ return 0
+}
+
+func (rcv *Message) Payload(j int) int8 {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
+ if o != 0 {
+ a := rcv._tab.Vector(o)
+ return rcv._tab.GetInt8(a + flatbuffers.UOffsetT(j*1))
+ }
+ return 0
+}
+
+func (rcv *Message) PayloadLength() int {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
+ if o != 0 {
+ return rcv._tab.VectorLen(o)
+ }
+ return 0
+}
+
+func (rcv *Message) MutatePayload(j int, n int8) bool {
+ o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
+ if o != 0 {
+ a := rcv._tab.Vector(o)
+ return rcv._tab.MutateInt8(a+flatbuffers.UOffsetT(j*1), n)
+ }
+ return false
+}
+
+func MessageStart(builder *flatbuffers.Builder) {
+ builder.StartObject(4)
+}
+func MessageAddCommand(builder *flatbuffers.Builder, command flatbuffers.UOffsetT) {
+ builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(command), 0)
+}
+func MessageAddBroker(builder *flatbuffers.Builder, broker flatbuffers.UOffsetT) {
+ builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(broker), 0)
+}
+func MessageAddTopics(builder *flatbuffers.Builder, topics flatbuffers.UOffsetT) {
+ builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(topics), 0)
+}
+func MessageStartTopicsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
+ return builder.StartVector(4, numElems, 4)
+}
+func MessageAddPayload(builder *flatbuffers.Builder, payload flatbuffers.UOffsetT) {
+ builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(payload), 0)
+}
+func MessageStartPayloadVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
+ return builder.StartVector(1, numElems, 1)
+}
+func MessageEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
+ return builder.EndObject()
+}
diff --git a/plugins/websockets/storage/storage.go b/plugins/websockets/storage/storage.go
new file mode 100644
index 00000000..ac256be2
--- /dev/null
+++ b/plugins/websockets/storage/storage.go
@@ -0,0 +1,79 @@
+package storage
+
+import (
+ "sync"
+
+ "github.com/spiral/roadrunner/v2/pkg/bst"
+)
+
+type Storage struct {
+ sync.RWMutex
+ BST bst.Storage
+}
+
+func NewStorage() *Storage {
+ return &Storage{
+ BST: bst.NewBST(),
+ }
+}
+
+func (s *Storage) InsertMany(connID string, topics []string) {
+ s.Lock()
+ defer s.Unlock()
+
+ for i := 0; i < len(topics); i++ {
+ s.BST.Insert(connID, topics[i])
+ }
+}
+
+func (s *Storage) Insert(connID string, topic string) {
+ s.Lock()
+ defer s.Unlock()
+
+ s.BST.Insert(connID, topic)
+}
+
+func (s *Storage) RemoveMany(connID string, topics []string) {
+ s.Lock()
+ defer s.Unlock()
+
+ for i := 0; i < len(topics); i++ {
+ s.BST.Remove(connID, topics[i])
+ }
+}
+
+func (s *Storage) Remove(connID string, topic string) {
+ s.Lock()
+ defer s.Unlock()
+
+ s.BST.Remove(connID, topic)
+}
+
+// GetByPtrTS Thread safe get
+func (s *Storage) GetByPtrTS(topics []string, res map[string]struct{}) {
+ s.Lock()
+ defer s.Unlock()
+
+ for i := 0; i < len(topics); i++ {
+ d := s.BST.Get(topics[i])
+ if len(d) > 0 {
+ for ii := range d {
+ res[ii] = struct{}{}
+ }
+ }
+ }
+}
+
+func (s *Storage) GetByPtr(topics []string, res map[string]struct{}) {
+ s.RLock()
+ defer s.RUnlock()
+
+ for i := 0; i < len(topics); i++ {
+ d := s.BST.Get(topics[i])
+ if len(d) > 0 {
+ for ii := range d {
+ res[ii] = struct{}{}
+ }
+ }
+ }
+}
diff --git a/plugins/websockets/storage/storage_test.go b/plugins/websockets/storage/storage_test.go
new file mode 100644
index 00000000..4072992a
--- /dev/null
+++ b/plugins/websockets/storage/storage_test.go
@@ -0,0 +1,299 @@
+package storage
+
+import (
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+)
+
+const predifined = "chat-1-2"
+
+func TestNewBST(t *testing.T) {
+ // create a new bst
+ g := NewStorage()
+
+ for i := 0; i < 100; i++ {
+ g.InsertMany(uuid.NewString(), []string{"comments"})
+ }
+
+ for i := 0; i < 100; i++ {
+ g.InsertMany(uuid.NewString(), []string{"comments2"})
+ }
+
+ for i := 0; i < 100; i++ {
+ g.InsertMany(uuid.NewString(), []string{"comments3"})
+ }
+
+ res := make(map[string]struct{}, 100)
+ assert.Len(t, res, 0)
+
+ // should be 100
+ g.GetByPtr([]string{"comments"}, res)
+ assert.Len(t, res, 100)
+
+ res = make(map[string]struct{}, 100)
+ assert.Len(t, res, 0)
+
+ // should be 100
+ g.GetByPtr([]string{"comments2"}, res)
+ assert.Len(t, res, 100)
+
+ res = make(map[string]struct{}, 100)
+ assert.Len(t, res, 0)
+
+ // should be 100
+ g.GetByPtr([]string{"comments3"}, res)
+ assert.Len(t, res, 100)
+}
+
+func BenchmarkGraph(b *testing.B) {
+ g := NewStorage()
+
+ for i := 0; i < 1000; i++ {
+ uid := uuid.New().String()
+ g.InsertMany(uuid.NewString(), []string{uid})
+ }
+
+ g.Insert(uuid.NewString(), predifined)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ res := make(map[string]struct{})
+
+ for i := 0; i < b.N; i++ {
+ g.GetByPtr([]string{predifined}, res)
+
+ for i := range res {
+ delete(res, i)
+ }
+ }
+}
+
+func BenchmarkBigSearch(b *testing.B) {
+ g1 := NewStorage()
+
+ predefinedSlice := make([]string, 0, 1000)
+ for i := 0; i < 1000; i++ {
+ predefinedSlice = append(predefinedSlice, uuid.NewString())
+ }
+ if predefinedSlice == nil {
+ b.FailNow()
+ }
+
+ for i := 0; i < 1000; i++ {
+ g1.Insert(uuid.NewString(), uuid.NewString())
+ }
+
+ for i := 0; i < 1000; i++ {
+ g1.Insert(uuid.NewString(), predefinedSlice[i])
+ }
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ res := make(map[string]struct{}, 333)
+
+ for i := 0; i < b.N; i++ {
+ g1.GetByPtr(predefinedSlice, res)
+
+ for i := range res {
+ delete(res, i)
+ }
+ }
+}
+
+func BenchmarkBigSearchWithRemoves(b *testing.B) {
+ g1 := NewStorage()
+
+ predefinedSlice := make([]string, 0, 1000)
+ for i := 0; i < 1000; i++ {
+ predefinedSlice = append(predefinedSlice, uuid.NewString())
+ }
+ if predefinedSlice == nil {
+ b.FailNow()
+ }
+
+ for i := 0; i < 1000; i++ {
+ g1.Insert(uuid.NewString(), uuid.NewString())
+ }
+
+ for i := 0; i < 1000; i++ {
+ g1.Insert(uuid.NewString(), predefinedSlice[i])
+ }
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ go func() {
+ tt := time.NewTicker(time.Microsecond)
+
+ res := make(map[string]struct{}, 1000)
+ for {
+ select {
+ case <-tt.C:
+ num := rand.Intn(1000) //nolint:gosec
+ g1.GetByPtr(predefinedSlice, res)
+ for k := range res {
+ g1.Remove(k, predefinedSlice[num])
+ }
+ }
+ }
+ }()
+
+ res := make(map[string]struct{}, 100)
+
+ for i := 0; i < b.N; i++ {
+ g1.GetByPtr(predefinedSlice, res)
+
+ for i := range res {
+ delete(res, i)
+ }
+ }
+}
+
+func TestBigSearchWithRemoves(t *testing.T) {
+ g1 := NewStorage()
+
+ predefinedSlice := make([]string, 0, 1000)
+ for i := 0; i < 1000; i++ {
+ predefinedSlice = append(predefinedSlice, uuid.NewString())
+ }
+ if predefinedSlice == nil {
+ t.FailNow()
+ }
+
+ for i := 0; i < 1000; i++ {
+ g1.Insert(uuid.NewString(), uuid.NewString())
+ }
+
+ for i := 0; i < 1000; i++ {
+ g1.Insert(uuid.NewString(), predefinedSlice[i])
+ }
+
+ stopCh := make(chan struct{})
+
+ go func() {
+ tt := time.NewTicker(time.Microsecond)
+
+ res := make(map[string]struct{}, 1000)
+ for {
+ select {
+ case <-tt.C:
+ num := rand.Intn(1000) //nolint:gosec
+ g1.GetByPtr(predefinedSlice, res)
+ for k := range res {
+ g1.Remove(k, predefinedSlice[num])
+ }
+
+ case <-stopCh:
+ tt.Stop()
+ return
+ }
+ }
+ }()
+
+ res := make(map[string]struct{}, 100)
+
+ for i := 0; i < 1000; i++ {
+ g1.GetByPtr(predefinedSlice, res)
+
+ for i := range res {
+ delete(res, i)
+ }
+ }
+
+ stopCh <- struct{}{}
+}
+
+func TestGraph(t *testing.T) {
+ g := NewStorage()
+
+ for i := 0; i < 1000; i++ {
+ uid := uuid.New().String()
+ g.Insert(uuid.NewString(), uid)
+ }
+
+ g.Insert(uuid.NewString(), predifined)
+
+ res := make(map[string]struct{})
+
+ g.GetByPtr([]string{predifined}, res)
+ assert.NotEmpty(t, res)
+ assert.Len(t, res, 1)
+}
+
+func TestTreeConcurrentContains(t *testing.T) {
+ g := NewStorage()
+
+ key1 := uuid.NewString()
+ key2 := uuid.NewString()
+ key3 := uuid.NewString()
+ key4 := uuid.NewString()
+ key5 := uuid.NewString()
+
+ g.Insert(key1, predifined)
+ g.Insert(key2, predifined)
+ g.Insert(key3, predifined)
+ g.Insert(key4, predifined)
+ g.Insert(key5, predifined)
+
+ res := make(map[string]struct{}, 100)
+
+ for i := 0; i < 100; i++ {
+ go func() {
+ g.GetByPtrTS([]string{predifined}, res)
+ }()
+
+ go func() {
+ g.GetByPtrTS([]string{predifined}, res)
+ }()
+
+ go func() {
+ g.GetByPtrTS([]string{predifined}, res)
+ }()
+
+ go func() {
+ g.GetByPtrTS([]string{predifined}, res)
+ }()
+ }
+
+ time.Sleep(time.Second * 5)
+
+ res2 := make(map[string]struct{}, 5)
+
+ g.GetByPtr([]string{predifined}, res2)
+ assert.NotEmpty(t, res2)
+ assert.Len(t, res2, 5)
+}
+
+func TestGraphRemove(t *testing.T) {
+ g := NewStorage()
+
+ key1 := uuid.NewString()
+ key2 := uuid.NewString()
+ key3 := uuid.NewString()
+ key4 := uuid.NewString()
+ key5 := uuid.NewString()
+
+ g.Insert(key1, predifined)
+ g.Insert(key2, predifined)
+ g.Insert(key3, predifined)
+ g.Insert(key4, predifined)
+ g.Insert(key5, predifined)
+
+ res := make(map[string]struct{}, 5)
+ g.GetByPtr([]string{predifined}, res)
+ assert.NotEmpty(t, res)
+ assert.Len(t, res, 5)
+
+ g.Remove(key1, predifined)
+
+ res2 := make(map[string]struct{}, 4)
+ g.GetByPtr([]string{predifined}, res2)
+ assert.NotEmpty(t, res2)
+ assert.Len(t, res2, 4)
+}
diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go
new file mode 100644
index 00000000..e666f846
--- /dev/null
+++ b/plugins/websockets/validator/access_validator.go
@@ -0,0 +1,76 @@
+package validator
+
+import (
+ "net/http"
+ "strings"
+
+ json "github.com/json-iterator/go"
+ "github.com/spiral/errors"
+ handler "github.com/spiral/roadrunner/v2/pkg/worker_handler"
+ "github.com/spiral/roadrunner/v2/plugins/http/attributes"
+)
+
+type AccessValidatorFn = func(r *http.Request, channels ...string) (*AccessValidator, error)
+
+type AccessValidator struct {
+ Header http.Header `json:"headers"`
+ Status int `json:"status"`
+ Body []byte
+}
+
+func ServerAccessValidator(r *http.Request) ([]byte, error) {
+ const op = errors.Op("server_access_validator")
+
+ err := attributes.Set(r, "ws:joinServer", true)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ defer delete(attributes.All(r), "ws:joinServer")
+
+ req := &handler.Request{
+ RemoteAddr: handler.FetchIP(r.RemoteAddr),
+ Protocol: r.Proto,
+ Method: r.Method,
+ URI: handler.URI(r),
+ Header: r.Header,
+ Cookies: make(map[string]string),
+ RawQuery: r.URL.RawQuery,
+ Attributes: attributes.All(r),
+ }
+
+ data, err := json.Marshal(req)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ return data, nil
+}
+
+func TopicsAccessValidator(r *http.Request, topics ...string) ([]byte, error) {
+ const op = errors.Op("topic_access_validator")
+ err := attributes.Set(r, "ws:joinTopics", strings.Join(topics, ","))
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ defer delete(attributes.All(r), "ws:joinTopics")
+
+ req := &handler.Request{
+ RemoteAddr: handler.FetchIP(r.RemoteAddr),
+ Protocol: r.Proto,
+ Method: r.Method,
+ URI: handler.URI(r),
+ Header: r.Header,
+ Cookies: make(map[string]string),
+ RawQuery: r.URL.RawQuery,
+ Attributes: attributes.All(r),
+ }
+
+ data, err := json.Marshal(req)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ return data, nil
+}