summaryrefslogtreecommitdiff
path: root/plugins/websockets/plugin.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/websockets/plugin.go')
-rw-r--r--plugins/websockets/plugin.go196
1 files changed, 37 insertions, 159 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
index a1002bdd..ca5f2f59 100644
--- a/plugins/websockets/plugin.go
+++ b/plugins/websockets/plugin.go
@@ -2,7 +2,6 @@ package websockets
import (
"context"
- "fmt"
"net/http"
"sync"
"time"
@@ -10,14 +9,13 @@ import (
"github.com/fasthttp/websocket"
"github.com/google/uuid"
json "github.com/json-iterator/go"
- endure "github.com/spiral/endure/pkg/container"
"github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/pkg/payload"
phpPool "github.com/spiral/roadrunner/v2/pkg/pool"
"github.com/spiral/roadrunner/v2/pkg/process"
- websocketsv1 "github.com/spiral/roadrunner/v2/pkg/proto/websockets/v1beta"
"github.com/spiral/roadrunner/v2/pkg/pubsub"
"github.com/spiral/roadrunner/v2/pkg/worker"
+ "github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/config"
"github.com/spiral/roadrunner/v2/plugins/http/attributes"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -26,7 +24,6 @@ import (
"github.com/spiral/roadrunner/v2/plugins/websockets/executor"
"github.com/spiral/roadrunner/v2/plugins/websockets/pool"
"github.com/spiral/roadrunner/v2/plugins/websockets/validator"
- "google.golang.org/protobuf/proto"
)
const (
@@ -35,14 +32,14 @@ const (
type Plugin struct {
sync.RWMutex
- // Collection with all available pubsubs
- pubsubs map[string]pubsub.PubSub
- psProviders map[string]pubsub.PSProvider
+ // subscriber+reader interfaces
+ subReader pubsub.SubReader
+ // broadcaster
+ broadcaster broadcast.Broadcaster
- cfg *Config
- cfgPlugin config.Configurer
- log logger.Logger
+ cfg *Config
+ log logger.Logger
// global connections map
connections sync.Map
@@ -53,14 +50,16 @@ type Plugin struct {
wsUpgrade *websocket.Upgrader
serveExit chan struct{}
+ // workers pool
phpPool phpPool.Pool
- server server.Server
+ // server which produces commands to the pool
+ server server.Server
// function used to validate access to the requested resource
accessValidator validator.AccessValidatorFn
}
-func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error {
+func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server, b broadcast.Broadcaster) error {
const op = errors.Op("websockets_plugin_init")
if !cfg.Has(PluginName) {
return errors.E(op, errors.Disabled)
@@ -71,36 +70,32 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se
return errors.E(op, err)
}
- p.cfg.InitDefault()
- p.pubsubs = make(map[string]pubsub.PubSub)
- p.psProviders = make(map[string]pubsub.PSProvider)
-
- p.log = log
- p.cfgPlugin = cfg
+ err = p.cfg.InitDefault()
+ if err != nil {
+ return errors.E(op, err)
+ }
p.wsUpgrade = &websocket.Upgrader{
HandshakeTimeout: time.Second * 60,
ReadBufferSize: 1024,
WriteBufferSize: 1024,
- WriteBufferPool: nil,
- Subprotocols: nil,
- Error: nil,
CheckOrigin: func(r *http.Request) bool {
- return true
+ return isOriginAllowed(r.Header.Get("Origin"), p.cfg)
},
- EnableCompression: false,
}
p.serveExit = make(chan struct{})
p.server = server
-
+ p.log = log
+ p.broadcaster = b
return nil
}
func (p *Plugin) Serve() chan error {
- errCh := make(chan error, 1)
const op = errors.Op("websockets_plugin_serve")
-
- err := p.initPubSubs()
+ errCh := make(chan error, 1)
+ // init broadcaster
+ var err error
+ p.subReader, err = p.broadcaster.GetDriver(p.cfg.Broker)
if err != nil {
errCh <- errors.E(op, err)
return errCh
@@ -126,76 +121,26 @@ func (p *Plugin) Serve() chan error {
p.accessValidator = p.defaultAccessValidator(p.phpPool)
}()
- p.workersPool = pool.NewWorkersPool(p.pubsubs, &p.connections, p.log)
-
- // run all pubsubs drivers
- for _, v := range p.pubsubs {
- go func(ps pubsub.PubSub) {
- for {
- select {
- case <-p.serveExit:
- return
- default:
- data, err := ps.Next()
- if err != nil {
- errCh <- err
- return
- }
- p.workersPool.Queue(data)
- }
- }
- }(v)
- }
-
- return errCh
-}
-
-func (p *Plugin) initPubSubs() error {
- for i := 0; i < len(p.cfg.PubSubs); i++ {
- // don't need to have a section for the in-memory
- if p.cfg.PubSubs[i] == "memory" {
- if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok {
- r, err := provider.PSProvide("")
- if err != nil {
- return err
- }
+ p.workersPool = pool.NewWorkersPool(p.subReader, &p.connections, p.log)
- // append default in-memory provider
- p.pubsubs["memory"] = r
- }
- continue
- }
- // key - memory, redis
- if provider, ok := p.psProviders[p.cfg.PubSubs[i]]; ok {
- // try local key
- switch {
- // try local config first
- case p.cfgPlugin.Has(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i])):
- r, err := provider.PSProvide(fmt.Sprintf("%s.%s", PluginName, p.cfg.PubSubs[i]))
- if err != nil {
- return err
- }
-
- // append redis provider
- p.pubsubs[p.cfg.PubSubs[i]] = r
- case p.cfgPlugin.Has(p.cfg.PubSubs[i]):
- r, err := provider.PSProvide(p.cfg.PubSubs[i])
+ // we need here only Reader part of the interface
+ go func(ps pubsub.Reader) {
+ for {
+ select {
+ case <-p.serveExit:
+ return
+ default:
+ data, err := ps.Next()
if err != nil {
- return err
+ errCh <- err
+ return
}
-
- // append redis provider
- p.pubsubs[p.cfg.PubSubs[i]] = r
- default:
- return errors.Errorf("could not find configuration sections for the %s", p.cfg.PubSubs[i])
+ p.workersPool.Queue(data)
}
- } else {
- // no such driver
- p.log.Warn("no such driver", "requested", p.cfg.PubSubs[i], "available", p.psProviders)
}
- }
+ }(p.subReader)
- return nil
+ return errCh
}
func (p *Plugin) Stop() error {
@@ -212,30 +157,12 @@ func (p *Plugin) Stop() error {
return nil
}
-func (p *Plugin) Collects() []interface{} {
- return []interface{}{
- p.GetPublishers,
- }
-}
-
func (p *Plugin) Available() {}
-func (p *Plugin) RPC() interface{} {
- return &rpc{
- plugin: p,
- log: p.log,
- }
-}
-
func (p *Plugin) Name() string {
return PluginName
}
-// GetPublishers collects all pubsubs
-func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PSProvider) {
- p.psProviders[name.Name()] = pub
-}
-
func (p *Plugin) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != p.cfg.Path {
@@ -281,7 +208,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
p.connections.Store(connectionID, safeConn)
// Executor wraps a connection to have a safe abstraction
- e := executor.NewExecutor(safeConn, p.log, connectionID, p.pubsubs, p.accessValidator, r)
+ e := executor.NewExecutor(safeConn, p.log, connectionID, p.subReader, p.accessValidator, r)
p.log.Info("websocket client connected", "uuid", connectionID)
err = e.StartCommandLoop()
@@ -365,55 +292,6 @@ func (p *Plugin) Reset() error {
return nil
}
-// Publish is an entry point to the websocket PUBSUB
-func (p *Plugin) Publish(m []byte) error {
- p.Lock()
- defer p.Unlock()
-
- msg := &websocketsv1.Message{}
- err := proto.Unmarshal(m, msg)
- if err != nil {
- return err
- }
-
- // Get payload
- for i := 0; i < len(msg.GetTopics()); i++ {
- if br, ok := p.pubsubs[msg.GetBroker()]; ok {
- err := br.Publish(m)
- if err != nil {
- return errors.E(err)
- }
- } else {
- p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg.GetBroker())
- }
- }
- return nil
-}
-
-func (p *Plugin) PublishAsync(m []byte) {
- go func() {
- p.Lock()
- defer p.Unlock()
- msg := &websocketsv1.Message{}
- err := proto.Unmarshal(m, msg)
- if err != nil {
- p.log.Error("message unmarshal")
- }
-
- // Get payload
- for i := 0; i < len(msg.GetTopics()); i++ {
- if br, ok := p.pubsubs[msg.GetBroker()]; ok {
- err := br.Publish(m)
- if err != nil {
- p.log.Error("publish async error", "error", err)
- }
- } else {
- p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg.GetBroker())
- }
- }
- }()
-}
-
func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn {
return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) {
const op = errors.Op("access_validator")