diff options
author | Valery Piashchynski <[email protected]> | 2021-06-01 00:10:31 +0300 |
---|---|---|
committer | GitHub <[email protected]> | 2021-06-01 00:10:31 +0300 |
commit | 548ee4432e48b316ada00feec1a6b89e67ae4f2f (patch) | |
tree | 5cd2aaeeafdb50e3e46824197c721223f54695bf /plugins/websockets | |
parent | 8cd696bbca8fac2ced30d8172c41b7434ec86650 (diff) | |
parent | df4d316d519cea6dff654bd917521a616a37f769 (diff) |
#660 feat(plugin): `broadcast` and `broadcast-ws` plugins update to RR2
#660 feat(plugin): `broadcast` and `broadcast-ws` plugins update to RR2
Diffstat (limited to 'plugins/websockets')
-rw-r--r-- | plugins/websockets/commands/enums.go | 9 | ||||
-rw-r--r-- | plugins/websockets/config.go | 58 | ||||
-rw-r--r-- | plugins/websockets/connection/connection.go | 67 | ||||
-rw-r--r-- | plugins/websockets/doc/broadcast.drawio | 1 | ||||
-rw-r--r-- | plugins/websockets/doc/doc.go | 27 | ||||
-rw-r--r-- | plugins/websockets/executor/executor.go | 226 | ||||
-rw-r--r-- | plugins/websockets/plugin.go | 386 | ||||
-rw-r--r-- | plugins/websockets/pool/workers_pool.go | 117 | ||||
-rw-r--r-- | plugins/websockets/rpc.go | 47 | ||||
-rw-r--r-- | plugins/websockets/schema/message.fbs | 10 | ||||
-rw-r--r-- | plugins/websockets/schema/message/Message.go | 118 | ||||
-rw-r--r-- | plugins/websockets/storage/storage.go | 79 | ||||
-rw-r--r-- | plugins/websockets/storage/storage_test.go | 299 | ||||
-rw-r--r-- | plugins/websockets/validator/access_validator.go | 76 |
14 files changed, 1520 insertions, 0 deletions
diff --git a/plugins/websockets/commands/enums.go b/plugins/websockets/commands/enums.go new file mode 100644 index 00000000..18c63be3 --- /dev/null +++ b/plugins/websockets/commands/enums.go @@ -0,0 +1,9 @@ +package commands + +type Command string + +const ( + Leave string = "leave" + Join string = "join" + Headers string = "headers" +) diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go new file mode 100644 index 00000000..be4aaa82 --- /dev/null +++ b/plugins/websockets/config.go @@ -0,0 +1,58 @@ +package websockets + +import ( + "time" + + "github.com/spiral/roadrunner/v2/pkg/pool" +) + +/* +websockets: + # pubsubs should implement PubSub interface to be collected via endure.Collects + + pubsubs:["redis", "amqp", "memory"] + # path used as websockets path + path: "/ws" +*/ + +// Config represents configuration for the ws plugin +type Config struct { + // http path for the websocket + Path string `mapstructure:"path"` + // ["redis", "amqp", "memory"] + PubSubs []string `mapstructure:"pubsubs"` + Middleware []string `mapstructure:"middleware"` + + Pool *pool.Config `mapstructure:"pool"` +} + +// InitDefault initialize default values for the ws config +func (c *Config) InitDefault() { + if c.Path == "" { + c.Path = "/ws" + } + if len(c.PubSubs) == 0 { + // memory used by default + c.PubSubs = append(c.PubSubs, "memory") + } + + if c.Pool == nil { + c.Pool = &pool.Config{} + if c.Pool.NumWorkers == 0 { + // 2 workers by default + c.Pool.NumWorkers = 2 + } + + if c.Pool.AllocateTimeout == 0 { + c.Pool.AllocateTimeout = time.Minute + } + + if c.Pool.DestroyTimeout == 0 { + c.Pool.DestroyTimeout = time.Minute + } + if c.Pool.Supervisor == nil { + return + } + c.Pool.Supervisor.InitDefaults() + } +} diff --git a/plugins/websockets/connection/connection.go b/plugins/websockets/connection/connection.go new file mode 100644 index 00000000..2b847173 --- /dev/null +++ b/plugins/websockets/connection/connection.go @@ -0,0 +1,67 @@ +package connection + +import ( + "sync" + + "github.com/fasthttp/websocket" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// Connection represents wrapped and safe to use from the different threads websocket connection +type Connection struct { + sync.RWMutex + log logger.Logger + conn *websocket.Conn +} + +func NewConnection(wsConn *websocket.Conn, log logger.Logger) *Connection { + return &Connection{ + conn: wsConn, + log: log, + } +} + +func (c *Connection) Write(mt int, data []byte) error { + c.Lock() + defer c.Unlock() + + const op = errors.Op("websocket_write") + // handle a case when a goroutine tried to write into the closed connection + defer func() { + if r := recover(); r != nil { + c.log.Warn("panic handled, tried to write into the closed connection") + } + }() + + err := c.conn.WriteMessage(mt, data) + if err != nil { + return errors.E(op, err) + } + + return nil +} + +func (c *Connection) Read() (int, []byte, error) { + const op = errors.Op("websocket_read") + + mt, data, err := c.conn.ReadMessage() + if err != nil { + return -1, nil, errors.E(op, err) + } + + return mt, data, nil +} + +func (c *Connection) Close() error { + c.Lock() + defer c.Unlock() + const op = errors.Op("websocket_close") + + err := c.conn.Close() + if err != nil { + return errors.E(op, err) + } + + return nil +} diff --git a/plugins/websockets/doc/broadcast.drawio b/plugins/websockets/doc/broadcast.drawio new file mode 100644 index 00000000..230870f2 --- /dev/null +++ b/plugins/websockets/doc/broadcast.drawio @@ -0,0 +1 @@ +<mxfile host="Electron" modified="2021-05-27T20:56:56.848Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/14.5.1 Chrome/89.0.4389.128 Electron/12.0.9 Safari/537.36" etag="Pt0MY_-SPz7R7foQA1VL" version="14.5.1" type="device"><diagram id="fD2kwGC0DAS2S_q_IsmE" name="Page-1">7V1Zc9rIFv411HUeULV28WgwTjxjx06wJ8l9mRKoAU2ExEjCS3797W4t9IbAoMZLrqcqg1oLrT77OV8fOuZg8fgx9ZfzqySAUccAwWPHPOsYhmECC/0PjzwVI7ppesXILA2Dcmw9MAp/wXIQlKOrMIAZc2GeJFEeLtnBSRLHcJIzY36aJg/sZdMkYr916c+gMDCa+JE4+i0M8nkxagIA1ic+wXA2z/kzC7+6uhzI5n6QPFBD5rBjDtIkyYtPi8cBjPDyVQtT3He+4Ww9sxTG+S43PFlXl7c/+n48uTWjn1/SEMSn3V45t/ypemMYoAUoD5M0nyezJPaj4Xq0vx69TJIlukxHg//APH8q6eev8gQNzfNFVJ6Fj2H+nfr8A30GmmGXh2eYYUB18FQdxHn69J0+oG/Dx+v7yFF1o7gy5WJlySqdwIblqFjMT2cwb7qupCleLOobyoX/CJMFRBNCF6Qw8vPwnuUmv2TKWX1dfetNEqI5G6CUIN2y7eKeSoB0y2AfUsy1vG9NfvSBmsh6iDCFnEGuvFUc/53d9od/zbI/xr9On7y7rm7swSFpsooDGJTkOIBfKB75QZ3axi8Ui/xgOETOLxlaw/wUaws0ECcxrMbOQ7xYFE9x9FbLZNaBLMWwwrPp7hQPvvejVflVlpQTLv0xUvoM9fwonMXo8wQtDUzRwD1M8xAp1dPyxCIMgoJRYBb+8sfkeXiVl5iNyavY/Y59Vq87fgB8ZN641PjlzWs1S1OkgaPFdS0f3wUacD2PkbqKZPtKc/WYHntHMp1mUInUms7LSq1LSy3YUWpZLe9u0/IvL7XSpTeclxTbyjpQYvsV+gEa+Tz8fnvyATsmMMuw14O8oTRZ4Fea44PlapytxnhOMb48g3GQ4W/LictVX/aQpD9his8skR/2PjSCuYFk5eOBpjuGyyoE4zCFUOka1+H0jHssDWHrAqPIdcZ7oybS77Zlm+y6m62Qk+UR3VNAyybVRJHy27A/uh78ObwdofFPd2iVwclo+PWvIfpwfvd58EEg9cM8zOFo6RMd+ICiOZbkU6RTB0mUpORqM/ChN50QhZsmPyF1xpl4cDzFdyRxTo0D8tfknAs8sJGEumFxXrFhlwL5sA7SdKvkmDkVnzlAlb/0igzuzm7ys7zkFm2psaMt1e2XtKWiBzzy7+FaNYVJXJvLB58Yyili98pSjtPEDyZ+lq9tbvY+NGyzdALNNAxWPtsyl6brauyTLbsaOILJNAWG8H4HgiKLaRkeR9EDI6Ly0Y6hWazVdLRqmdUT1Ph/buOlNPvB+bKDNLttCZIsT4S+OUk2mpMbiHt6wFChm3X2qci3UJHsaOI52v3Fbu8VocMDWjGBtM9zdqc2/k/q7JI/0dkt/tpydh1W9TpeyQy0rwskvq6lzNcFwooeUW/uWUOgEkq1ut1Vb04iP8vCCac69XZVpyWqzs3R3qGVBvRm/hN1QambNos3qLiuruSVeuR89ztc0HyHaTffgT4U8243V/rCXsDuTsAutnxvR+GtcrKYhelxXGT1elqP+uMChg0FtOdKSM/hmbfXzO7CPJ99g+5tESiDF0HuDjUCVUkxZZDt9+Fdmc3eFQqUXNdhS0ftZBaRyXW0qgBXh2A69xx1LpboL38dnl2M8PcY4OJz92p4df31x6F+1hQ6E2lSMXB7Y5I8VJhU7AFOVoyeJKdoKPKzTOvvs+ki/1N/7D1G/t/9weWPv8r0F7Ps5h626hDrxPha7o6+ls7Yne2VvP0tTeWJ7mZqBPPWVPKxqpCw4gf7MEk+QhlA1Ltfb9CjQP9udKBoQj2woSsTzZ7jmr4kBGo13+/slO+XyWYb+f7B2Y9o8V83eFwtT/8MgssvZ3nW3afAvo8g7pmsf369ZichlC7FrngWY0cZbD0HJJ21JwjLoMQQMvXuj9d4ynGwLMX4hCmaI09iGs7EatrRvBuaXk18urMX1OTcOI7L6sS28vo9U+PTy0dLAleahGIE9+Xouae3upn0DSbO0S2dXXWrFXLqntbj4BK6oTmGAoo26SOKoAI9/WxZoIWn4SOO4mmiiqQRyBwuCGyY2LxSdevGevwsXMzQzKNwjP71f61SiF9yBmOY+mj2532MS4aplt3PWjKQllAYs3UxS+jVF9Em0lXmvu6TWGktlbIr6kx7Vlpwfytp7mgl24cWH0REQ3Qq5WQ9urlr5LnD7R3QDNNWUvTk4gqkGlWUSpqYkCLmzV3/8mL0SSDpM+MDc+IBIPMv+6Bny0L3KflrKz5wdI2DHCDbtqsCNJUpQNES7QMTeuUqcfcK8z5q0xbVZhNn7x7gG5xog31k+9llFBfwbFr6RRtzuNvuYHO4ivINTaShlcmnmwMViQ9swwikgeqZfq5ekfCJBssVtYg00dCGEpGuck9Y5S46vFvOUj8QwGgkcH1ADOhEaF79cYo+zfAnfE+JXitx3HCcJZOfmDOYR0hvxNJMrlssCrhbhBURFhY/COMZfu3iRWJ0STGwRryxMHPZ5ASeQeTK5a5EBKf5YY4EVpZZ6bO3xDWmy1Xoq2OKaUxHwjS2MssjMA2BRAyiEJZVrmeIaBtLZLlsWcSWgBg894i2WRSrii+XzOI4/67wBs3+AhmgELHaKToLlo/oX7IyoBjv5tgw43MWdQ6zcbfkXHyu9IOZ0wGcJKlPZI9cg8vKaRQiK1p/9VpMSsGpBq7qnRxIW64m+QpDXYpr0JqM+fvQ2JIfm6f8yF6vX4glPunhk/zMb5NlOCH3nBWTJTpi86xUzKGP7UkqTAKcpDAIkVIaEKW1SNKnD6KGuvGfooRsoynuTv0H9O/4KYdSdXbcNxvUavnkH6TuileJINH2AyJPyFKk2Qf5vDhdwGtWwti8+k3QVdOIOHxT4uQJmRB8fO4vwgjL+ycY3UP81Pb1isd7R6YHtCp5SKtfmc32epprK9IvkmKM6PrHAec3SzP/Haaatq2YJodJNVfu9nfPJVmNJnW71T2nSGY3GMxDM5Suw/t5vCku3lxAhUiSnTr/KMA9SvEO7X1KworAePvt53/WRs+2oXhyFK3I1ZsrWkfIzMm3egoKZhgHxAVAUkRunyRRVBa05tjEgQUFiS03rsDHZZLhwyJ+ID76PM+xnVpGK2yVXizZ1ypyp9GMdIGGTQQjx23VtnTd1mzu0dXxEbY4CEzy6fYWBeng5vLu48VngbisYtgSGAgk56vONvQCSxbMe8bYJMDpjZQWxHazDwAqW1uHXzvG7Mo2A7rCql8slhFcQJIcAjer8YhstQ7xuk3x8u4eAbchX9IY+CAiiGBGS0IEV0IEQxURJNlXtPBRmM07TIV/XAYGb0zNbRMKE2gAWJZnuLbnGIDbm6yO9y051vNV+yPPw9ao21Il52MJnFpN54n9UtoOJ/qeswWVLNywBfeMXOxe0w1qUMyWuP13UOdJL87eyVZea8MO67VjZFoWB/ppCSUCND5teiynyBbt8waswjukJ+C38raEUDcNjQuqu0drZqKLsdAn5GF0l8i7StIFCtqwoe+PbgmFk7RImj6E+bz2AnKcoszqsAgFmckk9HMYdOj6iEzuW3WZp94EyjHwY8/GSYp2XGabQzZ45ku7zJYcdPfKG2i4r9pxkNTOX2PLKhGJ+5GUJf1JvsJ4OfCQdZtKlOj//gLLWTzOSLK66MiRIkEkK1el7pG+Kr3+xTtR9u42ZW/x+7LbQTDZumZ5ki11VVMOXVPR+0huy8VCXmUNfgPymlzyua3N/JZmHY2Ckl1Ot3OMLOhm/pSkI/1s3l34S95k/4RP2GB3RQtdGnHyxOoSRokotuIO8Hu6dLuMceY6O6BYDrPurskGWK4n4hHqDsttQ4HlNBY97nUupm5I+MZzYLrt8sVI2TYl2cKrS4KJxrXYEzZeETnJc38yJ04u3eORtHeUZ//fHFEsfl+nqYtEMWT9M9QRRbRZBVEmBYyKqKsKn5VptaS8eVK4HCks3daALoqId0xqSBDulY8ZhPc8gCNb+nE1RjBVlZwgN/snMTZsVplFv617r57UeJRpCKOABmfQX0ENs5OReLwYyJJW9T4/Z76Okm5K5jekvo8OvzuIrTyhLUJPDGZNU8JSbYDL5NXGV1STV9Vk/8jBbHXh9p6RL9o23ZRkopB0Rljw/rjGBVes5y+Hp6QVawmolWDIRnf90eDrRR9ddX73mTp6D4GNubUajzxotkPvgZvvyycbKK5htzzyKkBhoCqGORv6aL03egJNdw122VtquNw1WC7Rdc1V0ROlUSPRAJzvw8Hd7fVXUZ7LJuwUKv6kEn7ck73A06sMSfWxr0NDFpIC4AxPzxWHpHzHZk9io6Vun7KEsyHW99RSwEdBhiFvI3he9spQmhSwPb7gptc/VvRiaX9DTAy8byo4EipIdioemQpiluB9U8GVUMF4cSrI9qS8Zyp4Eiq8eCHSfOFC5H5NoRXt2G0ldKtI+spDN0vcHF/UIetWP0yJge0DtMpyclhgDMZPHQZXIOWot+fmb61H2U6V8D7YsXc1dsOgfrxATXQMKyjpb0BBg2vx1E7BuKv3+Mi7q6R7pfy15T/T8KqxJVwPxVeWjdtdpR/6Cy77YVI9j+PiyofYhDE1gJBFfu4dHnuDIlSqmF8sjBSHkWHNE1VBp0zTMk3uw2AHhNvb1HXeNl3nuNxvj7Rju1yDZSTjaIZLjOJ0eSHiHRLTtjwFtOxSP2RR2a2jAWMqJfvbxIM9STwo20Z03HhQTPPyv4EJTj5eiz1A31x53uWQEj3ZPrqjolcsMSNyCWcwDkg3hE19F0ihXNJlgW/HQYrZG7ts0FV4XeuQXXs8YIm1sqT7EfCXy+fW8hXN2sCzptA+r2JSJp7UunFJhVngIUg1fAI36gynIT7PYh7C9X7lRTbrUEDfV/GeFssyslejO4+EcZdrPnLE+dt4rt9qwFGHArVgHOYa1LKJ8av9rJ/R15W/MYzCpYD1MSnPs/QDyHCFg1H3eg7LcstVNicVuDUevFP9rvF/KKX+ZQVXNXRmXDrME/QYGKicrYtnWzVqqh1z0qqGmi6BHq3bUqubjkemQ0cWqyVpPcYJ7EPJyKs4/Jcs293dxRmxCRn6hvXV+Gcy8XUFZIn8aGZNEOpxVMRykq/RwB9UvmoPv+rwEU6QaU1L3u8YjZ3MJK3TVMxMB3hq59SPi27EkjBzIhDPlIC+AkKx8xQpmfuaGufMC/VHt91qG5TCd8HGlFXlWIVcX18Wd90U0jfjOtMVSVd59Frv3eJfR+l7FOYV5qs0Ltmd4+FCjyicAbGlTXtlhDXJBWA9kTGFcyR28FvDXhyeYeVzeeO+NY98le23s3sy15rPe+/gW2MxTzAV1sEoCgXnV0mAg/7h/wA=</diagram></mxfile>
\ No newline at end of file diff --git a/plugins/websockets/doc/doc.go b/plugins/websockets/doc/doc.go new file mode 100644 index 00000000..fc214be8 --- /dev/null +++ b/plugins/websockets/doc/doc.go @@ -0,0 +1,27 @@ +package doc + +/* +RPC message structure: + +type Msg struct { + // Topic message been pushed into. + Topics_ []string `json:"topic"` + + // Command (join, leave, headers) + Command_ string `json:"command"` + + // Broker (redis, memory) + Broker_ string `json:"broker"` + + // Payload to be broadcasted + Payload_ []byte `json:"payload"` +} + +1. Topics - string array (slice) with topics to join or leave +2. Command - string, command to apply on the provided topics +3. Broker - string, pub-sub broker to use, for the one-node systems might be used `memory` broker or `redis`. For the multi-node - +`redis` broker should be used. +4. Payload - raw byte array to send to the subscribers (binary messages). + + +*/ diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go new file mode 100644 index 00000000..24ea19ce --- /dev/null +++ b/plugins/websockets/executor/executor.go @@ -0,0 +1,226 @@ +package executor + +import ( + "fmt" + "net/http" + "sync" + + "github.com/fasthttp/websocket" + json "github.com/json-iterator/go" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/websockets/commands" + "github.com/spiral/roadrunner/v2/plugins/websockets/connection" + "github.com/spiral/roadrunner/v2/plugins/websockets/storage" + "github.com/spiral/roadrunner/v2/plugins/websockets/validator" +) + +type Response struct { + Topic string `json:"topic"` + Payload []string `json:"payload"` +} + +type Executor struct { + sync.Mutex + conn *connection.Connection + storage *storage.Storage + log logger.Logger + + // associated connection ID + connID string + + // map with the pubsub drivers + pubsub map[string]pubsub.PubSub + actualTopics map[string]struct{} + + req *http.Request + accessValidator validator.AccessValidatorFn +} + +// NewExecutor creates protected connection and starts command loop +func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, + connID string, pubsubs map[string]pubsub.PubSub, av validator.AccessValidatorFn, r *http.Request) *Executor { + return &Executor{ + conn: conn, + connID: connID, + storage: bst, + log: log, + pubsub: pubsubs, + accessValidator: av, + actualTopics: make(map[string]struct{}, 10), + req: r, + } +} + +func (e *Executor) StartCommandLoop() error { //nolint:gocognit + const op = errors.Op("executor_command_loop") + for { + mt, data, err := e.conn.Read() + if err != nil { + if mt == -1 { + e.log.Info("socket was closed", "reason", err, "message type", mt) + return nil + } + + return errors.E(op, err) + } + + msg := &pubsub.Message{} + + err = json.Unmarshal(data, msg) + if err != nil { + e.log.Error("error unmarshal message", "error", err) + continue + } + + // nil message, continue + if msg == nil { + e.log.Warn("get nil message, skipping") + continue + } + + switch msg.Command { + // handle leave + case commands.Join: + e.log.Debug("get join command", "msg", msg) + + val, err := e.accessValidator(e.req, msg.Topics...) + if err != nil { + if val != nil { + e.log.Debug("validation error", "status", val.Status, "headers", val.Header, "body", val.Body) + } + + resp := &Response{ + Topic: "#join", + Payload: msg.Topics, + } + + packet, errJ := json.Marshal(resp) + if errJ != nil { + e.log.Error("error marshal the body", "error", errJ) + return errors.E(op, fmt.Errorf("%v,%v", err, errJ)) + } + + errW := e.conn.Write(websocket.BinaryMessage, packet) + if errW != nil { + e.log.Error("error writing payload to the connection", "payload", packet, "error", errW) + return errors.E(op, fmt.Errorf("%v,%v", err, errW)) + } + + continue + } + + resp := &Response{ + Topic: "@join", + Payload: msg.Topics, + } + + packet, err := json.Marshal(resp) + if err != nil { + e.log.Error("error marshal the body", "error", err) + return errors.E(op, err) + } + + err = e.conn.Write(websocket.BinaryMessage, packet) + if err != nil { + e.log.Error("error writing payload to the connection", "payload", packet, "error", err) + return errors.E(op, err) + } + + // subscribe to the topic + if br, ok := e.pubsub[msg.Broker]; ok { + err = e.Set(br, msg.Topics) + if err != nil { + return errors.E(op, err) + } + } + + // handle leave + case commands.Leave: + e.log.Debug("get leave command", "msg", msg) + + // prepare response + resp := &Response{ + Topic: "@leave", + Payload: msg.Topics, + } + + packet, err := json.Marshal(resp) + if err != nil { + e.log.Error("error marshal the body", "error", err) + return errors.E(op, err) + } + + err = e.conn.Write(websocket.BinaryMessage, packet) + if err != nil { + e.log.Error("error writing payload to the connection", "payload", packet, "error", err) + return errors.E(op, err) + } + + if br, ok := e.pubsub[msg.Broker]; ok { + err = e.Leave(br, msg.Topics) + if err != nil { + return errors.E(op, err) + } + } + + case commands.Headers: + + default: + e.log.Warn("unknown command", "command", msg.Command) + } + } +} + +func (e *Executor) Set(br pubsub.PubSub, topics []string) error { + // associate connection with topics + err := br.Subscribe(topics...) + if err != nil { + e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + // in case of error, unsubscribe connection from the dead topics + _ = br.Unsubscribe(topics...) + return err + } + + e.storage.InsertMany(e.connID, topics) + + // save topics for the connection + for i := 0; i < len(topics); i++ { + e.actualTopics[topics[i]] = struct{}{} + } + + return nil +} + +func (e *Executor) Leave(br pubsub.PubSub, topics []string) error { + // remove associated connections from the storage + e.storage.RemoveMany(e.connID, topics) + err := br.Unsubscribe(topics...) + if err != nil { + e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) + return err + } + + // remove topics for the connection + for i := 0; i < len(topics); i++ { + delete(e.actualTopics, topics[i]) + } + + return nil +} + +func (e *Executor) CleanUp() { + for topic := range e.actualTopics { + // remove from the bst + e.storage.Remove(e.connID, topic) + + for _, ps := range e.pubsub { + _ = ps.Unsubscribe(topic) + } + } + + for k := range e.actualTopics { + delete(e.actualTopics, k) + } +} diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go new file mode 100644 index 00000000..9b21ff8f --- /dev/null +++ b/plugins/websockets/plugin.go @@ -0,0 +1,386 @@ +package websockets + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/fasthttp/websocket" + "github.com/google/uuid" + json "github.com/json-iterator/go" + endure "github.com/spiral/endure/pkg/container" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/payload" + phpPool "github.com/spiral/roadrunner/v2/pkg/pool" + "github.com/spiral/roadrunner/v2/pkg/process" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/pkg/worker" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/spiral/roadrunner/v2/plugins/websockets/connection" + "github.com/spiral/roadrunner/v2/plugins/websockets/executor" + "github.com/spiral/roadrunner/v2/plugins/websockets/pool" + "github.com/spiral/roadrunner/v2/plugins/websockets/storage" + "github.com/spiral/roadrunner/v2/plugins/websockets/validator" +) + +const ( + PluginName string = "websockets" +) + +type Plugin struct { + sync.RWMutex + // Collection with all available pubsubs + pubsubs map[string]pubsub.PubSub + + cfg *Config + log logger.Logger + + // global connections map + connections sync.Map + storage *storage.Storage + + // GO workers pool + workersPool *pool.WorkersPool + + wsUpgrade *websocket.Upgrader + serveExit chan struct{} + + phpPool phpPool.Pool + server server.Server + + // function used to validate access to the requested resource + accessValidator validator.AccessValidatorFn +} + +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { + const op = errors.Op("websockets_plugin_init") + if !cfg.Has(PluginName) { + return errors.E(op, errors.Disabled) + } + + err := cfg.UnmarshalKey(PluginName, &p.cfg) + if err != nil { + return errors.E(op, err) + } + + p.cfg.InitDefault() + + p.pubsubs = make(map[string]pubsub.PubSub) + p.log = log + p.storage = storage.NewStorage() + p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log) + p.wsUpgrade = &websocket.Upgrader{ + HandshakeTimeout: time.Second * 60, + } + p.serveExit = make(chan struct{}) + p.server = server + + return nil +} + +func (p *Plugin) Serve() chan error { + errCh := make(chan error) + + go func() { + var err error + p.Lock() + defer p.Unlock() + + p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{ + Debug: p.cfg.Pool.Debug, + NumWorkers: p.cfg.Pool.NumWorkers, + MaxJobs: p.cfg.Pool.MaxJobs, + AllocateTimeout: p.cfg.Pool.AllocateTimeout, + DestroyTimeout: p.cfg.Pool.DestroyTimeout, + Supervisor: p.cfg.Pool.Supervisor, + }, map[string]string{"RR_MODE": "http"}) + if err != nil { + errCh <- err + } + + p.accessValidator = p.defaultAccessValidator(p.phpPool) + }() + + // run all pubsubs drivers + for _, v := range p.pubsubs { + go func(ps pubsub.PubSub) { + for { + select { + case <-p.serveExit: + return + default: + data, err := ps.Next() + if err != nil { + errCh <- err + return + } + p.workersPool.Queue(data) + } + } + }(v) + } + + return errCh +} + +func (p *Plugin) Stop() error { + // close workers pool + p.workersPool.Stop() + p.Lock() + p.phpPool.Destroy(context.Background()) + p.Unlock() + return nil +} + +func (p *Plugin) Collects() []interface{} { + return []interface{}{ + p.GetPublishers, + } +} + +func (p *Plugin) Available() {} + +func (p *Plugin) RPC() interface{} { + return &rpc{ + plugin: p, + log: p.log, + } +} + +func (p *Plugin) Name() string { + return PluginName +} + +// GetPublishers collects all pubsubs +func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) { + p.pubsubs[name.Name()] = pub +} + +func (p *Plugin) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != p.cfg.Path { + next.ServeHTTP(w, r) + return + } + + // we need to lock here, because accessValidator might not be set in the Serve func at the moment + p.RLock() + // before we hijacked connection, we still can write to the response headers + val, err := p.accessValidator(r) + p.RUnlock() + if err != nil { + p.log.Error("validation error") + w.WriteHeader(400) + return + } + + if val.Status != http.StatusOK { + for k, v := range val.Header { + for i := 0; i < len(v); i++ { + w.Header().Add(k, v[i]) + } + } + w.WriteHeader(val.Status) + _, _ = w.Write(val.Body) + return + } + + // upgrade connection to websocket connection + _conn, err := p.wsUpgrade.Upgrade(w, r, nil) + if err != nil { + // connection hijacked, do not use response.writer or request + p.log.Error("upgrade connection", "error", err) + return + } + + // construct safe connection protected by mutexes + safeConn := connection.NewConnection(_conn, p.log) + // generate UUID from the connection + connectionID := uuid.NewString() + // store connection + p.connections.Store(connectionID, safeConn) + + defer func() { + // close the connection on exit + err = safeConn.Close() + if err != nil { + p.log.Error("connection close", "error", err) + } + + // when exiting - delete the connection + p.connections.Delete(connectionID) + }() + + // Executor wraps a connection to have a safe abstraction + e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r) + p.log.Info("websocket client connected", "uuid", connectionID) + defer e.CleanUp() + + err = e.StartCommandLoop() + if err != nil { + p.log.Error("command loop error, disconnecting", "error", err.Error()) + return + } + + p.log.Info("disconnected", "connectionID", connectionID) + }) +} + +// Workers returns slice with the process states for the workers +func (p *Plugin) Workers() []process.State { + p.RLock() + defer p.RUnlock() + + workers := p.workers() + + ps := make([]process.State, 0, len(workers)) + for i := 0; i < len(workers); i++ { + state, err := process.WorkerProcessState(workers[i]) + if err != nil { + return nil + } + ps = append(ps, state) + } + + return ps +} + +// internal +func (p *Plugin) workers() []worker.BaseProcess { + return p.phpPool.Workers() +} + +// Reset destroys the old pool and replaces it with new one, waiting for old pool to die +func (p *Plugin) Reset() error { + p.Lock() + defer p.Unlock() + const op = errors.Op("ws_plugin_reset") + p.log.Info("WS plugin got restart request. Restarting...") + p.phpPool.Destroy(context.Background()) + p.phpPool = nil + + var err error + p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{ + Debug: p.cfg.Pool.Debug, + NumWorkers: p.cfg.Pool.NumWorkers, + MaxJobs: p.cfg.Pool.MaxJobs, + AllocateTimeout: p.cfg.Pool.AllocateTimeout, + DestroyTimeout: p.cfg.Pool.DestroyTimeout, + Supervisor: p.cfg.Pool.Supervisor, + }, map[string]string{"RR_MODE": "http"}) + if err != nil { + return errors.E(op, err) + } + + // attach validators + p.accessValidator = p.defaultAccessValidator(p.phpPool) + + p.log.Info("WS plugin successfully restarted") + return nil +} + +// Publish is an entry point to the websocket PUBSUB +func (p *Plugin) Publish(msg []*pubsub.Message) error { + p.Lock() + defer p.Unlock() + + for i := 0; i < len(msg); i++ { + for j := 0; j < len(msg[i].Topics); j++ { + if br, ok := p.pubsubs[msg[i].Broker]; ok { + err := br.Publish(msg) + if err != nil { + return errors.E(err) + } + } else { + p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker) + } + } + } + return nil +} + +func (p *Plugin) PublishAsync(msg []*pubsub.Message) { + go func() { + p.Lock() + defer p.Unlock() + for i := 0; i < len(msg); i++ { + for j := 0; j < len(msg[i].Topics); j++ { + err := p.pubsubs[msg[i].Broker].Publish(msg) + if err != nil { + p.log.Error("publish async error", "error", err) + return + } + } + } + }() +} + +func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn { + return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) { + p.RLock() + defer p.RUnlock() + const op = errors.Op("access_validator") + + p.log.Debug("validation", "topics", topics) + r = attributes.Init(r) + + // if channels len is eq to 0, we use serverValidator + if len(topics) == 0 { + ctx, err := validator.ServerAccessValidator(r) + if err != nil { + return nil, errors.E(op, err) + } + + val, err := exec(ctx, pool) + if err != nil { + return nil, errors.E(err) + } + + return val, nil + } + + ctx, err := validator.TopicsAccessValidator(r, topics...) + if err != nil { + return nil, errors.E(op, err) + } + + val, err := exec(ctx, pool) + if err != nil { + return nil, errors.E(op) + } + + if val.Status != http.StatusOK { + return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status)) + } + + return val, nil + } +} + +// go:inline +func exec(ctx []byte, pool phpPool.Pool) (*validator.AccessValidator, error) { + const op = errors.Op("exec") + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + return val, nil +} diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go new file mode 100644 index 00000000..8f18580f --- /dev/null +++ b/plugins/websockets/pool/workers_pool.go @@ -0,0 +1,117 @@ +package pool + +import ( + "sync" + + "github.com/fasthttp/websocket" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/websockets/connection" + "github.com/spiral/roadrunner/v2/plugins/websockets/storage" +) + +type WorkersPool struct { + storage *storage.Storage + connections *sync.Map + resPool sync.Pool + log logger.Logger + + queue chan *pubsub.Message + exit chan struct{} +} + +// NewWorkersPool constructs worker pool for the websocket connections +func NewWorkersPool(storage *storage.Storage, connections *sync.Map, log logger.Logger) *WorkersPool { + wp := &WorkersPool{ + connections: connections, + queue: make(chan *pubsub.Message, 100), + storage: storage, + log: log, + exit: make(chan struct{}), + } + + wp.resPool.New = func() interface{} { + return make(map[string]struct{}, 10) + } + + // start 10 workers + for i := 0; i < 10; i++ { + wp.do() + } + + return wp +} + +func (wp *WorkersPool) Queue(msg *pubsub.Message) { + wp.queue <- msg +} + +func (wp *WorkersPool) Stop() { + for i := 0; i < 10; i++ { + wp.exit <- struct{}{} + } + + close(wp.exit) +} + +func (wp *WorkersPool) put(res map[string]struct{}) { + // optimized + // https://go-review.googlesource.com/c/go/+/110055/ + // not O(n), but O(1) + for k := range res { + delete(res, k) + } +} + +func (wp *WorkersPool) get() map[string]struct{} { + return wp.resPool.Get().(map[string]struct{}) +} + +func (wp *WorkersPool) do() { //nolint:gocognit + go func() { + for { + select { + case msg, ok := <-wp.queue: + if !ok { + return + } + // do not handle nil's + if msg == nil { + continue + } + if len(msg.Topics) == 0 { + continue + } + res := wp.get() + // get connections for the particular topic + wp.storage.GetByPtr(msg.Topics, res) + if len(res) == 0 { + wp.log.Info("no such topic", "topic", msg.Topics) + wp.put(res) + continue + } + + for i := range res { + c, ok := wp.connections.Load(i) + if !ok { + wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.Broker, "topics", msg.Topics) + continue + } + + conn := c.(*connection.Connection) + err := conn.Write(websocket.BinaryMessage, msg.Payload) + if err != nil { + wp.log.Error("error sending payload over the connection", "broker", msg.Broker, "topics", msg.Topics) + wp.put(res) + continue + } + } + + wp.put(res) + case <-wp.exit: + wp.log.Info("get exit signal, exiting from the workers pool") + return + } + } + }() +} diff --git a/plugins/websockets/rpc.go b/plugins/websockets/rpc.go new file mode 100644 index 00000000..2fb0f1b9 --- /dev/null +++ b/plugins/websockets/rpc.go @@ -0,0 +1,47 @@ +package websockets + +import ( + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// rpc collectors struct +type rpc struct { + plugin *Plugin + log logger.Logger +} + +func (r *rpc) Publish(msg []*pubsub.Message, ok *bool) error { + const op = errors.Op("broadcast_publish") + r.log.Debug("message published", "msg", msg) + + // just return in case of nil message + if msg == nil { + *ok = true + return nil + } + + err := r.plugin.Publish(msg) + if err != nil { + *ok = false + return errors.E(op, err) + } + *ok = true + return nil +} + +func (r *rpc) PublishAsync(msg []*pubsub.Message, ok *bool) error { + r.log.Debug("message published", "msg", msg) + + // just return in case of nil message + if msg == nil { + *ok = true + return nil + } + // publish to the registered broker + r.plugin.PublishAsync(msg) + + *ok = true + return nil +} diff --git a/plugins/websockets/schema/message.fbs b/plugins/websockets/schema/message.fbs new file mode 100644 index 00000000..f2d92c78 --- /dev/null +++ b/plugins/websockets/schema/message.fbs @@ -0,0 +1,10 @@ +namespace message; + +table Message { + command:string; + broker:string; + topics:[string]; + payload:[byte]; +} + +root_type Message; diff --git a/plugins/websockets/schema/message/Message.go b/plugins/websockets/schema/message/Message.go new file mode 100644 index 00000000..26bbd12c --- /dev/null +++ b/plugins/websockets/schema/message/Message.go @@ -0,0 +1,118 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package message + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type Message struct { + _tab flatbuffers.Table +} + +func GetRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Message{} + x.Init(buf, n+offset) + return x +} + +func GetSizePrefixedRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &Message{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func (rcv *Message) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Message) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Message) Command() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Message) Broker() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Message) Topics(j int) []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j*4)) + } + return nil +} + +func (rcv *Message) TopicsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Message) Payload(j int) int8 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt8(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Message) PayloadLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Message) MutatePayload(j int, n int8) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateInt8(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func MessageStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func MessageAddCommand(builder *flatbuffers.Builder, command flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(command), 0) +} +func MessageAddBroker(builder *flatbuffers.Builder, broker flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(broker), 0) +} +func MessageAddTopics(builder *flatbuffers.Builder, topics flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(topics), 0) +} +func MessageStartTopicsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MessageAddPayload(builder *flatbuffers.Builder, payload flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(payload), 0) +} +func MessageStartPayloadVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MessageEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/plugins/websockets/storage/storage.go b/plugins/websockets/storage/storage.go new file mode 100644 index 00000000..ac256be2 --- /dev/null +++ b/plugins/websockets/storage/storage.go @@ -0,0 +1,79 @@ +package storage + +import ( + "sync" + + "github.com/spiral/roadrunner/v2/pkg/bst" +) + +type Storage struct { + sync.RWMutex + BST bst.Storage +} + +func NewStorage() *Storage { + return &Storage{ + BST: bst.NewBST(), + } +} + +func (s *Storage) InsertMany(connID string, topics []string) { + s.Lock() + defer s.Unlock() + + for i := 0; i < len(topics); i++ { + s.BST.Insert(connID, topics[i]) + } +} + +func (s *Storage) Insert(connID string, topic string) { + s.Lock() + defer s.Unlock() + + s.BST.Insert(connID, topic) +} + +func (s *Storage) RemoveMany(connID string, topics []string) { + s.Lock() + defer s.Unlock() + + for i := 0; i < len(topics); i++ { + s.BST.Remove(connID, topics[i]) + } +} + +func (s *Storage) Remove(connID string, topic string) { + s.Lock() + defer s.Unlock() + + s.BST.Remove(connID, topic) +} + +// GetByPtrTS Thread safe get +func (s *Storage) GetByPtrTS(topics []string, res map[string]struct{}) { + s.Lock() + defer s.Unlock() + + for i := 0; i < len(topics); i++ { + d := s.BST.Get(topics[i]) + if len(d) > 0 { + for ii := range d { + res[ii] = struct{}{} + } + } + } +} + +func (s *Storage) GetByPtr(topics []string, res map[string]struct{}) { + s.RLock() + defer s.RUnlock() + + for i := 0; i < len(topics); i++ { + d := s.BST.Get(topics[i]) + if len(d) > 0 { + for ii := range d { + res[ii] = struct{}{} + } + } + } +} diff --git a/plugins/websockets/storage/storage_test.go b/plugins/websockets/storage/storage_test.go new file mode 100644 index 00000000..4072992a --- /dev/null +++ b/plugins/websockets/storage/storage_test.go @@ -0,0 +1,299 @@ +package storage + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +const predifined = "chat-1-2" + +func TestNewBST(t *testing.T) { + // create a new bst + g := NewStorage() + + for i := 0; i < 100; i++ { + g.InsertMany(uuid.NewString(), []string{"comments"}) + } + + for i := 0; i < 100; i++ { + g.InsertMany(uuid.NewString(), []string{"comments2"}) + } + + for i := 0; i < 100; i++ { + g.InsertMany(uuid.NewString(), []string{"comments3"}) + } + + res := make(map[string]struct{}, 100) + assert.Len(t, res, 0) + + // should be 100 + g.GetByPtr([]string{"comments"}, res) + assert.Len(t, res, 100) + + res = make(map[string]struct{}, 100) + assert.Len(t, res, 0) + + // should be 100 + g.GetByPtr([]string{"comments2"}, res) + assert.Len(t, res, 100) + + res = make(map[string]struct{}, 100) + assert.Len(t, res, 0) + + // should be 100 + g.GetByPtr([]string{"comments3"}, res) + assert.Len(t, res, 100) +} + +func BenchmarkGraph(b *testing.B) { + g := NewStorage() + + for i := 0; i < 1000; i++ { + uid := uuid.New().String() + g.InsertMany(uuid.NewString(), []string{uid}) + } + + g.Insert(uuid.NewString(), predifined) + + b.ResetTimer() + b.ReportAllocs() + + res := make(map[string]struct{}) + + for i := 0; i < b.N; i++ { + g.GetByPtr([]string{predifined}, res) + + for i := range res { + delete(res, i) + } + } +} + +func BenchmarkBigSearch(b *testing.B) { + g1 := NewStorage() + + predefinedSlice := make([]string, 0, 1000) + for i := 0; i < 1000; i++ { + predefinedSlice = append(predefinedSlice, uuid.NewString()) + } + if predefinedSlice == nil { + b.FailNow() + } + + for i := 0; i < 1000; i++ { + g1.Insert(uuid.NewString(), uuid.NewString()) + } + + for i := 0; i < 1000; i++ { + g1.Insert(uuid.NewString(), predefinedSlice[i]) + } + + b.ResetTimer() + b.ReportAllocs() + + res := make(map[string]struct{}, 333) + + for i := 0; i < b.N; i++ { + g1.GetByPtr(predefinedSlice, res) + + for i := range res { + delete(res, i) + } + } +} + +func BenchmarkBigSearchWithRemoves(b *testing.B) { + g1 := NewStorage() + + predefinedSlice := make([]string, 0, 1000) + for i := 0; i < 1000; i++ { + predefinedSlice = append(predefinedSlice, uuid.NewString()) + } + if predefinedSlice == nil { + b.FailNow() + } + + for i := 0; i < 1000; i++ { + g1.Insert(uuid.NewString(), uuid.NewString()) + } + + for i := 0; i < 1000; i++ { + g1.Insert(uuid.NewString(), predefinedSlice[i]) + } + + b.ResetTimer() + b.ReportAllocs() + + go func() { + tt := time.NewTicker(time.Microsecond) + + res := make(map[string]struct{}, 1000) + for { + select { + case <-tt.C: + num := rand.Intn(1000) //nolint:gosec + g1.GetByPtr(predefinedSlice, res) + for k := range res { + g1.Remove(k, predefinedSlice[num]) + } + } + } + }() + + res := make(map[string]struct{}, 100) + + for i := 0; i < b.N; i++ { + g1.GetByPtr(predefinedSlice, res) + + for i := range res { + delete(res, i) + } + } +} + +func TestBigSearchWithRemoves(t *testing.T) { + g1 := NewStorage() + + predefinedSlice := make([]string, 0, 1000) + for i := 0; i < 1000; i++ { + predefinedSlice = append(predefinedSlice, uuid.NewString()) + } + if predefinedSlice == nil { + t.FailNow() + } + + for i := 0; i < 1000; i++ { + g1.Insert(uuid.NewString(), uuid.NewString()) + } + + for i := 0; i < 1000; i++ { + g1.Insert(uuid.NewString(), predefinedSlice[i]) + } + + stopCh := make(chan struct{}) + + go func() { + tt := time.NewTicker(time.Microsecond) + + res := make(map[string]struct{}, 1000) + for { + select { + case <-tt.C: + num := rand.Intn(1000) //nolint:gosec + g1.GetByPtr(predefinedSlice, res) + for k := range res { + g1.Remove(k, predefinedSlice[num]) + } + + case <-stopCh: + tt.Stop() + return + } + } + }() + + res := make(map[string]struct{}, 100) + + for i := 0; i < 1000; i++ { + g1.GetByPtr(predefinedSlice, res) + + for i := range res { + delete(res, i) + } + } + + stopCh <- struct{}{} +} + +func TestGraph(t *testing.T) { + g := NewStorage() + + for i := 0; i < 1000; i++ { + uid := uuid.New().String() + g.Insert(uuid.NewString(), uid) + } + + g.Insert(uuid.NewString(), predifined) + + res := make(map[string]struct{}) + + g.GetByPtr([]string{predifined}, res) + assert.NotEmpty(t, res) + assert.Len(t, res, 1) +} + +func TestTreeConcurrentContains(t *testing.T) { + g := NewStorage() + + key1 := uuid.NewString() + key2 := uuid.NewString() + key3 := uuid.NewString() + key4 := uuid.NewString() + key5 := uuid.NewString() + + g.Insert(key1, predifined) + g.Insert(key2, predifined) + g.Insert(key3, predifined) + g.Insert(key4, predifined) + g.Insert(key5, predifined) + + res := make(map[string]struct{}, 100) + + for i := 0; i < 100; i++ { + go func() { + g.GetByPtrTS([]string{predifined}, res) + }() + + go func() { + g.GetByPtrTS([]string{predifined}, res) + }() + + go func() { + g.GetByPtrTS([]string{predifined}, res) + }() + + go func() { + g.GetByPtrTS([]string{predifined}, res) + }() + } + + time.Sleep(time.Second * 5) + + res2 := make(map[string]struct{}, 5) + + g.GetByPtr([]string{predifined}, res2) + assert.NotEmpty(t, res2) + assert.Len(t, res2, 5) +} + +func TestGraphRemove(t *testing.T) { + g := NewStorage() + + key1 := uuid.NewString() + key2 := uuid.NewString() + key3 := uuid.NewString() + key4 := uuid.NewString() + key5 := uuid.NewString() + + g.Insert(key1, predifined) + g.Insert(key2, predifined) + g.Insert(key3, predifined) + g.Insert(key4, predifined) + g.Insert(key5, predifined) + + res := make(map[string]struct{}, 5) + g.GetByPtr([]string{predifined}, res) + assert.NotEmpty(t, res) + assert.Len(t, res, 5) + + g.Remove(key1, predifined) + + res2 := make(map[string]struct{}, 4) + g.GetByPtr([]string{predifined}, res2) + assert.NotEmpty(t, res2) + assert.Len(t, res2, 4) +} diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go new file mode 100644 index 00000000..e666f846 --- /dev/null +++ b/plugins/websockets/validator/access_validator.go @@ -0,0 +1,76 @@ +package validator + +import ( + "net/http" + "strings" + + json "github.com/json-iterator/go" + "github.com/spiral/errors" + handler "github.com/spiral/roadrunner/v2/pkg/worker_handler" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" +) + +type AccessValidatorFn = func(r *http.Request, channels ...string) (*AccessValidator, error) + +type AccessValidator struct { + Header http.Header `json:"headers"` + Status int `json:"status"` + Body []byte +} + +func ServerAccessValidator(r *http.Request) ([]byte, error) { + const op = errors.Op("server_access_validator") + + err := attributes.Set(r, "ws:joinServer", true) + if err != nil { + return nil, errors.E(op, err) + } + + defer delete(attributes.All(r), "ws:joinServer") + + req := &handler.Request{ + RemoteAddr: handler.FetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: handler.URI(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), + } + + data, err := json.Marshal(req) + if err != nil { + return nil, errors.E(op, err) + } + + return data, nil +} + +func TopicsAccessValidator(r *http.Request, topics ...string) ([]byte, error) { + const op = errors.Op("topic_access_validator") + err := attributes.Set(r, "ws:joinTopics", strings.Join(topics, ",")) + if err != nil { + return nil, errors.E(op, err) + } + + defer delete(attributes.All(r), "ws:joinTopics") + + req := &handler.Request{ + RemoteAddr: handler.FetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: handler.URI(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), + } + + data, err := json.Marshal(req) + if err != nil { + return nil, errors.E(op, err) + } + + return data, nil +} |