diff options
-rw-r--r-- | plugins/websockets/plugin.go | 114 |
1 files changed, 85 insertions, 29 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index 4c722860..8a15699e 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -13,7 +13,9 @@ import ( "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/payload" phpPool "github.com/spiral/roadrunner/v2/pkg/pool" + "github.com/spiral/roadrunner/v2/pkg/process" "github.com/spiral/roadrunner/v2/pkg/pubsub" + "github.com/spiral/roadrunner/v2/pkg/worker" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -171,6 +173,7 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { val, err := p.accessValidator(r) p.RUnlock() if err != nil { + p.log.Error("validation error") w.WriteHeader(400) return } @@ -222,6 +225,59 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { }) } +// Workers returns slice with the process states for the workers +func (p *Plugin) Workers() []process.State { + p.RLock() + defer p.RUnlock() + + workers := p.workers() + + ps := make([]process.State, 0, len(workers)) + for i := 0; i < len(workers); i++ { + state, err := process.WorkerProcessState(workers[i]) + if err != nil { + return nil + } + ps = append(ps, state) + } + + return ps +} + +// internal +func (p *Plugin) workers() []worker.BaseProcess { + return p.phpPool.Workers() +} + +// Reset destroys the old pool and replaces it with new one, waiting for old pool to die +func (p *Plugin) Reset() error { + p.Lock() + defer p.Unlock() + const op = errors.Op("ws_plugin_reset") + p.log.Info("WS plugin got restart request. Restarting...") + p.phpPool.Destroy(context.Background()) + p.phpPool = nil + + var err error + p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{ + Debug: p.cfg.Pool.Debug, + NumWorkers: p.cfg.Pool.NumWorkers, + MaxJobs: p.cfg.Pool.MaxJobs, + AllocateTimeout: p.cfg.Pool.AllocateTimeout, + DestroyTimeout: p.cfg.Pool.DestroyTimeout, + Supervisor: p.cfg.Pool.Supervisor, + }, map[string]string{"RR_MODE": "http"}) + if err != nil { + return errors.E(op, err) + } + + // attach validators + p.accessValidator = p.defaultAccessValidator(p.phpPool) + + p.log.Info("WS plugin successfully restarted") + return nil +} + // Publish is an entry point to the websocket PUBSUB func (p *Plugin) Publish(msg []*pubsub.Message) error { p.Lock() @@ -264,6 +320,7 @@ func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValid defer p.RUnlock() const op = errors.Op("access_validator") + p.log.Debug("validation", "topics", topics) r = attributes.Init(r) // if channels len is eq to 0, we use serverValidator @@ -273,21 +330,9 @@ func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValid return nil, errors.E(op, err) } - pd := payload.Payload{ - Context: ctx, - } - - resp, err := pool.Exec(pd) - if err != nil { - return nil, errors.E(op, err) - } - - val := &validator.AccessValidator{ - Body: resp.Body, - } - err = json.Unmarshal(resp.Context, val) + val, err := exec(ctx, pool) if err != nil { - return nil, errors.E(op, err) + return nil, errors.E(err) } return val, nil @@ -298,22 +343,9 @@ func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValid return nil, errors.E(op, err) } - pd := payload.Payload{ - Context: ctx, - } - - resp, err := pool.Exec(pd) + val, err := exec(ctx, pool) if err != nil { - return nil, errors.E(op, err) - } - - val := &validator.AccessValidator{ - Body: resp.Body, - } - - err = json.Unmarshal(resp.Context, val) - if err != nil { - return nil, errors.E(op, err) + return nil, errors.E(op) } if val.Status != http.StatusOK { @@ -323,3 +355,27 @@ func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValid return val, nil } } + +// go:inline +func exec(ctx []byte, pool phpPool.Pool) (*validator.AccessValidator, error) { + const op = errors.Op("exec") + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + return val, nil +} |