diff options
24 files changed, 392 insertions, 433 deletions
diff --git a/pkg/pubsub/message.go b/pkg/pubsub/message.go index c1a7246a..c17d153b 100644 --- a/pkg/pubsub/message.go +++ b/pkg/pubsub/message.go @@ -5,15 +5,15 @@ import ( ) type Message struct { - // Topic message been pushed into. - Topics []string `json:"topic"` - // Command (join, leave, headers) Command string `json:"command"` // Broker (redis, memory) Broker string `json:"broker"` + // Topic message been pushed into. + Topics []string `json:"topic"` + // Payload to be broadcasted Payload []byte `json:"payload"` } diff --git a/pkg/worker_handler/handler.go b/pkg/worker_handler/handler.go index d98cdef0..e0d1aae0 100644 --- a/pkg/worker_handler/handler.go +++ b/pkg/worker_handler/handler.go @@ -202,16 +202,16 @@ func (h *Handler) resolveIP(r *Request) { // CF-Connecting-IP is an Enterprise feature and we check it last in order. // This operations are near O(1) because Headers struct are the map type -> type MIMEHeader map[string][]string if r.Header.Get("X-Real-Ip") != "" { - r.RemoteAddr = fetchIP(r.Header.Get("X-Real-Ip")) + r.RemoteAddr = FetchIP(r.Header.Get("X-Real-Ip")) return } if r.Header.Get("True-Client-IP") != "" { - r.RemoteAddr = fetchIP(r.Header.Get("True-Client-IP")) + r.RemoteAddr = FetchIP(r.Header.Get("True-Client-IP")) return } if r.Header.Get("CF-Connecting-IP") != "" { - r.RemoteAddr = fetchIP(r.Header.Get("CF-Connecting-IP")) + r.RemoteAddr = FetchIP(r.Header.Get("CF-Connecting-IP")) } } diff --git a/pkg/worker_handler/request.go b/pkg/worker_handler/request.go index 178bc827..75ee8381 100644 --- a/pkg/worker_handler/request.go +++ b/pkg/worker_handler/request.go @@ -61,7 +61,7 @@ type Request struct { body interface{} } -func fetchIP(pair string) string { +func FetchIP(pair string) string { if !strings.ContainsRune(pair, ':') { return pair } @@ -73,10 +73,10 @@ func fetchIP(pair string) string { // NewRequest creates new PSR7 compatible request using net/http request. func NewRequest(r *http.Request, cfg config.Uploads) (*Request, error) { req := &Request{ - RemoteAddr: fetchIP(r.RemoteAddr), + RemoteAddr: FetchIP(r.RemoteAddr), Protocol: r.Proto, Method: r.Method, - URI: uri(r), + URI: URI(r), Header: r.Header, Cookies: make(map[string]string), RawQuery: r.URL.RawQuery, @@ -174,8 +174,8 @@ func (r *Request) contentType() int { return contentStream } -// uri fetches full uri from request in a form of string (including https scheme if TLS connection is enabled). -func uri(r *http.Request) string { +// URI fetches full uri from request in a form of string (including https scheme if TLS connection is enabled). +func URI(r *http.Request) string { if r.URL.Host != "" { return r.URL.String() } diff --git a/plugins/channel/interface.go b/plugins/channel/interface.go deleted file mode 100644 index 50fc9c96..00000000 --- a/plugins/channel/interface.go +++ /dev/null @@ -1,8 +0,0 @@ -package channel - -// Hub used as a channel between two or more plugins -// this is not a PUBLIC plugin, API might be changed at any moment -type Hub interface { - FromWorker() chan interface{} - ToWorker() chan interface{} -} diff --git a/plugins/channel/plugin.go b/plugins/channel/plugin.go deleted file mode 100644 index 362dbc07..00000000 --- a/plugins/channel/plugin.go +++ /dev/null @@ -1,59 +0,0 @@ -package channel - -import ( - "sync" -) - -const ( - PluginName string = "hub" -) - -type Plugin struct { - sync.Mutex - fromCh chan interface{} - toCh chan interface{} -} - -func (p *Plugin) Init() error { - p.Lock() - defer p.Unlock() - - p.fromCh = make(chan interface{}) - p.toCh = make(chan interface{}) - return nil -} - -func (p *Plugin) Serve() chan error { - return make(chan error) -} - -func (p *Plugin) Stop() error { - // read from the channels on stop to prevent blocking - go func() { - for range p.fromCh { - } - }() - go func() { - for range p.toCh { - } - }() - return nil -} - -func (p *Plugin) FromWorker() chan interface{} { - p.Lock() - defer p.Unlock() - // one-directional queue - return p.fromCh -} - -func (p *Plugin) ToWorker() chan interface{} { - p.Lock() - defer p.Unlock() - // one-directional queue - return p.toCh -} - -func (p *Plugin) Name() string { - return PluginName -} diff --git a/plugins/http/channel.go b/plugins/http/channel.go deleted file mode 100644 index 23b5ff3e..00000000 --- a/plugins/http/channel.go +++ /dev/null @@ -1,29 +0,0 @@ -package http - -import ( - "net/http" -) - -// messages method used to read messages from the ws plugin with the auth requests for the topics and server -func (p *Plugin) messages() { - for msg := range p.hub.ToWorker() { - p.RLock() - // msg here is the structure with http.ResponseWriter and http.Request - rmsg := msg.(struct { - RW http.ResponseWriter - Req *http.Request - }) - - // invoke handler with redirected responsewriter and request - p.handler.ServeHTTP(rmsg.RW, rmsg.Req) - - p.hub.FromWorker() <- struct { - RW http.ResponseWriter - Req *http.Request - }{ - rmsg.RW, - rmsg.Req, - } - p.RUnlock() - } -} diff --git a/plugins/http/config/http.go b/plugins/http/config/http.go index 8b63395f..a1c2afa6 100644 --- a/plugins/http/config/http.go +++ b/plugins/http/config/http.go @@ -7,7 +7,7 @@ import ( "time" "github.com/spiral/errors" - poolImpl "github.com/spiral/roadrunner/v2/pkg/pool" + "github.com/spiral/roadrunner/v2/pkg/pool" ) // HTTP configures RoadRunner HTTP server. @@ -34,7 +34,7 @@ type HTTP struct { Uploads *Uploads `mapstructure:"uploads"` // Pool configures worker pool. - Pool *poolImpl.Config `mapstructure:"pool"` + Pool *pool.Config `mapstructure:"pool"` // Env is environment variables passed to the http pool Env map[string]string @@ -70,7 +70,7 @@ func (c *HTTP) EnableFCGI() bool { func (c *HTTP) InitDefaults() error { if c.Pool == nil { // default pool - c.Pool = &poolImpl.Config{ + c.Pool = &pool.Config{ Debug: false, NumWorkers: uint64(runtime.NumCPU()), MaxJobs: 0, diff --git a/plugins/http/plugin.go b/plugins/http/plugin.go index c4a1a83f..770ca8ca 100644 --- a/plugins/http/plugin.go +++ b/plugins/http/plugin.go @@ -6,7 +6,6 @@ import ( "log" "net/http" "sync" - "time" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" @@ -14,7 +13,6 @@ import ( "github.com/spiral/roadrunner/v2/pkg/process" "github.com/spiral/roadrunner/v2/pkg/worker" handler "github.com/spiral/roadrunner/v2/pkg/worker_handler" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" httpConfig "github.com/spiral/roadrunner/v2/plugins/http/config" @@ -68,14 +66,11 @@ type Plugin struct { http *http.Server https *http.Server fcgi *http.Server - - // message bus - hub channel.Hub } // Init must return configure svc and return true if svc hasStatus enabled. Must return error in case of // misconfiguration. Services must not be used without proper configuration pushed first. -func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server server.Server, hub channel.Hub) error { +func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server server.Server) error { const op = errors.Op("http_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) @@ -109,7 +104,6 @@ func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server serv p.cfg.Env[RrMode] = "http" p.server = server - p.hub = hub return nil } @@ -174,9 +168,8 @@ func (p *Plugin) serve(errCh chan error) { } } else { p.http = &http.Server{ - Handler: p, - ErrorLog: p.stdLog, - ReadHeaderTimeout: time.Second, + Handler: p, + ErrorLog: p.stdLog, } } } @@ -216,9 +209,6 @@ func (p *Plugin) serve(errCh chan error) { go func() { p.serveFCGI(errCh) }() - - // read messages from the ws - go p.messages() } // Stop stops the http. @@ -229,21 +219,21 @@ func (p *Plugin) Stop() error { if p.fcgi != nil { err := p.fcgi.Shutdown(context.Background()) if err != nil && err != http.ErrServerClosed { - p.log.Error("error shutting down the fcgi server", "error", err) + p.log.Error("fcgi shutdown", "error", err) } } if p.https != nil { err := p.https.Shutdown(context.Background()) if err != nil && err != http.ErrServerClosed { - p.log.Error("error shutting down the https server", "error", err) + p.log.Error("https shutdown", "error", err) } } if p.http != nil { - err := p.http.Close() + err := p.http.Shutdown(context.Background()) if err != nil && err != http.ErrServerClosed { - p.log.Error("error shutting down the http server", "error", err) + p.log.Error("http shutdown", "error", err) } } @@ -337,6 +327,7 @@ func (p *Plugin) Reset() error { p.cfg.Cidrs, p.pool, ) + if err != nil { return errors.E(op, err) } diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go index f3cb8e12..be4aaa82 100644 --- a/plugins/websockets/config.go +++ b/plugins/websockets/config.go @@ -1,22 +1,16 @@ package websockets -import "time" +import ( + "time" + + "github.com/spiral/roadrunner/v2/pkg/pool" +) /* websockets: # pubsubs should implement PubSub interface to be collected via endure.Collects - # also, they should implement RPC methods to publish data into them - # pubsubs might use general config section or its own pubsubs:["redis", "amqp", "memory"] - - # sample of the own config section for the redis pubsub driver - redis: - address: - - localhost:1111 - .... the rest - - # path used as websockets path path: "/ws" */ @@ -26,33 +20,10 @@ type Config struct { // http path for the websocket Path string `mapstructure:"path"` // ["redis", "amqp", "memory"] - PubSubs []string `mapstructure:"pubsubs"` - Middleware []string `mapstructure:"middleware"` - Redis *RedisConfig `mapstructure:"redis"` -} + PubSubs []string `mapstructure:"pubsubs"` + Middleware []string `mapstructure:"middleware"` -type RedisConfig struct { - Addrs []string `mapstructure:"addrs"` - DB int `mapstructure:"db"` - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - MasterName string `mapstructure:"master_name"` - SentinelPassword string `mapstructure:"sentinel_password"` - RouteByLatency bool `mapstructure:"route_by_latency"` - RouteRandomly bool `mapstructure:"route_randomly"` - MaxRetries int `mapstructure:"max_retries"` - DialTimeout time.Duration `mapstructure:"dial_timeout"` - MinRetryBackoff time.Duration `mapstructure:"min_retry_backoff"` - MaxRetryBackoff time.Duration `mapstructure:"max_retry_backoff"` - PoolSize int `mapstructure:"pool_size"` - MinIdleConns int `mapstructure:"min_idle_conns"` - MaxConnAge time.Duration `mapstructure:"max_conn_age"` - ReadTimeout time.Duration `mapstructure:"read_timeout"` - WriteTimeout time.Duration `mapstructure:"write_timeout"` - PoolTimeout time.Duration `mapstructure:"pool_timeout"` - IdleTimeout time.Duration `mapstructure:"idle_timeout"` - IdleCheckFreq time.Duration `mapstructure:"idle_check_freq"` - ReadOnly bool `mapstructure:"read_only"` + Pool *pool.Config `mapstructure:"pool"` } // InitDefault initialize default values for the ws config @@ -64,4 +35,24 @@ func (c *Config) InitDefault() { // memory used by default c.PubSubs = append(c.PubSubs, "memory") } + + if c.Pool == nil { + c.Pool = &pool.Config{} + if c.Pool.NumWorkers == 0 { + // 2 workers by default + c.Pool.NumWorkers = 2 + } + + if c.Pool.AllocateTimeout == 0 { + c.Pool.AllocateTimeout = time.Minute + } + + if c.Pool.DestroyTimeout == 0 { + c.Pool.DestroyTimeout = time.Minute + } + if c.Pool.Supervisor == nil { + return + } + c.Pool.Supervisor.InitDefaults() + } } diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index 87fed3a6..24ea19ce 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -9,7 +9,6 @@ import ( json "github.com/json-iterator/go" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pubsub" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" @@ -35,21 +34,22 @@ type Executor struct { pubsub map[string]pubsub.PubSub actualTopics map[string]struct{} - hub channel.Hub - req *http.Request + req *http.Request + accessValidator validator.AccessValidatorFn } // NewExecutor creates protected connection and starts command loop -func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub, hub channel.Hub, r *http.Request) *Executor { +func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, + connID string, pubsubs map[string]pubsub.PubSub, av validator.AccessValidatorFn, r *http.Request) *Executor { return &Executor{ - conn: conn, - connID: connID, - storage: bst, - log: log, - pubsub: pubsubs, - hub: hub, - actualTopics: make(map[string]struct{}, 10), - req: r, + conn: conn, + connID: connID, + storage: bst, + log: log, + pubsub: pubsubs, + accessValidator: av, + actualTopics: make(map[string]struct{}, 10), + req: r, } } @@ -85,8 +85,12 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit case commands.Join: e.log.Debug("get join command", "msg", msg) - err := validator.NewValidator().AssertTopicsAccess(e.hub, e.req, msg.Topics...) + val, err := e.accessValidator(e.req, msg.Topics...) if err != nil { + if val != nil { + e.log.Debug("validation error", "status", val.Status, "headers", val.Header, "body", val.Body) + } + resp := &Response{ Topic: "#join", Payload: msg.Topics, diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index c51c7ca1..4c722860 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -1,19 +1,23 @@ package websockets import ( + "context" "net/http" "sync" "time" "github.com/fasthttp/websocket" "github.com/google/uuid" + json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/payload" + phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/pubsub" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/executor" "github.com/spiral/roadrunner/v2/plugins/websockets/pool" @@ -26,12 +30,12 @@ const ( ) type Plugin struct { - mu sync.RWMutex + sync.RWMutex // Collection with all available pubsubs pubsubs map[string]pubsub.PubSub - Config *Config - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -40,25 +44,38 @@ type Plugin struct { // GO workers pool workersPool *pool.WorkersPool - hub channel.Hub + wsUpgrade *websocket.Upgrader + serveExit chan struct{} + + phpPool phpPool.Pool + server server.Server + + // function used to validate access to the requested resource + accessValidator validator.AccessValidatorFn } -func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.Hub) error { +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { const op = errors.Op("websockets_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) } - err := cfg.UnmarshalKey(PluginName, &p.Config) + err := cfg.UnmarshalKey(PluginName, &p.cfg) if err != nil { return errors.E(op, err) } + p.cfg.InitDefault() + p.pubsubs = make(map[string]pubsub.PubSub) p.log = log p.storage = storage.NewStorage() p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log) - p.hub = channel + p.wsUpgrade = &websocket.Upgrader{ + HandshakeTimeout: time.Second * 60, + } + p.serveExit = make(chan struct{}) + p.server = server return nil } @@ -66,25 +83,54 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel. func (p *Plugin) Serve() chan error { errCh := make(chan error) + go func() { + var err error + p.Lock() + defer p.Unlock() + + p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{ + Debug: p.cfg.Pool.Debug, + NumWorkers: p.cfg.Pool.NumWorkers, + MaxJobs: p.cfg.Pool.MaxJobs, + AllocateTimeout: p.cfg.Pool.AllocateTimeout, + DestroyTimeout: p.cfg.Pool.DestroyTimeout, + Supervisor: p.cfg.Pool.Supervisor, + }, map[string]string{"RR_MODE": "http"}) + if err != nil { + errCh <- err + } + + p.accessValidator = p.defaultAccessValidator(p.phpPool) + }() + // run all pubsubs drivers for _, v := range p.pubsubs { go func(ps pubsub.PubSub) { for { - data, err := ps.Next() - if err != nil { - errCh <- err + select { + case <-p.serveExit: return + default: + data, err := ps.Next() + if err != nil { + errCh <- err + return + } + p.workersPool.Queue(data) } - - p.workersPool.Queue(data) } }(v) } + return errCh } func (p *Plugin) Stop() error { + // close workers pool p.workersPool.Stop() + p.Lock() + p.phpPool.Destroy(context.Background()) + p.Unlock() return nil } @@ -114,84 +160,72 @@ func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) { func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != p.Config.Path { + if r.URL.Path != p.cfg.Path { next.ServeHTTP(w, r) return } - r = attributes.Init(r) - - err := validator.NewValidator().AssertServerAccess(p.hub, r) + // we need to lock here, because accessValidator might not be set in the Serve func at the moment + p.RLock() + // before we hijacked connection, we still can write to the response headers + val, err := p.accessValidator(r) + p.RUnlock() if err != nil { - // show the error to the user - if av, ok := err.(*validator.AccessValidator); ok { - av.Copy(w) - } else { - w.WriteHeader(400) - return - } + w.WriteHeader(400) + return } - // connection upgrader - upgraded := websocket.Upgrader{ - HandshakeTimeout: time.Second * 60, - ReadBufferSize: 0, - WriteBufferSize: 0, - WriteBufferPool: nil, - Subprotocols: nil, - Error: nil, - CheckOrigin: nil, - EnableCompression: false, + if val.Status != http.StatusOK { + _, _ = w.Write(val.Body) + w.WriteHeader(val.Status) + return } // upgrade connection to websocket connection - _conn, err := upgraded.Upgrade(w, r, nil) + _conn, err := p.wsUpgrade.Upgrade(w, r, nil) if err != nil { // connection hijacked, do not use response.writer or request - p.log.Error("upgrade connection error", "error", err) + p.log.Error("upgrade connection", "error", err) return } // construct safe connection protected by mutexes safeConn := connection.NewConnection(_conn, p.log) + // generate UUID from the connection + connectionID := uuid.NewString() + // store connection + p.connections.Store(connectionID, safeConn) + defer func() { // close the connection on exit err = safeConn.Close() if err != nil { - p.log.Error("connection close error", "error", err) + p.log.Error("connection close", "error", err) } - }() - // generate UUID from the connection - connectionID := uuid.NewString() - // store connection - p.connections.Store(connectionID, safeConn) - // when exiting - delete the connection - defer func() { + // when exiting - delete the connection p.connections.Delete(connectionID) }() - p.mu.Lock() // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.hub, r) - p.mu.Unlock() - + e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) - defer e.CleanUp() err = e.StartCommandLoop() if err != nil { - p.log.Error("command loop error", "error", err.Error()) + p.log.Error("command loop error, disconnecting", "error", err.Error()) return } + + p.log.Info("disconnected", "connectionID", connectionID) }) } // Publish is an entry point to the websocket PUBSUB func (p *Plugin) Publish(msg []*pubsub.Message) error { - p.mu.Lock() - defer p.mu.Unlock() + p.Lock() + defer p.Unlock() for i := 0; i < len(msg); i++ { for j := 0; j < len(msg[i].Topics); j++ { @@ -210,8 +244,8 @@ func (p *Plugin) Publish(msg []*pubsub.Message) error { func (p *Plugin) PublishAsync(msg []*pubsub.Message) { go func() { - p.mu.Lock() - defer p.mu.Unlock() + p.Lock() + defer p.Unlock() for i := 0; i < len(msg); i++ { for j := 0; j < len(msg[i].Topics); j++ { err := p.pubsubs[msg[i].Broker].Publish(msg) @@ -223,3 +257,69 @@ func (p *Plugin) PublishAsync(msg []*pubsub.Message) { } }() } + +func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn { + return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) { + p.RLock() + defer p.RUnlock() + const op = errors.Op("access_validator") + + r = attributes.Init(r) + + // if channels len is eq to 0, we use serverValidator + if len(topics) == 0 { + ctx, err := validator.TopicsAccessValidator(r, topics...) + if err != nil { + return nil, errors.E(op, err) + } + + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + return val, nil + } + + ctx, err := validator.ServerAccessValidator(r) + if err != nil { + return nil, errors.E(op, err) + } + + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + if val.Status != http.StatusOK { + return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status)) + } + + return val, nil + } +} diff --git a/plugins/websockets/schema/message.fbs b/plugins/websockets/schema/message.fbs new file mode 100644 index 00000000..f2d92c78 --- /dev/null +++ b/plugins/websockets/schema/message.fbs @@ -0,0 +1,10 @@ +namespace message; + +table Message { + command:string; + broker:string; + topics:[string]; + payload:[byte]; +} + +root_type Message; diff --git a/plugins/websockets/schema/message/Message.go b/plugins/websockets/schema/message/Message.go new file mode 100644 index 00000000..26bbd12c --- /dev/null +++ b/plugins/websockets/schema/message/Message.go @@ -0,0 +1,118 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package message + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type Message struct { + _tab flatbuffers.Table +} + +func GetRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Message{} + x.Init(buf, n+offset) + return x +} + +func GetSizePrefixedRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &Message{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func (rcv *Message) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Message) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Message) Command() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Message) Broker() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Message) Topics(j int) []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j*4)) + } + return nil +} + +func (rcv *Message) TopicsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Message) Payload(j int) int8 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt8(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Message) PayloadLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Message) MutatePayload(j int, n int8) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateInt8(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func MessageStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func MessageAddCommand(builder *flatbuffers.Builder, command flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(command), 0) +} +func MessageAddBroker(builder *flatbuffers.Builder, broker flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(broker), 0) +} +func MessageAddTopics(builder *flatbuffers.Builder, topics flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(topics), 0) +} +func MessageStartTopicsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MessageAddPayload(builder *flatbuffers.Builder, payload flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(payload), 0) +} +func MessageStartPayloadVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MessageEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go index cd70d9a7..e666f846 100644 --- a/plugins/websockets/validator/access_validator.go +++ b/plugins/websockets/validator/access_validator.go @@ -1,138 +1,76 @@ package validator import ( - "bytes" - "io" "net/http" "strings" + json "github.com/json-iterator/go" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/plugins/channel" + handler "github.com/spiral/roadrunner/v2/pkg/worker_handler" "github.com/spiral/roadrunner/v2/plugins/http/attributes" ) -type AccessValidator struct { - buffer *bytes.Buffer - header http.Header - status int -} - -func NewValidator() *AccessValidator { - return &AccessValidator{ - buffer: bytes.NewBuffer(nil), - header: make(http.Header), - } -} - -// Copy all content to parent response writer. -func (w *AccessValidator) Copy(rw http.ResponseWriter) { - rw.WriteHeader(w.status) - - for k, v := range w.header { - for _, vv := range v { - rw.Header().Add(k, vv) - } - } +type AccessValidatorFn = func(r *http.Request, channels ...string) (*AccessValidator, error) - _, _ = io.Copy(rw, w.buffer) -} - -// Header returns the header map that will be sent by WriteHeader. -func (w *AccessValidator) Header() http.Header { - return w.header -} - -// Write writes the data to the connection as part of an HTTP reply. -func (w *AccessValidator) Write(p []byte) (int, error) { - return w.buffer.Write(p) -} - -// WriteHeader sends an HTTP response header with the provided status code. -func (w *AccessValidator) WriteHeader(statusCode int) { - w.status = statusCode -} - -// IsOK returns true if response contained 200 status code. -func (w *AccessValidator) IsOK() bool { - return w.status == 200 -} - -// Body returns response body to rely to user. -func (w *AccessValidator) Body() []byte { - return w.buffer.Bytes() -} - -// Error contains server response. -func (w *AccessValidator) Error() string { - return w.buffer.String() +type AccessValidator struct { + Header http.Header `json:"headers"` + Status int `json:"status"` + Body []byte } -// AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in -// case of error -func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error { +func ServerAccessValidator(r *http.Request) ([]byte, error) { const op = errors.Op("server_access_validator") + err := attributes.Set(r, "ws:joinServer", true) if err != nil { - return errors.E(op, err) + return nil, errors.E(op, err) } defer delete(attributes.All(r), "ws:joinServer") - // send payload to the worker - hub.ToWorker() <- struct { - RW http.ResponseWriter - Req *http.Request - }{ - w, - r, + req := &handler.Request{ + RemoteAddr: handler.FetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: handler.URI(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), } - resp := <-hub.FromWorker() - - rmsg := resp.(struct { - RW http.ResponseWriter - Req *http.Request - }) - - if !rmsg.RW.(*AccessValidator).IsOK() { - return w + data, err := json.Marshal(req) + if err != nil { + return nil, errors.E(op, err) } - return nil + return data, nil } -// AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies. -// the decision to authorize user will be based on response code (200). -func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error { - const op = errors.Op("topics_access_validator") - - err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")) +func TopicsAccessValidator(r *http.Request, topics ...string) ([]byte, error) { + const op = errors.Op("topic_access_validator") + err := attributes.Set(r, "ws:joinTopics", strings.Join(topics, ",")) if err != nil { - return errors.E(op, err) + return nil, errors.E(op, err) } defer delete(attributes.All(r), "ws:joinTopics") - // send payload to worker - hub.ToWorker() <- struct { - RW http.ResponseWriter - Req *http.Request - }{ - w, - r, + req := &handler.Request{ + RemoteAddr: handler.FetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: handler.URI(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), } - // wait response - resp := <-hub.FromWorker() - - rmsg := resp.(struct { - RW http.ResponseWriter - Req *http.Request - }) - - if !rmsg.RW.(*AccessValidator).IsOK() { - return w + data, err := json.Marshal(req) + if err != nil { + return nil, errors.E(op, err) } - return nil + return data, nil } diff --git a/plugins/websockets/validator/access_validator_test.go b/plugins/websockets/validator/access_validator_test.go deleted file mode 100644 index 4a07b00f..00000000 --- a/plugins/websockets/validator/access_validator_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package validator - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestResponseWrapper_Body(t *testing.T) { - w := NewValidator() - _, _ = w.Write([]byte("hello")) - - assert.Equal(t, []byte("hello"), w.Body()) -} - -func TestResponseWrapper_Header(t *testing.T) { - w := NewValidator() - w.Header().Set("k", "value") - - assert.Equal(t, "value", w.Header().Get("k")) -} - -func TestResponseWrapper_StatusCode(t *testing.T) { - w := NewValidator() - w.WriteHeader(200) - - assert.True(t, w.IsOK()) -} - -func TestResponseWrapper_StatusCodeBad(t *testing.T) { - w := NewValidator() - w.WriteHeader(400) - - assert.False(t, w.IsOK()) -} diff --git a/tests/plugins/config/config_test.go b/tests/plugins/config/config_test.go index 3cf026bd..b6063cec 100755 --- a/tests/plugins/config/config_test.go +++ b/tests/plugins/config/config_test.go @@ -7,7 +7,6 @@ import ( "time" endure "github.com/spiral/endure/pkg/container" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/rpc" @@ -34,11 +33,6 @@ func TestViperProvider_Init(t *testing.T) { t.Fatal(err) } - err = container.Register(&channel.Plugin{}) - if err != nil { - t.Fatal(err) - } - err = container.Init() if err != nil { t.Fatal(err) @@ -88,7 +82,6 @@ func TestConfigOverwriteFail(t *testing.T) { &rpc.Plugin{}, vp, &Foo2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -110,7 +103,6 @@ func TestConfigOverwriteValid(t *testing.T) { &logger.ZapLogger{}, &rpc.Plugin{}, vp, - &channel.Plugin{}, &Foo2{}, ) assert.NoError(t, err) @@ -163,7 +155,6 @@ func TestConfigEnvVariables(t *testing.T) { &rpc.Plugin{}, vp, &Foo2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -215,7 +206,6 @@ func TestConfigEnvVariablesFail(t *testing.T) { &rpc.Plugin{}, vp, &Foo2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -247,11 +237,6 @@ func TestConfigProvider_GeneralSection(t *testing.T) { t.Fatal(err) } - err = container.Register(&channel.Plugin{}) - if err != nil { - t.Fatal(err) - } - err = container.Init() if err != nil { t.Fatal(err) diff --git a/tests/plugins/gzip/plugin_test.go b/tests/plugins/gzip/plugin_test.go index 5294e672..844fd411 100644 --- a/tests/plugins/gzip/plugin_test.go +++ b/tests/plugins/gzip/plugin_test.go @@ -11,7 +11,6 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/gzip" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -36,7 +35,6 @@ func TestGzipPlugin(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &gzip.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -130,7 +128,6 @@ func TestMiddlewareNotExist(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &gzip.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/headers/headers_plugin_test.go b/tests/plugins/headers/headers_plugin_test.go index e4903335..49d86b00 100644 --- a/tests/plugins/headers/headers_plugin_test.go +++ b/tests/plugins/headers/headers_plugin_test.go @@ -11,7 +11,6 @@ import ( "time" endure "github.com/spiral/endure/pkg/container" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/headers" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" @@ -35,7 +34,6 @@ func TestHeadersInit(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -102,7 +100,6 @@ func TestRequestHeaders(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -188,7 +185,6 @@ func TestResponseHeaders(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -275,7 +271,6 @@ func TestCORSHeaders(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &headers.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/http/http_plugin_test.go b/tests/plugins/http/http_plugin_test.go index aa57077d..128eec26 100644 --- a/tests/plugins/http/http_plugin_test.go +++ b/tests/plugins/http/http_plugin_test.go @@ -24,7 +24,6 @@ import ( goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" "github.com/spiral/roadrunner/v2/pkg/events" "github.com/spiral/roadrunner/v2/pkg/process" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/gzip" "github.com/spiral/roadrunner/v2/plugins/informer" @@ -63,7 +62,6 @@ func TestHTTPInit(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -128,7 +126,6 @@ func TestHTTPNoConfigSection(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -196,7 +193,6 @@ func TestHTTPInformerReset(t *testing.T) { &httpPlugin.Plugin{}, &informer.Plugin{}, &resetter.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -319,7 +315,6 @@ func TestSSL(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -456,7 +451,6 @@ func TestSSLRedirect(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -546,7 +540,6 @@ func TestSSLPushPipes(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -637,7 +630,6 @@ func TestFastCGI_RequestUri(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -732,7 +724,6 @@ func TestH2CUpgrade(t *testing.T) { mockLogger, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -824,7 +815,6 @@ func TestH2C(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -917,7 +907,6 @@ func TestHttpMiddleware(t *testing.T) { &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1064,7 +1053,6 @@ logs: &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1150,7 +1138,6 @@ func TestHttpEnvVariables(t *testing.T) { &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1238,7 +1225,6 @@ func TestHttpBrokenPipes(t *testing.T) { &httpPlugin.Plugin{}, &PluginMiddleware{}, &PluginMiddleware2{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1300,7 +1286,6 @@ func TestHTTPSupervisedPool(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &informer.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1503,7 +1488,6 @@ func TestHTTPBigRequestSize(t *testing.T) { &logger.ZapLogger{}, &server.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1596,7 +1580,6 @@ func TestStaticEtagPlugin(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1695,7 +1678,6 @@ func TestStaticPluginSecurity(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1845,7 +1827,6 @@ func TestStaticPlugin(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -1960,7 +1941,6 @@ func TestStaticDisabled_Error(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) assert.Error(t, cont.Init()) @@ -1982,7 +1962,6 @@ func TestStaticFilesDisabled(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -2075,7 +2054,6 @@ func TestStaticFilesForbid(t *testing.T) { &httpPlugin.Plugin{}, &gzip.Plugin{}, &static.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/logger/logger_test.go b/tests/plugins/logger/logger_test.go index f63a6a5d..d2877781 100644 --- a/tests/plugins/logger/logger_test.go +++ b/tests/plugins/logger/logger_test.go @@ -9,7 +9,6 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -99,7 +98,6 @@ func TestLoggerRawErr(t *testing.T) { mockLogger, &server.Plugin{}, &http.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -226,7 +224,6 @@ func TestLoggerNoConfig2(t *testing.T) { &logger.ZapLogger{}, &http.Plugin{}, &server.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/metrics/metrics_test.go b/tests/plugins/metrics/metrics_test.go index 48c01f24..8be567ec 100644 --- a/tests/plugins/metrics/metrics_test.go +++ b/tests/plugins/metrics/metrics_test.go @@ -15,7 +15,6 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -145,7 +144,6 @@ func TestMetricsIssue571(t *testing.T) { &server.Plugin{}, mockLogger, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/reload/reload_plugin_test.go b/tests/plugins/reload/reload_plugin_test.go index 41c9c92f..6db7b6d0 100644 --- a/tests/plugins/reload/reload_plugin_test.go +++ b/tests/plugins/reload/reload_plugin_test.go @@ -16,7 +16,6 @@ import ( "github.com/golang/mock/gomock" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/reload" @@ -66,7 +65,6 @@ func TestReloadInit(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -163,7 +161,6 @@ func TestReloadHugeNumberOfFiles(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -273,7 +270,6 @@ func TestReloadFilterFileExt(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -404,7 +400,6 @@ func TestReloadCopy100(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -682,7 +677,6 @@ func TestReloadNoRecursion(t *testing.T) { &httpPlugin.Plugin{}, &reload.Plugin{}, &resetter.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/status/plugin_test.go b/tests/plugins/status/plugin_test.go index 06983199..663f4ee3 100644 --- a/tests/plugins/status/plugin_test.go +++ b/tests/plugins/status/plugin_test.go @@ -14,7 +14,6 @@ import ( endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -39,7 +38,6 @@ func TestStatusHttp(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -127,7 +125,6 @@ func TestStatusRPC(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -207,7 +204,6 @@ func TestReadyHttp(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -295,7 +291,6 @@ func TestReadinessRPCWorkerNotReady(t *testing.T) { &server.Plugin{}, &httpPlugin.Plugin{}, &status.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) diff --git a/tests/plugins/websockets/websocket_plugin_test.go b/tests/plugins/websockets/websocket_plugin_test.go index 61ef186b..6b11f9e1 100644 --- a/tests/plugins/websockets/websocket_plugin_test.go +++ b/tests/plugins/websockets/websocket_plugin_test.go @@ -16,7 +16,6 @@ import ( json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" "github.com/spiral/roadrunner/v2/plugins/logger" @@ -60,7 +59,6 @@ func TestBroadcastInit(t *testing.T) { &redis.Plugin{}, &websockets.Plugin{}, &httpPlugin.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -170,7 +168,6 @@ func TestWSRedisAndMemory(t *testing.T) { &websockets.Plugin{}, &httpPlugin.Plugin{}, &memory.Plugin{}, - &channel.Plugin{}, ) assert.NoError(t, err) @@ -313,7 +310,9 @@ func RPCWsMemory(t *testing.T) { assert.NoError(t, err) defer func() { - _ = resp.Body.Close() + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } }() d, err := json.Marshal(message("join", "memory", []byte("hello websockets"), "foo", "foo2")) |