diff options
author | Valery Piashchynski <[email protected]> | 2021-06-18 01:06:16 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2021-06-18 01:06:16 +0300 |
commit | fe7bb0fe758d573fe353df028257ed66c6eccf66 (patch) | |
tree | 74392f8e61e96c85f0d8b684cfc08e3fc3664ae9 /plugins/websockets | |
parent | 68ff941c4226074206ceed9c30bd95317aa0e9fc (diff) |
- Rework main parts
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/websockets')
-rw-r--r-- | plugins/websockets/config.go | 16 | ||||
-rw-r--r-- | plugins/websockets/executor/executor.go | 41 | ||||
-rw-r--r-- | plugins/websockets/origin_test.go | 9 | ||||
-rw-r--r-- | plugins/websockets/plugin.go | 140 | ||||
-rw-r--r-- | plugins/websockets/pool/workers_pool.go | 20 |
5 files changed, 80 insertions, 146 deletions
diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go index b1d5d0a8..933a12e0 100644 --- a/plugins/websockets/config.go +++ b/plugins/websockets/config.go @@ -4,6 +4,7 @@ import ( "strings" "time" + "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pool" ) @@ -17,9 +18,9 @@ websockets: // Config represents configuration for the ws plugin type Config struct { // http path for the websocket - Path string `mapstructure:"path"` - + Path string `mapstructure:"path"` AllowedOrigin string `mapstructure:"allowed_origin"` + Broker string `mapstructure:"broker"` // wildcard origin allowedWOrigins []wildcard @@ -31,11 +32,16 @@ type Config struct { } // InitDefault initialize default values for the ws config -func (c *Config) InitDefault() { +func (c *Config) InitDefault() error { if c.Path == "" { c.Path = "/ws" } + // broker is mandatory + if c.Broker == "" { + return errors.Str("broker key should be specified") + } + if c.Pool == nil { c.Pool = &pool.Config{} if c.Pool.NumWorkers == 0 { @@ -64,7 +70,7 @@ func (c *Config) InitDefault() { if origin == "*" { // If "*" is present in the list, turn the whole list into a match all c.allowedAll = true - return + return nil } else if i := strings.IndexByte(origin, '*'); i >= 0 { // Split the origin in two: start and end string without the * w := wildcard{origin[0:i], origin[i+1:]} @@ -72,4 +78,6 @@ func (c *Config) InitDefault() { } else { c.allowedOrigins = append(c.allowedOrigins, origin) } + + return nil } diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index 07f22043..799312ad 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -7,8 +7,8 @@ import ( json "github.com/json-iterator/go" "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" - "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" @@ -28,8 +28,8 @@ type Executor struct { // associated connection ID connID string - // map with the pubsub drivers - pubsub map[string]pubsub.Subscriber + // subscriber drivers + sub pubsub.Subscriber actualTopics map[string]struct{} req *http.Request @@ -38,12 +38,12 @@ type Executor struct { // NewExecutor creates protected connection and starts command loop func NewExecutor(conn *connection.Connection, log logger.Logger, - connID string, pubsubs map[string]pubsub.Subscriber, av validator.AccessValidatorFn, r *http.Request) *Executor { + connID string, sub pubsub.Subscriber, av validator.AccessValidatorFn, r *http.Request) *Executor { return &Executor{ conn: conn, connID: connID, log: log, - pubsub: pubsubs, + sub: sub, accessValidator: av, actualTopics: make(map[string]struct{}, 10), req: r, @@ -126,11 +126,9 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit } // subscribe to the topic - if br, ok := e.pubsub[msg.Broker]; ok { - err = e.Set(br, msg.Topics) - if err != nil { - return errors.E(op, err) - } + err = e.Set(msg.Topics) + if err != nil { + return errors.E(op, err) } // handle leave @@ -155,11 +153,9 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit return errors.E(op, err) } - if br, ok := e.pubsub[msg.Broker]; ok { - err = e.Leave(br, msg.Topics) - if err != nil { - return errors.E(op, err) - } + err = e.Leave(msg.Topics) + if err != nil { + return errors.E(op, err) } case commands.Headers: @@ -170,13 +166,13 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit } } -func (e *Executor) Set(br pubsub.Subscriber, topics []string) error { +func (e *Executor) Set(topics []string) error { // associate connection with topics - err := br.Subscribe(e.connID, topics...) + err := e.sub.Subscribe(e.connID, topics...) if err != nil { e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) // in case of error, unsubscribe connection from the dead topics - _ = br.Unsubscribe(e.connID, topics...) + _ = e.sub.Unsubscribe(e.connID, topics...) return err } @@ -188,9 +184,9 @@ func (e *Executor) Set(br pubsub.Subscriber, topics []string) error { return nil } -func (e *Executor) Leave(br pubsub.Subscriber, topics []string) error { +func (e *Executor) Leave(topics []string) error { // remove associated connections from the storage - err := br.Unsubscribe(e.connID, topics...) + err := e.sub.Unsubscribe(e.connID, topics...) if err != nil { e.log.Error("error subscribing to the provided topics", "topics", topics, "error", err.Error()) return err @@ -207,10 +203,7 @@ func (e *Executor) Leave(br pubsub.Subscriber, topics []string) error { func (e *Executor) CleanUp() { // unsubscribe particular connection from the topics for topic := range e.actualTopics { - // here - for _, ps := range e.pubsub { - _ = ps.Unsubscribe(e.connID, topic) - } + _ = e.sub.Unsubscribe(e.connID, topic) } // clean up the actualTopics data diff --git a/plugins/websockets/origin_test.go b/plugins/websockets/origin_test.go index e877fad3..ec6e1960 100644 --- a/plugins/websockets/origin_test.go +++ b/plugins/websockets/origin_test.go @@ -11,7 +11,8 @@ func TestConfig_Origin(t *testing.T) { AllowedOrigin: "*", } - cfg.InitDefault() + err := cfg.InitDefault() + assert.NoError(t, err) assert.True(t, isOriginAllowed("http://some.some.some.sssome", cfg)) assert.True(t, isOriginAllowed("http://", cfg)) @@ -29,7 +30,8 @@ func TestConfig_OriginWildCard(t *testing.T) { AllowedOrigin: "https://*my.site.com", } - cfg.InitDefault() + err := cfg.InitDefault() + assert.NoError(t, err) assert.True(t, isOriginAllowed("https://my.site.com", cfg)) assert.False(t, isOriginAllowed("http://", cfg)) @@ -50,7 +52,8 @@ func TestConfig_OriginWildCard2(t *testing.T) { AllowedOrigin: "https://my.*.com", } - cfg.InitDefault() + err := cfg.InitDefault() + assert.NoError(t, err) assert.True(t, isOriginAllowed("https://my.site.com", cfg)) assert.False(t, isOriginAllowed("http://", cfg)) diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index cf861c72..de7443fd 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -9,13 +9,12 @@ import ( "github.com/fasthttp/websocket" "github.com/google/uuid" json "github.com/json-iterator/go" - endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/interface/broadcast" + "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" "github.com/spiral/roadrunner/v2/pkg/payload" phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/process" - "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/pkg/worker" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" @@ -33,16 +32,14 @@ const ( type Plugin struct { sync.RWMutex - // Collection with all available pubsubs - //pubsubs map[string]pubsub.PubSub - //psProviders map[string]pubsub.PSProvider + // subscriber+reader interfaces + subReader pubsub.SubReader + // broadcaster + broadcaster broadcast.Broadcaster - subReaders map[string]pubsub.SubReader - - cfg *Config - cfgPlugin config.Configurer - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -53,8 +50,10 @@ type Plugin struct { wsUpgrade *websocket.Upgrader serveExit chan struct{} + // workers pool phpPool phpPool.Pool - server server.Server + // server which produces commands to the pool + server server.Server // function used to validate access to the requested resource accessValidator validator.AccessValidatorFn @@ -71,14 +70,10 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se return errors.E(op, err) } - p.cfg.InitDefault() - //p.pubsubs = make(map[string]pubsub.PubSub) - //p.psProviders = make(map[string]pubsub.PSProvider) - - p.subReaders = make(map[string]pubsub.SubReader) - - p.log = log - p.cfgPlugin = cfg + err = p.cfg.InitDefault() + if err != nil { + return errors.E(op, err) + } p.wsUpgrade = &websocket.Upgrader{ HandshakeTimeout: time.Second * 60, @@ -90,19 +85,21 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se } p.serveExit = make(chan struct{}) p.server = server - + p.log = log + p.broadcaster = b return nil } func (p *Plugin) Serve() chan error { - errCh := make(chan error, 1) const op = errors.Op("websockets_plugin_serve") - - //err := p.initPubSubs() - //if err != nil { - // errCh <- errors.E(op, err) - // return errCh - //} + errCh := make(chan error, 1) + // init broadcaster + var err error + p.subReader, err = p.broadcaster.GetDriver(p.cfg.Broker) + if err != nil { + errCh <- errors.E(op, err) + return errCh + } go func() { var err error @@ -124,78 +121,28 @@ func (p *Plugin) Serve() chan error { p.accessValidator = p.defaultAccessValidator(p.phpPool) }() - p.workersPool = pool.NewWorkersPool(p.subReaders, &p.connections, p.log) + p.workersPool = pool.NewWorkersPool(p.subReader, &p.connections, p.log) // run all pubsubs drivers - for _, v := range p.subReaders { - go func(ps pubsub.SubReader) { - for { - select { - case <-p.serveExit: + go func(ps pubsub.Reader) { + for { + select { + case <-p.serveExit: + return + default: + data, err := ps.Next() + if err != nil { + errCh <- err return - default: - data, err := ps.Next() - if err != nil { - errCh <- err - return - } - p.workersPool.Queue(data) } + p.workersPool.Queue(data) } - }(v) - } + } + }(p.subReader) return errCh } -//func (p *Plugin) initPubSubs() error { -// for i := 0; i < len(p.cfg.PubSubs); i++ { -// // don't need to have a section for the in-memory -// if p.cfg.PubSubs[i] == "memory" { -// if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { -// r, err := provider.PSProvide("") -// if err != nil { -// return err -// } -// -// // append default in-memory provider -// p.pubsubs["memory"] = r -// } -// continue -// } -// // key - memory, redis -// if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { -// // try local key -// switch { -// // try local config first -// case p.cfgPlugin.Has(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])): -// r, err := provider.PSProvide(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])) -// if err != nil { -// return err -// } -// -// // append redis provider -// p.pubsubs[p.cfg.PubSubs[i]] = r -// case p.cfgPlugin.Has(p.cfg.PubSubs[i]): -// r, err := provider.PSProvide(p.cfg.PubSubs[i]) -// if err != nil { -// return err -// } -// -// // append redis provider -// p.pubsubs[p.cfg.PubSubs[i]] = r -// default: -// return errors.Errorf("could not find configuration sections for the %s", p.cfg.PubSubs[i]) -// } -// } else { -// // no such driver -// p.log.Warn("no such driver", "requested", p.cfg.PubSubs[i], "available", p.psProviders) -// } -// } -// -// return nil -//} - func (p *Plugin) Stop() error { // close workers pool p.workersPool.Stop() @@ -210,23 +157,12 @@ func (p *Plugin) Stop() error { return nil } -func (p *Plugin) Collects() []interface{} { - return []interface{}{ - p.GetSubsReader, - } -} - func (p *Plugin) Available() {} func (p *Plugin) Name() string { return PluginName } -// GetSubsReader collects all plugins which implement SubReader interface -func (p *Plugin) GetSubsReader(name endure.Named, pub pubsub.SubReader) { - p.subReaders[name.Name()] = pub -} - func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != p.cfg.Path { @@ -272,7 +208,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { p.connections.Store(connectionID, safeConn) // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, connectionID, nil, p.accessValidator, r) + e := executor.NewExecutor(safeConn, p.log, connectionID, p.subReader, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) err = e.StartCommandLoop() diff --git a/plugins/websockets/pool/workers_pool.go b/plugins/websockets/pool/workers_pool.go index 22042d8d..cd9444da 100644 --- a/plugins/websockets/pool/workers_pool.go +++ b/plugins/websockets/pool/workers_pool.go @@ -4,15 +4,15 @@ import ( "sync" json "github.com/json-iterator/go" + "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" - "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/utils" ) type WorkersPool struct { - storage map[string]pubsub.SubReader + subscriber pubsub.Subscriber connections *sync.Map resPool sync.Pool log logger.Logger @@ -22,11 +22,11 @@ type WorkersPool struct { } // NewWorkersPool constructs worker pool for the websocket connections -func NewWorkersPool(pubsubs map[string]pubsub.SubReader, connections *sync.Map, log logger.Logger) *WorkersPool { +func NewWorkersPool(subscriber pubsub.Subscriber, connections *sync.Map, log logger.Logger) *WorkersPool { wp := &WorkersPool{ connections: connections, queue: make(chan *websocketsv1.Message, 100), - storage: pubsubs, + subscriber: subscriber, log: log, exit: make(chan struct{}), } @@ -90,19 +90,13 @@ func (wp *WorkersPool) do() { //nolint:gocognit continue } - br, ok := wp.storage[msg.Broker] - if !ok { - wp.log.Warn("no such broker", "requested", msg.GetBroker(), "available", wp.storage) - continue - } - // send a message to every topic for i := 0; i < len(msg.GetTopics()); i++ { // get free map res := wp.get() // get connections for the particular topic - br.Connections(msg.GetTopics()[i], res) + wp.subscriber.Connections(msg.GetTopics()[i], res) if len(res) == 0 { wp.log.Info("no such topic", "topic", msg.GetTopics()[i]) @@ -114,7 +108,7 @@ func (wp *WorkersPool) do() { //nolint:gocognit for topic := range res { c, ok := wp.connections.Load(topic) if !ok { - wp.log.Warn("the user disconnected connection before the message being written to it", "broker", msg.GetBroker(), "topics", msg.GetTopics()[i]) + wp.log.Warn("the user disconnected connection before the message being written to it", "topics", msg.GetTopics()[i]) wp.put(res) continue } @@ -135,7 +129,7 @@ func (wp *WorkersPool) do() { //nolint:gocognit err = c.(*connection.Connection).Write(d) if err != nil { for i := 0; i < len(msg.GetTopics()); i++ { - wp.log.Error("error sending payload over the connection", "error", err, "broker", msg.GetBroker(), "topics", msg.GetTopics()[i]) + wp.log.Error("error sending payload over the connection", "error", err, "topics", msg.GetTopics()[i]) } wp.put(res) continue |