summaryrefslogtreecommitdiff
path: root/plugins/websockets/plugin.go
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-06-01 00:10:31 +0300
committerGitHub <[email protected]>2021-06-01 00:10:31 +0300
commit548ee4432e48b316ada00feec1a6b89e67ae4f2f (patch)
tree5cd2aaeeafdb50e3e46824197c721223f54695bf /plugins/websockets/plugin.go
parent8cd696bbca8fac2ced30d8172c41b7434ec86650 (diff)
parentdf4d316d519cea6dff654bd917521a616a37f769 (diff)
#660 feat(plugin): `broadcast` and `broadcast-ws` plugins update to RR2
#660 feat(plugin): `broadcast` and `broadcast-ws` plugins update to RR2
Diffstat (limited to 'plugins/websockets/plugin.go')
-rw-r--r--plugins/websockets/plugin.go386
1 files changed, 386 insertions, 0 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
new file mode 100644
index 00000000..9b21ff8f
--- /dev/null
+++ b/plugins/websockets/plugin.go
@@ -0,0 +1,386 @@
+package websockets
+
+import (
+ "context"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/fasthttp/websocket"
+ "github.com/google/uuid"
+ json "github.com/json-iterator/go"
+ endure "github.com/spiral/endure/pkg/container"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/payload"
+ phpPool "github.com/spiral/roadrunner/v2/pkg/pool"
+ "github.com/spiral/roadrunner/v2/pkg/process"
+ "github.com/spiral/roadrunner/v2/pkg/pubsub"
+ "github.com/spiral/roadrunner/v2/pkg/worker"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/http/attributes"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/plugins/server"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/connection"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/executor"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/pool"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/storage"
+ "github.com/spiral/roadrunner/v2/plugins/websockets/validator"
+)
+
+const (
+ PluginName string = "websockets"
+)
+
+type Plugin struct {
+ sync.RWMutex
+ // Collection with all available pubsubs
+ pubsubs map[string]pubsub.PubSub
+
+ cfg *Config
+ log logger.Logger
+
+ // global connections map
+ connections sync.Map
+ storage *storage.Storage
+
+ // GO workers pool
+ workersPool *pool.WorkersPool
+
+ wsUpgrade *websocket.Upgrader
+ serveExit chan struct{}
+
+ phpPool phpPool.Pool
+ server server.Server
+
+ // function used to validate access to the requested resource
+ accessValidator validator.AccessValidatorFn
+}
+
+func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error {
+ const op = errors.Op("websockets_plugin_init")
+ if !cfg.Has(PluginName) {
+ return errors.E(op, errors.Disabled)
+ }
+
+ err := cfg.UnmarshalKey(PluginName, &p.cfg)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ p.cfg.InitDefault()
+
+ p.pubsubs = make(map[string]pubsub.PubSub)
+ p.log = log
+ p.storage = storage.NewStorage()
+ p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log)
+ p.wsUpgrade = &websocket.Upgrader{
+ HandshakeTimeout: time.Second * 60,
+ }
+ p.serveExit = make(chan struct{})
+ p.server = server
+
+ return nil
+}
+
+func (p *Plugin) Serve() chan error {
+ errCh := make(chan error)
+
+ go func() {
+ var err error
+ p.Lock()
+ defer p.Unlock()
+
+ p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{
+ Debug: p.cfg.Pool.Debug,
+ NumWorkers: p.cfg.Pool.NumWorkers,
+ MaxJobs: p.cfg.Pool.MaxJobs,
+ AllocateTimeout: p.cfg.Pool.AllocateTimeout,
+ DestroyTimeout: p.cfg.Pool.DestroyTimeout,
+ Supervisor: p.cfg.Pool.Supervisor,
+ }, map[string]string{"RR_MODE": "http"})
+ if err != nil {
+ errCh <- err
+ }
+
+ p.accessValidator = p.defaultAccessValidator(p.phpPool)
+ }()
+
+ // run all pubsubs drivers
+ for _, v := range p.pubsubs {
+ go func(ps pubsub.PubSub) {
+ for {
+ select {
+ case <-p.serveExit:
+ return
+ default:
+ data, err := ps.Next()
+ if err != nil {
+ errCh <- err
+ return
+ }
+ p.workersPool.Queue(data)
+ }
+ }
+ }(v)
+ }
+
+ return errCh
+}
+
+func (p *Plugin) Stop() error {
+ // close workers pool
+ p.workersPool.Stop()
+ p.Lock()
+ p.phpPool.Destroy(context.Background())
+ p.Unlock()
+ return nil
+}
+
+func (p *Plugin) Collects() []interface{} {
+ return []interface{}{
+ p.GetPublishers,
+ }
+}
+
+func (p *Plugin) Available() {}
+
+func (p *Plugin) RPC() interface{} {
+ return &rpc{
+ plugin: p,
+ log: p.log,
+ }
+}
+
+func (p *Plugin) Name() string {
+ return PluginName
+}
+
+// GetPublishers collects all pubsubs
+func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) {
+ p.pubsubs[name.Name()] = pub
+}
+
+func (p *Plugin) Middleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != p.cfg.Path {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // we need to lock here, because accessValidator might not be set in the Serve func at the moment
+ p.RLock()
+ // before we hijacked connection, we still can write to the response headers
+ val, err := p.accessValidator(r)
+ p.RUnlock()
+ if err != nil {
+ p.log.Error("validation error")
+ w.WriteHeader(400)
+ return
+ }
+
+ if val.Status != http.StatusOK {
+ for k, v := range val.Header {
+ for i := 0; i < len(v); i++ {
+ w.Header().Add(k, v[i])
+ }
+ }
+ w.WriteHeader(val.Status)
+ _, _ = w.Write(val.Body)
+ return
+ }
+
+ // upgrade connection to websocket connection
+ _conn, err := p.wsUpgrade.Upgrade(w, r, nil)
+ if err != nil {
+ // connection hijacked, do not use response.writer or request
+ p.log.Error("upgrade connection", "error", err)
+ return
+ }
+
+ // construct safe connection protected by mutexes
+ safeConn := connection.NewConnection(_conn, p.log)
+ // generate UUID from the connection
+ connectionID := uuid.NewString()
+ // store connection
+ p.connections.Store(connectionID, safeConn)
+
+ defer func() {
+ // close the connection on exit
+ err = safeConn.Close()
+ if err != nil {
+ p.log.Error("connection close", "error", err)
+ }
+
+ // when exiting - delete the connection
+ p.connections.Delete(connectionID)
+ }()
+
+ // Executor wraps a connection to have a safe abstraction
+ e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r)
+ p.log.Info("websocket client connected", "uuid", connectionID)
+ defer e.CleanUp()
+
+ err = e.StartCommandLoop()
+ if err != nil {
+ p.log.Error("command loop error, disconnecting", "error", err.Error())
+ return
+ }
+
+ p.log.Info("disconnected", "connectionID", connectionID)
+ })
+}
+
+// Workers returns slice with the process states for the workers
+func (p *Plugin) Workers() []process.State {
+ p.RLock()
+ defer p.RUnlock()
+
+ workers := p.workers()
+
+ ps := make([]process.State, 0, len(workers))
+ for i := 0; i < len(workers); i++ {
+ state, err := process.WorkerProcessState(workers[i])
+ if err != nil {
+ return nil
+ }
+ ps = append(ps, state)
+ }
+
+ return ps
+}
+
+// internal
+func (p *Plugin) workers() []worker.BaseProcess {
+ return p.phpPool.Workers()
+}
+
+// Reset destroys the old pool and replaces it with new one, waiting for old pool to die
+func (p *Plugin) Reset() error {
+ p.Lock()
+ defer p.Unlock()
+ const op = errors.Op("ws_plugin_reset")
+ p.log.Info("WS plugin got restart request. Restarting...")
+ p.phpPool.Destroy(context.Background())
+ p.phpPool = nil
+
+ var err error
+ p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{
+ Debug: p.cfg.Pool.Debug,
+ NumWorkers: p.cfg.Pool.NumWorkers,
+ MaxJobs: p.cfg.Pool.MaxJobs,
+ AllocateTimeout: p.cfg.Pool.AllocateTimeout,
+ DestroyTimeout: p.cfg.Pool.DestroyTimeout,
+ Supervisor: p.cfg.Pool.Supervisor,
+ }, map[string]string{"RR_MODE": "http"})
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ // attach validators
+ p.accessValidator = p.defaultAccessValidator(p.phpPool)
+
+ p.log.Info("WS plugin successfully restarted")
+ return nil
+}
+
+// Publish is an entry point to the websocket PUBSUB
+func (p *Plugin) Publish(msg []*pubsub.Message) error {
+ p.Lock()
+ defer p.Unlock()
+
+ for i := 0; i < len(msg); i++ {
+ for j := 0; j < len(msg[i].Topics); j++ {
+ if br, ok := p.pubsubs[msg[i].Broker]; ok {
+ err := br.Publish(msg)
+ if err != nil {
+ return errors.E(err)
+ }
+ } else {
+ p.log.Warn("no such broker", "available", p.pubsubs, "requested", msg[i].Broker)
+ }
+ }
+ }
+ return nil
+}
+
+func (p *Plugin) PublishAsync(msg []*pubsub.Message) {
+ go func() {
+ p.Lock()
+ defer p.Unlock()
+ for i := 0; i < len(msg); i++ {
+ for j := 0; j < len(msg[i].Topics); j++ {
+ err := p.pubsubs[msg[i].Broker].Publish(msg)
+ if err != nil {
+ p.log.Error("publish async error", "error", err)
+ return
+ }
+ }
+ }
+ }()
+}
+
+func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn {
+ return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) {
+ p.RLock()
+ defer p.RUnlock()
+ const op = errors.Op("access_validator")
+
+ p.log.Debug("validation", "topics", topics)
+ r = attributes.Init(r)
+
+ // if channels len is eq to 0, we use serverValidator
+ if len(topics) == 0 {
+ ctx, err := validator.ServerAccessValidator(r)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val, err := exec(ctx, pool)
+ if err != nil {
+ return nil, errors.E(err)
+ }
+
+ return val, nil
+ }
+
+ ctx, err := validator.TopicsAccessValidator(r, topics...)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val, err := exec(ctx, pool)
+ if err != nil {
+ return nil, errors.E(op)
+ }
+
+ if val.Status != http.StatusOK {
+ return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status))
+ }
+
+ return val, nil
+ }
+}
+
+// go:inline
+func exec(ctx []byte, pool phpPool.Pool) (*validator.AccessValidator, error) {
+ const op = errors.Op("exec")
+ pd := payload.Payload{
+ Context: ctx,
+ }
+
+ resp, err := pool.Exec(pd)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val := &validator.AccessValidator{
+ Body: resp.Body,
+ }
+
+ err = json.Unmarshal(resp.Context, val)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ return val, nil
+}