diff options
Diffstat (limited to 'plugins/websockets/plugin.go')
-rw-r--r-- | plugins/websockets/plugin.go | 140 |
1 files changed, 38 insertions, 102 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index cf861c72..de7443fd 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -9,13 +9,12 @@ import ( "github.com/fasthttp/websocket" "github.com/google/uuid" json "github.com/json-iterator/go" - endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/interface/broadcast" + "github.com/spiral/roadrunner/v2/pkg/interface/pubsub" "github.com/spiral/roadrunner/v2/pkg/payload" phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/process" - "github.com/spiral/roadrunner/v2/pkg/pubsub" "github.com/spiral/roadrunner/v2/pkg/worker" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" @@ -33,16 +32,14 @@ const ( type Plugin struct { sync.RWMutex - // Collection with all available pubsubs - //pubsubs map[string]pubsub.PubSub - //psProviders map[string]pubsub.PSProvider + // subscriber+reader interfaces + subReader pubsub.SubReader + // broadcaster + broadcaster broadcast.Broadcaster - subReaders map[string]pubsub.SubReader - - cfg *Config - cfgPlugin config.Configurer - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -53,8 +50,10 @@ type Plugin struct { wsUpgrade *websocket.Upgrader serveExit chan struct{} + // workers pool phpPool phpPool.Pool - server server.Server + // server which produces commands to the pool + server server.Server // function used to validate access to the requested resource accessValidator validator.AccessValidatorFn @@ -71,14 +70,10 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se return errors.E(op, err) } - p.cfg.InitDefault() - //p.pubsubs = make(map[string]pubsub.PubSub) - //p.psProviders = make(map[string]pubsub.PSProvider) - - p.subReaders = make(map[string]pubsub.SubReader) - - p.log = log - p.cfgPlugin = cfg + err = p.cfg.InitDefault() + if err != nil { + return errors.E(op, err) + } p.wsUpgrade = &websocket.Upgrader{ HandshakeTimeout: time.Second * 60, @@ -90,19 +85,21 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se } p.serveExit = make(chan struct{}) p.server = server - + p.log = log + p.broadcaster = b return nil } func (p *Plugin) Serve() chan error { - errCh := make(chan error, 1) const op = errors.Op("websockets_plugin_serve") - - //err := p.initPubSubs() - //if err != nil { - // errCh <- errors.E(op, err) - // return errCh - //} + errCh := make(chan error, 1) + // init broadcaster + var err error + p.subReader, err = p.broadcaster.GetDriver(p.cfg.Broker) + if err != nil { + errCh <- errors.E(op, err) + return errCh + } go func() { var err error @@ -124,78 +121,28 @@ func (p *Plugin) Serve() chan error { p.accessValidator = p.defaultAccessValidator(p.phpPool) }() - p.workersPool = pool.NewWorkersPool(p.subReaders, &p.connections, p.log) + p.workersPool = pool.NewWorkersPool(p.subReader, &p.connections, p.log) // run all pubsubs drivers - for _, v := range p.subReaders { - go func(ps pubsub.SubReader) { - for { - select { - case <-p.serveExit: + go func(ps pubsub.Reader) { + for { + select { + case <-p.serveExit: + return + default: + data, err := ps.Next() + if err != nil { + errCh <- err return - default: - data, err := ps.Next() - if err != nil { - errCh <- err - return - } - p.workersPool.Queue(data) } + p.workersPool.Queue(data) } - }(v) - } + } + }(p.subReader) return errCh } -//func (p *Plugin) initPubSubs() error { -// for i := 0; i < len(p.cfg.PubSubs); i++ { -// // don't need to have a section for the in-memory -// if p.cfg.PubSubs[i] == "memory" { -// if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { -// r, err := provider.PSProvide("") -// if err != nil { -// return err -// } -// -// // append default in-memory provider -// p.pubsubs["memory"] = r -// } -// continue -// } -// // key - memory, redis -// if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok { -// // try local key -// switch { -// // try local config first -// case p.cfgPlugin.Has(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])): -// r, err := provider.PSProvide(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])) -// if err != nil { -// return err -// } -// -// // append redis provider -// p.pubsubs[p.cfg.PubSubs[i]] = r -// case p.cfgPlugin.Has(p.cfg.PubSubs[i]): -// r, err := provider.PSProvide(p.cfg.PubSubs[i]) -// if err != nil { -// return err -// } -// -// // append redis provider -// p.pubsubs[p.cfg.PubSubs[i]] = r -// default: -// return errors.Errorf("could not find configuration sections for the %s", p.cfg.PubSubs[i]) -// } -// } else { -// // no such driver -// p.log.Warn("no such driver", "requested", p.cfg.PubSubs[i], "available", p.psProviders) -// } -// } -// -// return nil -//} - func (p *Plugin) Stop() error { // close workers pool p.workersPool.Stop() @@ -210,23 +157,12 @@ func (p *Plugin) Stop() error { return nil } -func (p *Plugin) Collects() []interface{} { - return []interface{}{ - p.GetSubsReader, - } -} - func (p *Plugin) Available() {} func (p *Plugin) Name() string { return PluginName } -// GetSubsReader collects all plugins which implement SubReader interface -func (p *Plugin) GetSubsReader(name endure.Named, pub pubsub.SubReader) { - p.subReaders[name.Name()] = pub -} - func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != p.cfg.Path { @@ -272,7 +208,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { p.connections.Store(connectionID, safeConn) // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, connectionID, nil, p.accessValidator, r) + e := executor.NewExecutor(safeConn, p.log, connectionID, p.subReader, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) err = e.StartCommandLoop() |