diff options
author | Valery Piashchynski <[email protected]> | 2021-06-21 11:41:42 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2021-06-21 11:41:42 +0300 |
commit | bdcfdd28d705e401973da2beb8a11543e362bda4 (patch) | |
tree | 6a80b5b78ce18c7ddf298861d5b0cd05d8c64ccf /plugins/websockets/plugin.go | |
parent | cee4bc46097506d6e892b6af194751434700621a (diff) | |
parent | 87d023d32feef5fe28c9bb65a796deb77d536b15 (diff) |
Merge remote-tracking branch 'origin/master' into feature/jobs_plugin
# Conflicts:
# plugins/websockets/plugin.go
Diffstat (limited to 'plugins/websockets/plugin.go')
-rw-r--r-- | plugins/websockets/plugin.go | 196 |
1 files changed, 37 insertions, 159 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index a1002bdd..ca5f2f59 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -2,7 +2,6 @@ package websockets import ( "context" - "fmt" "net/http" "sync" "time" @@ -10,14 +9,13 @@ import ( "github.com/fasthttp/websocket" "github.com/google/uuid" json "github.com/json-iterator/go" - endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/payload" phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/process" - websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta" "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/pkg/worker" + "github.com/spiral/roadrunner/v2/plugins/broadcast" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -26,7 +24,6 @@ import ( "github.com/spiral/roadrunner/v2/plugins/websockets/executor" "github.com/spiral/roadrunner/v2/plugins/websockets/pool" "github.com/spiral/roadrunner/v2/plugins/websockets/validator" - "google.golang.org/protobuf/proto" ) const ( @@ -35,14 +32,14 @@ const ( type Plugin struct { sync.RWMutex - // Collection with all available pubsubs - pubsubs map[string]pubsub.PubSub - psProviders map[string]pubsub.PSProvider + // subscriber+reader interfaces + subReader pubsub.SubReader + // broadcaster + broadcaster broadcast.Broadcaster - cfg *Config - cfgPlugin config.Configurer - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -53,14 +50,16 @@ type Plugin struct { wsUpgrade *websocket.Upgrader serveExit chan struct{} + // workers pool phpPool phpPool.Pool - server server.Server + // server which produces commands to the pool + server server.Server // function used to validate access to the requested resource accessValidator validator.AccessValidatorFn } -func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server, b broadcast.Broadcaster) error { const op = errors.Op("websockets_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) @@ -71,36 +70,32 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se return errors.E(op, err) } - p.cfg.InitDefault() - p.pubsubs = make(map[string]pubsub.PubSub) - p.psProviders = make(map[string]pubsub.PSProvider) - - p.log = log - p.cfgPlugin = cfg + err = p.cfg.InitDefault() + if err != nil { + return errors.E(op, err) + } p.wsUpgrade = &websocket.Upgrader{ HandshakeTimeout: time.Second * 60, ReadBufferSize: 1024, WriteBufferSize: 1024, - WriteBufferPool: nil, - Subprotocols: nil, - Error: nil, CheckOrigin: func(r *http.Request) bool { - return true + return isOriginAllowed(r.Header.Get("Origin"), p.cfg) }, - EnableCompression: false, } p.serveExit = make(chan struct{}) p.server = server - + p.log = log + p.broadcaster = b return nil } func (p *Plugin) Serve() chan error { - errCh := make(chan error, 1) const op = errors.Op("websockets_plugin_serve") - - err := p.initPubSubs() + errCh := make(chan error, 1) + // init broadcaster + var err error + p.subReader, err = p.broadcaster.GetDriver(p.cfg.Broker) if err != nil { errCh <- errors.E(op, err) return errCh @@ -126,76 +121,26 @@ func (p *Plugin) Serve() chan error { p.accessValidator = p.defaultAccessValidator(p.phpPool) }() - p.workersPool = pool.NewWorkersPool(p.pubsubs, &p.connections, p.log) - - // run all pubsubs drivers - for _, v := range p.pubsubs { - go func(ps pubsub.PubSub) { - for { - select { - case <-p.serveExit: - return - default: - data, err := ps.Next() - if err != nil { - errCh <- err - return - } - p.workersPool.Queue(data) - } - } - }(v) - } - - return errCh -} - -func (p *Plugin) initPubSubs() error { - for i := 0; i < len(p.cfg.PubSubs); i++ { - // don't need to have a section for the in-memory - if p.cfg.PubSubs[i] == "memory" { - if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { - r, err := provider.PSProvide("") - if err != nil { - return err - } + p.workersPool = pool.NewWorkersPool(p.subReader, &p.connections, p.log) - // append default in-memory provider - p.pubsubs["memory"] = r - } - continue - } - // key - memory, redis - if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { - // try local key - switch { - // try local config first - case p.cfgPlugin.Has(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])): - r, err := provider.PSProvide(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])) - if err != nil { - return err - } - - // append redis provider - p.pubsubs[p.cfg.PubSubs[i]] = r - case p.cfgPlugin.Has(p.cfg.PubSubs[i]): - r, err := provider.PSProvide(p.cfg.PubSubs[i]) + // we need here only Reader part of the interface + go func(ps pubsub.Reader) { + for { + select { + case <-p.serveExit: + return + default: + data, err := ps.Next() if err != nil { - return err + errCh <- err + return } - - // append redis provider - p.pubsubs[p.cfg.PubSubs[i]] = r - default: - return errors.Errorf("could not find configuration sections for the %s", p.cfg.PubSubs[i]) + p.workersPool.Queue(data) } - } else { - // no such driver - p.log.Warn("no such driver", "requested", p.cfg.PubSubs[i], "available", p.psProviders) } - } + }(p.subReader) - return nil + return errCh } func (p *Plugin) Stop() error { @@ -212,30 +157,12 @@ func (p *Plugin) Stop() error { return nil } -func (p *Plugin) Collects() []interface{} { - return []interface{}{ - p.GetPublishers, - } -} - func (p *Plugin) Available() {} -func (p *Plugin) RPC() interface{} { - return &rpc{ - plugin: p, - log: p.log, - } -} - func (p *Plugin) Name() string { return PluginName } -// GetPublishers collects all pubsubs -func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PSProvider) { - p.psProviders[name.Name()] = pub -} - func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != p.cfg.Path { @@ -281,7 +208,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { p.connections.Store(connectionID, safeConn) // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, connectionID, p.pubsubs, p.accessValidator, r) + e := executor.NewExecutor(safeConn, p.log, connectionID, p.subReader, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) err = e.StartCommandLoop() @@ -365,55 +292,6 @@ func (p *Plugin) Reset() error { return nil } -// Publish is an entry point to the websocket PUBSUB -func (p *Plugin) Publish(m []byte) error { - p.Lock() - defer p.Unlock() - - msg := &websocketsv1.Message{} - err := proto.Unmarshal(m, msg) - if err != nil { - return err - } - - // Get payload - for i := 0; i < len(msg.GetTopics()); i++ { - if br, ok := p.pubsubs[msg.GetBroker()]; ok { - err := br.Publish(m) - if err != nil { - return errors.E(err) - } - } else { - p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg.GetBroker()) - } - } - return nil -} - -func (p *Plugin) PublishAsync(m []byte) { - go func() { - p.Lock() - defer p.Unlock() - msg := &websocketsv1.Message{} - err := proto.Unmarshal(m, msg) - if err != nil { - p.log.Error("message unmarshal") - } - - // Get payload - for i := 0; i < len(msg.GetTopics()); i++ { - if br, ok := p.pubsubs[msg.GetBroker()]; ok { - err := br.Publish(m) - if err != nil { - p.log.Error("publish async error", "error", err) - } - } else { - p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg.GetBroker()) - } - } - }() -} - func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn { return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) { const op = errors.Op("access_validator") |