summaryrefslogtreecommitdiff
path: root/plugins/websockets
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-06-18 01:06:16 +0300
committerValery Piashchynski <[email protected]>2021-06-18 01:06:16 +0300
commitfe7bb0fe758d573fe353df028257ed66c6eccf66 (patch)
tree74392f8e61e96c85f0d8b684cfc08e3fc3664ae9 /plugins/websockets
parent68ff941c4226074206ceed9c30bd95317aa0e9fc (diff)
- Rework main parts
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/websockets')
-rw-r--r--plugins/websockets/config.go16
-rw-r--r--plugins/websockets/executor/executor.go41
-rw-r--r--plugins/websockets/origin_test.go9
-rw-r--r--plugins/websockets/plugin.go140
-rw-r--r--plugins/websockets/pool/workers_pool.go20
5 files changed, 80 insertions, 146 deletions
diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go
index b1d5d0a8..933a12e0 100644
--- a/plugins/websockets/config.go
+++ b/plugins/websockets/config.go
@@ -4,6 +4,7 @@ import (
"strings"
"time"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/pkg/pool"
)
@@ -17,9 +18,9 @@ websockets:
// Config represents configuration for the ws plugin
type Config struct {
// http path for the websocket
- Path string `mapstructure:"path"`
-
+ Path string `mapstructure:"path"`
AllowedOrigin string `mapstructure:"allowed_origin"`
+ Broker string `mapstructure:"broker"`
// wildcard origin
allowedWOrigins []wildcard
@@ -31,11 +32,16 @@ type Config struct {
}
// InitDefault initialize default values for the ws config
-func (c *Config) InitDefault() {
+func (c *Config) InitDefault() error {
if c.Path == "" {
c.Path = "/ws"
}
+ // broker is mandatory
+ if c.Broker == "" {
+ return errors.Str("broker key should be specified")
+ }
+
if c.Pool == nil {
c.Pool = &pool.Config{}
if c.Pool.NumWorkers == 0 {
@@ -64,7 +70,7 @@ func (c *Config) InitDefault() {
if origin == "*" {
// If "*" is present in the list, turn the whole list into a match all
c.allowedAll = true
- return
+ return nil
} else if i := strings.IndexByte(origin, '*'); i >= 0 {
// Split the origin in two: start and end string without the *
w := wildcard{origin[0:i], origin[i+1:]}
@@ -72,4 +78,6 @@ func (c *Config) InitDefault() {
} else {
c.allowedOrigins = append(c.allowedOrigins, origin)
}
+
+ return nil
}
diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go
index 07f22043..799312ad 100644
--- a/plugins/websockets/executor/executor.go
+++ b/plugins/websockets/executor/executor.go
@@ -7,8 +7,8 @@ import (
json "github.com/json-iterator/go"
"github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/interface/pubsub"
websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta"
- "github.com/spiral/roadrunner/v2/pkg/pubsub"
"github.com/spiral/roadrunner/v2/plugins/logger"
"github.com/spiral/roadrunner/v2/plugins/websockets/commands"
"github.com/spiral/roadrunner/v2/plugins/websockets/connection"
@@ -28,8 +28,8 @@ type Executor struct {
// associated connection ID
connID string
- // map with the pubsub drivers
- pubsub map[string]pubsub.Subscriber
+ // subscriber drivers
+ sub pubsub.Subscriber
actualTopics map[string]struct{}
req *http.Request
@@ -38,12 +38,12 @@ type Executor struct {
// NewExecutor creates protected connection and starts command loop
func NewExecutor(conn *connection.Connection, log logger.Logger,
- connID string, pubsubs map[string]pubsub.Subscriber, av validator.AccessValidatorFn, r *http.Request) *Executor {
+ connID string, sub pubsub.Subscriber, av validator.AccessValidatorFn, r *http.Request) *Executor {
return &Executor{
conn: conn,
connID: connID,
log: log,
- pubsub: pubsubs,
+ sub: sub,
accessValidator: av,
actualTopics: make(map[string]struct{}, 10),
req: r,
@@ -126,11 +126,9 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
}
// subscribe to the topic
- if br, ok := e.pubsub[msg.Broker]; ok {
- err = e.Set(br, msg.Topics)
- if err != nil {
- return errors.E(op, err)
- }
+ err = e.Set(msg.Topics)
+ if err != nil {
+ return errors.E(op, err)
}
// handle leave
@@ -155,11 +153,9 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
return errors.E(op, err)
}
- if br, ok := e.pubsub[msg.Broker]; ok {
- err = e.Leave(br, msg.Topics)
- if err != nil {
- return errors.E(op, err)
- }
+ err = e.Leave(msg.Topics)
+ if err != nil {
+ return errors.E(op, err)
}
case commands.Headers:
@@ -170,13 +166,13 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit
}
}
-func (e *Executor) Set(br pubsub.Subscriber, topics []string) error {
+func (e *Executor) Set(topics []string) error {
// associate connection with topics
- err := br.Subscribe(e.connID, topics...)
+ err := e.sub.Subscribe(e.connID, topics...)
if err != nil {
e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
// in case of error, unsubscribe connection from the dead topics
- _ = br.Unsubscribe(e.connID, topics...)
+ _ = e.sub.Unsubscribe(e.connID, topics...)
return err
}
@@ -188,9 +184,9 @@ func (e *Executor) Set(br pubsub.Subscriber, topics []string) error {
return nil
}
-func (e *Executor) Leave(br pubsub.Subscriber, topics []string) error {
+func (e *Executor) Leave(topics []string) error {
// remove associated connections from the storage
- err := br.Unsubscribe(e.connID, topics...)
+ err := e.sub.Unsubscribe(e.connID, topics...)
if err != nil {
e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error())
return err
@@ -207,10 +203,7 @@ func (e *Executor) Leave(br pubsub.Subscriber, topics []string) error {
func (e *Executor) CleanUp() {
// unsubscribe particular connection from the topics
for topic := range e.actualTopics {
- // here
- for _, ps := range e.pubsub {
- _ = ps.Unsubscribe(e.connID, topic)
- }
+ _ = e.sub.Unsubscribe(e.connID, topic)
}
// clean up the actualTopics data
diff --git a/plugins/websockets/origin_test.go b/plugins/websockets/origin_test.go
index e877fad3..ec6e1960 100644
--- a/plugins/websockets/origin_test.go
+++ b/plugins/websockets/origin_test.go
@@ -11,7 +11,8 @@ func TestConfig_Origin(t *testing.T) {
AllowedOrigin: "*",
}
- cfg.InitDefault()
+ err := cfg.InitDefault()
+ assert.NoError(t, err)
assert.True(t, isOriginAllowed("http://some.some.some.sssome", cfg))
assert.True(t, isOriginAllowed("http://", cfg))
@@ -29,7 +30,8 @@ func TestConfig_OriginWildCard(t *testing.T) {
AllowedOrigin: "https://*my.site.com",
}
- cfg.InitDefault()
+ err := cfg.InitDefault()
+ assert.NoError(t, err)
assert.True(t, isOriginAllowed("https://my.site.com", cfg))
assert.False(t, isOriginAllowed("http://", cfg))
@@ -50,7 +52,8 @@ func TestConfig_OriginWildCard2(t *testing.T) {
AllowedOrigin: "https://my.*.com",
}
- cfg.InitDefault()
+ err := cfg.InitDefault()
+ assert.NoError(t, err)
assert.True(t, isOriginAllowed("https://my.site.com", cfg))
assert.False(t, isOriginAllowed("http://", cfg))
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
index cf861c72..de7443fd 100644
--- a/plugins/websockets/plugin.go
+++ b/plugins/websockets/plugin.go
@@ -9,13 +9,12 @@ import (
"github.com/fasthttp/websocket"
"github.com/google/uuid"
json "github.com/json-iterator/go"
- endure "github.com/spiral/endure/pkg/container"
"github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/pkg/interface/broadcast"
+ "github.com/spiral/roadrunner/v2/pkg/interface/pubsub"
"github.com/spiral/roadrunner/v2/pkg/payload"
phpPool "github.com/spiral/roadrunner/v2/pkg/pool"
"github.com/spiral/roadrunner/v2/pkg/process"
- "github.com/spiral/roadrunner/v2/pkg/pubsub"
"github.com/spiral/roadrunner/v2/pkg/worker"
"github.com/spiral/roadrunner/v2/plugins/config"
"github.com/spiral/roadrunner/v2/plugins/http/attributes"
@@ -33,16 +32,14 @@ const (
type Plugin struct {
sync.RWMutex
- // Collection with all available pubsubs
- //pubsubs map[string]pubsub.PubSub
- //psProviders map[string]pubsub.PSProvider
+ // subscriber+reader interfaces
+ subReader pubsub.SubReader
+ // broadcaster
+ broadcaster broadcast.Broadcaster
- subReaders map[string]pubsub.SubReader
-
- cfg *Config
- cfgPlugin config.Configurer
- log logger.Logger
+ cfg *Config
+ log logger.Logger
// global connections map
connections sync.Map
@@ -53,8 +50,10 @@ type Plugin struct {
wsUpgrade *websocket.Upgrader
serveExit chan struct{}
+ // workers pool
phpPool phpPool.Pool
- server server.Server
+ // server which produces commands to the pool
+ server server.Server
// function used to validate access to the requested resource
accessValidator validator.AccessValidatorFn
@@ -71,14 +70,10 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se
return errors.E(op, err)
}
- p.cfg.InitDefault()
- //p.pubsubs = make(map[string]pubsub.PubSub)
- //p.psProviders = make(map[string]pubsub.PSProvider)
-
- p.subReaders = make(map[string]pubsub.SubReader)
-
- p.log = log
- p.cfgPlugin = cfg
+ err = p.cfg.InitDefault()
+ if err != nil {
+ return errors.E(op, err)
+ }
p.wsUpgrade = &websocket.Upgrader{
HandshakeTimeout: time.Second * 60,
@@ -90,19 +85,21 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se
}
p.serveExit = make(chan struct{})
p.server = server
-
+ p.log = log
+ p.broadcaster = b
return nil
}
func (p *Plugin) Serve() chan error {
- errCh := make(chan error, 1)
const op = errors.Op("websockets_plugin_serve")
-
- //err := p.initPubSubs()
- //if err != nil {
- // errCh <- errors.E(op, err)
- // return errCh
- //}
+ errCh := make(chan error, 1)
+ // init broadcaster
+ var err error
+ p.subReader, err = p.broadcaster.GetDriver(p.cfg.Broker)
+ if err != nil {
+ errCh <- errors.E(op, err)
+ return errCh
+ }
go func() {
var err error
@@ -124,78 +121,28 @@ func (p *Plugin) Serve() chan error {
p.accessValidator = p.defaultAccessValidator(p.phpPool)
}()
- p.workersPool = pool.NewWorkersPool(p.subReaders, &p.connections, p.log)
+ p.workersPool = pool.NewWorkersPool(p.subReader, &p.connections, p.log)
// run all pubsubs drivers
- for _, v := range p.subReaders {
- go func(ps pubsub.SubReader) {
- for {
- select {
- case <-p.serveExit:
+ go func(ps pubsub.Reader) {
+ for {
+ select {
+ case <-p.serveExit:
+ return
+ default:
+ data, err := ps.Next()
+ if err != nil {
+ errCh <- err
return
- default:
- data, err := ps.Next()
- if err != nil {
- errCh <- err
- return
- }
- p.workersPool.Queue(data)
}
+ p.workersPool.Queue(data)
}
- }(v)
- }
+ }
+ }(p.subReader)
return errCh
}
-//func (p *Plugin) initPubSubs() error {
-// for i := 0; i < len(p.cfg.PubSubs); i++ {
-// // don't need to have a section for the in-memory
-// if p.cfg.PubSubs[i] == "memory" {
-// if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok {
-// r, err := provider.PSProvide("")
-// if err != nil {
-// return err
-// }
-//
-// // append default in-memory provider
-// p.pubsubs["memory"] = r
-// }
-// continue
-// }
-// // key - memory, redis
-// if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok {
-// // try local key
-// switch {
-// // try local config first
-// case p.cfgPlugin.Has(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])):
-// r, err := provider.PSProvide(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i]))
-// if err != nil {
-// return err
-// }
-//
-// // append redis provider
-// p.pubsubs[p.cfg.PubSubs[i]] = r
-// case p.cfgPlugin.Has(p.cfg.PubSubs[i]):
-// r, err := provider.PSProvide(p.cfg.PubSubs[i])
-// if err != nil {
-// return err
-// }
-//
-// // append redis provider
-// p.pubsubs[p.cfg.PubSubs[i]] = r
-// default:
-// return errors.Errorf("could not find configuration sections for the %s", p.cfg.PubSubs[i])
-// }
-// } else {
-// // no such driver
-// p.log.Warn("no such driver", "requested", p.cfg.PubSubs[i], "available", p.psProviders)
-// }
-// }
-//
-// return nil
-//}
-
func (p *Plugin) Stop() error {
// close workers pool
p.workersPool.Stop()
@@ -210,23 +157,12 @@ func (p *Plugin) Stop() error {
return nil
}
-func (p *Plugin) Collects() []interface{} {
- return []interface{}{
- p.GetSubsReader,
- }
-}
-
func (p *Plugin) Available() {}
func (p *Plugin) Name() string {
return PluginName
}
-// GetSubsReader collects all plugins which implement SubReader interface
-func (p *Plugin) GetSubsReader(name endure.Named, pub pubsub.SubReader) {
- p.subReaders[name.Name()] = pub
-}
-
func (p *Plugin) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != p.cfg.Path {
@@ -272,7 +208,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
p.connections.Store(connectionID, safeConn)
// Executor wraps a connection to have a safe abstraction
- e := executor.NewExecutor(safeConn, p.log, connectionID, nil, p.accessValidator, r)
+ e := executor.NewExecutor(safeConn, p.log, connectionID, p.subReader, p.accessValidator, r)
p.log.Info("websocket client connected", "uuid", connectionID)
err = e.StartCommandLoop()
diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go
index 22042d8d..cd9444da 100644
--- a/plugins/websockets/pool/workers_pool.go
+++ b/plugins/websockets/pool/workers_pool.go
@@ -4,15 +4,15 @@ import (
"sync"
json "github.com/json-iterator/go"
+ "github.com/spiral/roadrunner/v2/pkg/interface/pubsub"
websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta"
- "github.com/spiral/roadrunner/v2/pkg/pubsub"
"github.com/spiral/roadrunner/v2/plugins/logger"
"github.com/spiral/roadrunner/v2/plugins/websockets/connection"
"github.com/spiral/roadrunner/v2/utils"
)
type WorkersPool struct {
- storage map[string]pubsub.SubReader
+ subscriber pubsub.Subscriber
connections *sync.Map
resPool sync.Pool
log logger.Logger
@@ -22,11 +22,11 @@ type WorkersPool struct {
}
// NewWorkersPool constructs worker pool for the websocket connections
-func NewWorkersPool(pubsubs map[string]pubsub.SubReader, connections *sync.Map, log logger.Logger) *WorkersPool {
+func NewWorkersPool(subscriber pubsub.Subscriber, connections *sync.Map, log logger.Logger) *WorkersPool {
wp := &WorkersPool{
connections: connections,
queue: make(chan *websocketsv1.Message, 100),
- storage: pubsubs,
+ subscriber: subscriber,
log: log,
exit: make(chan struct{}),
}
@@ -90,19 +90,13 @@ func (wp *WorkersPool) do() { //nolint:gocognit
continue
}
- br, ok := wp.storage[msg.Broker]
- if !ok {
- wp.log.Warn("no such broker", "requested", msg.GetBroker(), "available", wp.storage)
- continue
- }
-
// send a message to every topic
for i := 0; i < len(msg.GetTopics()); i++ {
// get free map
res := wp.get()
// get connections for the particular topic
- br.Connections(msg.GetTopics()[i], res)
+ wp.subscriber.Connections(msg.GetTopics()[i], res)
if len(res) == 0 {
wp.log.Info("no such topic", "topic", msg.GetTopics()[i])
@@ -114,7 +108,7 @@ func (wp *WorkersPool) do() { //nolint:gocognit
for topic := range res {
c, ok := wp.connections.Load(topic)
if !ok {
- wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.GetBroker(), "topics", msg.GetTopics()[i])
+ wp.log.Warn("the user disconnected connection before the message being written to it", "topics", msg.GetTopics()[i])
wp.put(res)
continue
}
@@ -135,7 +129,7 @@ func (wp *WorkersPool) do() { //nolint:gocognit
err = c.(*connection.Connection).Write(d)
if err != nil {
for i := 0; i < len(msg.GetTopics()); i++ {
- wp.log.Error("error sending payload over the connection", "error", err, "broker", msg.GetBroker(), "topics", msg.GetTopics()[i])
+ wp.log.Error("error sending payload over the connection", "error", err, "topics", msg.GetTopics()[i])
}
wp.put(res)
continue