diff options
Diffstat (limited to 'plugins/websockets/plugin.go')
-rw-r--r-- | plugins/websockets/plugin.go | 208 |
1 files changed, 154 insertions, 54 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index c51c7ca1..4c722860 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -1,19 +1,23 @@ package websockets import ( + "context" "net/http" "sync" "time" "github.com/fasthttp/websocket" "github.com/google/uuid" + json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/payload" + phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/pubsub" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/executor" "github.com/spiral/roadrunner/v2/plugins/websockets/pool" @@ -26,12 +30,12 @@ const ( ) type Plugin struct { - mu sync.RWMutex + sync.RWMutex // Collection with all available pubsubs pubsubs map[string]pubsub.PubSub - Config *Config - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -40,25 +44,38 @@ type Plugin struct { // GO workers pool workersPool *pool.WorkersPool - hub channel.Hub + wsUpgrade *websocket.Upgrader + serveExit chan struct{} + + phpPool phpPool.Pool + server server.Server + + // function used to validate access to the requested resource + accessValidator validator.AccessValidatorFn } -func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.Hub) error { +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { const op = errors.Op("websockets_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) } - err := cfg.UnmarshalKey(PluginName, &p.Config) + err := cfg.UnmarshalKey(PluginName, &p.cfg) if err != nil { return errors.E(op, err) } + p.cfg.InitDefault() + p.pubsubs = make(map[string]pubsub.PubSub) p.log = log p.storage = storage.NewStorage() p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log) - p.hub = channel + p.wsUpgrade = &websocket.Upgrader{ + HandshakeTimeout: time.Second * 60, + } + p.serveExit = make(chan struct{}) + p.server = server return nil } @@ -66,25 +83,54 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel. func (p *Plugin) Serve() chan error { errCh := make(chan error) + go func() { + var err error + p.Lock() + defer p.Unlock() + + p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{ + Debug: p.cfg.Pool.Debug, + NumWorkers: p.cfg.Pool.NumWorkers, + MaxJobs: p.cfg.Pool.MaxJobs, + AllocateTimeout: p.cfg.Pool.AllocateTimeout, + DestroyTimeout: p.cfg.Pool.DestroyTimeout, + Supervisor: p.cfg.Pool.Supervisor, + }, map[string]string{"RR_MODE": "http"}) + if err != nil { + errCh <- err + } + + p.accessValidator = p.defaultAccessValidator(p.phpPool) + }() + // run all pubsubs drivers for _, v := range p.pubsubs { go func(ps pubsub.PubSub) { for { - data, err := ps.Next() - if err != nil { - errCh <- err + select { + case <-p.serveExit: return + default: + data, err := ps.Next() + if err != nil { + errCh <- err + return + } + p.workersPool.Queue(data) } - - p.workersPool.Queue(data) } }(v) } + return errCh } func (p *Plugin) Stop() error { + // close workers pool p.workersPool.Stop() + p.Lock() + p.phpPool.Destroy(context.Background()) + p.Unlock() return nil } @@ -114,84 +160,72 @@ func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) { func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != p.Config.Path { + if r.URL.Path != p.cfg.Path { next.ServeHTTP(w, r) return } - r = attributes.Init(r) - - err := validator.NewValidator().AssertServerAccess(p.hub, r) + // we need to lock here, because accessValidator might not be set in the Serve func at the moment + p.RLock() + // before we hijacked connection, we still can write to the response headers + val, err := p.accessValidator(r) + p.RUnlock() if err != nil { - // show the error to the user - if av, ok := err.(*validator.AccessValidator); ok { - av.Copy(w) - } else { - w.WriteHeader(400) - return - } + w.WriteHeader(400) + return } - // connection upgrader - upgraded := websocket.Upgrader{ - HandshakeTimeout: time.Second * 60, - ReadBufferSize: 0, - WriteBufferSize: 0, - WriteBufferPool: nil, - Subprotocols: nil, - Error: nil, - CheckOrigin: nil, - EnableCompression: false, + if val.Status != http.StatusOK { + _, _ = w.Write(val.Body) + w.WriteHeader(val.Status) + return } // upgrade connection to websocket connection - _conn, err := upgraded.Upgrade(w, r, nil) + _conn, err := p.wsUpgrade.Upgrade(w, r, nil) if err != nil { // connection hijacked, do not use response.writer or request - p.log.Error("upgrade connection error", "error", err) + p.log.Error("upgrade connection", "error", err) return } // construct safe connection protected by mutexes safeConn := connection.NewConnection(_conn, p.log) + // generate UUID from the connection + connectionID := uuid.NewString() + // store connection + p.connections.Store(connectionID, safeConn) + defer func() { // close the connection on exit err = safeConn.Close() if err != nil { - p.log.Error("connection close error", "error", err) + p.log.Error("connection close", "error", err) } - }() - // generate UUID from the connection - connectionID := uuid.NewString() - // store connection - p.connections.Store(connectionID, safeConn) - // when exiting - delete the connection - defer func() { + // when exiting - delete the connection p.connections.Delete(connectionID) }() - p.mu.Lock() // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.hub, r) - p.mu.Unlock() - + e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) - defer e.CleanUp() err = e.StartCommandLoop() if err != nil { - p.log.Error("command loop error", "error", err.Error()) + p.log.Error("command loop error, disconnecting", "error", err.Error()) return } + + p.log.Info("disconnected", "connectionID", connectionID) }) } // Publish is an entry point to the websocket PUBSUB func (p *Plugin) Publish(msg []*pubsub.Message) error { - p.mu.Lock() - defer p.mu.Unlock() + p.Lock() + defer p.Unlock() for i := 0; i < len(msg); i++ { for j := 0; j < len(msg[i].Topics); j++ { @@ -210,8 +244,8 @@ func (p *Plugin) Publish(msg []*pubsub.Message) error { func (p *Plugin) PublishAsync(msg []*pubsub.Message) { go func() { - p.mu.Lock() - defer p.mu.Unlock() + p.Lock() + defer p.Unlock() for i := 0; i < len(msg); i++ { for j := 0; j < len(msg[i].Topics); j++ { err := p.pubsubs[msg[i].Broker].Publish(msg) @@ -223,3 +257,69 @@ func (p *Plugin) PublishAsync(msg []*pubsub.Message) { } }() } + +func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn { + return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) { + p.RLock() + defer p.RUnlock() + const op = errors.Op("access_validator") + + r = attributes.Init(r) + + // if channels len is eq to 0, we use serverValidator + if len(topics) == 0 { + ctx, err := validator.TopicsAccessValidator(r, topics...) + if err != nil { + return nil, errors.E(op, err) + } + + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + return val, nil + } + + ctx, err := validator.ServerAccessValidator(r) + if err != nil { + return nil, errors.E(op, err) + } + + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + if val.Status != http.StatusOK { + return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status)) + } + + return val, nil + } +} |