diff options
author | Valery Piashchynski <[email protected]> | 2021-05-31 16:05:00 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2021-05-31 16:05:00 +0300 |
commit | 49703d70a3ede70ce9a0cab824cbcb96dbf824c0 (patch) | |
tree | 181d72a3321d52c960a519ba3a233e3e7fe8e86a /plugins/websockets | |
parent | 0ee91dc24d3e68706d89092c06b1c0d09dab0353 (diff) |
- Rework access_validators
- WS plugin uses it's own pool to handle requests on the /ws (or any
user-defined) endpoint
- Ability to write custom validators
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/websockets')
-rw-r--r-- | plugins/websockets/config.go | 65 | ||||
-rw-r--r-- | plugins/websockets/executor/executor.go | 30 | ||||
-rw-r--r-- | plugins/websockets/plugin.go | 208 | ||||
-rw-r--r-- | plugins/websockets/schema/message.fbs | 10 | ||||
-rw-r--r-- | plugins/websockets/schema/message/Message.go | 118 | ||||
-rw-r--r-- | plugins/websockets/validator/access_validator.go | 142 | ||||
-rw-r--r-- | plugins/websockets/validator/access_validator_test.go | 35 |
7 files changed, 367 insertions, 241 deletions
diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go index f3cb8e12..be4aaa82 100644 --- a/plugins/websockets/config.go +++ b/plugins/websockets/config.go @@ -1,22 +1,16 @@ package websockets -import "time" +import ( + "time" + + "github.com/spiral/roadrunner/v2/pkg/pool" +) /* websockets: # pubsubs should implement PubSub interface to be collected via endure.Collects - # also, they should implement RPC methods to publish data into them - # pubsubs might use general config section or its own pubsubs:["redis", "amqp", "memory"] - - # sample of the own config section for the redis pubsub driver - redis: - address: - - localhost:1111 - .... the rest - - # path used as websockets path path: "/ws" */ @@ -26,33 +20,10 @@ type Config struct { // http path for the websocket Path string `mapstructure:"path"` // ["redis", "amqp", "memory"] - PubSubs []string `mapstructure:"pubsubs"` - Middleware []string `mapstructure:"middleware"` - Redis *RedisConfig `mapstructure:"redis"` -} + PubSubs []string `mapstructure:"pubsubs"` + Middleware []string `mapstructure:"middleware"` -type RedisConfig struct { - Addrs []string `mapstructure:"addrs"` - DB int `mapstructure:"db"` - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - MasterName string `mapstructure:"master_name"` - SentinelPassword string `mapstructure:"sentinel_password"` - RouteByLatency bool `mapstructure:"route_by_latency"` - RouteRandomly bool `mapstructure:"route_randomly"` - MaxRetries int `mapstructure:"max_retries"` - DialTimeout time.Duration `mapstructure:"dial_timeout"` - MinRetryBackoff time.Duration `mapstructure:"min_retry_backoff"` - MaxRetryBackoff time.Duration `mapstructure:"max_retry_backoff"` - PoolSize int `mapstructure:"pool_size"` - MinIdleConns int `mapstructure:"min_idle_conns"` - MaxConnAge time.Duration `mapstructure:"max_conn_age"` - ReadTimeout time.Duration `mapstructure:"read_timeout"` - WriteTimeout time.Duration `mapstructure:"write_timeout"` - PoolTimeout time.Duration `mapstructure:"pool_timeout"` - IdleTimeout time.Duration `mapstructure:"idle_timeout"` - IdleCheckFreq time.Duration `mapstructure:"idle_check_freq"` - ReadOnly bool `mapstructure:"read_only"` + Pool *pool.Config `mapstructure:"pool"` } // InitDefault initialize default values for the ws config @@ -64,4 +35,24 @@ func (c *Config) InitDefault() { // memory used by default c.PubSubs = append(c.PubSubs, "memory") } + + if c.Pool == nil { + c.Pool = &pool.Config{} + if c.Pool.NumWorkers == 0 { + // 2 workers by default + c.Pool.NumWorkers = 2 + } + + if c.Pool.AllocateTimeout == 0 { + c.Pool.AllocateTimeout = time.Minute + } + + if c.Pool.DestroyTimeout == 0 { + c.Pool.DestroyTimeout = time.Minute + } + if c.Pool.Supervisor == nil { + return + } + c.Pool.Supervisor.InitDefaults() + } } diff --git a/plugins/websockets/executor/executor.go b/plugins/websockets/executor/executor.go index 87fed3a6..24ea19ce 100644 --- a/plugins/websockets/executor/executor.go +++ b/plugins/websockets/executor/executor.go @@ -9,7 +9,6 @@ import ( json "github.com/json-iterator/go" "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/pkg/pubsub" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/logger" "github.com/spiral/roadrunner/v2/plugins/websockets/commands" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" @@ -35,21 +34,22 @@ type Executor struct { pubsub map[string]pubsub.PubSub actualTopics map[string]struct{} - hub channel.Hub - req *http.Request + req *http.Request + accessValidator validator.AccessValidatorFn } // NewExecutor creates protected connection and starts command loop -func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, connID string, pubsubs map[string]pubsub.PubSub, hub channel.Hub, r *http.Request) *Executor { +func NewExecutor(conn *connection.Connection, log logger.Logger, bst *storage.Storage, + connID string, pubsubs map[string]pubsub.PubSub, av validator.AccessValidatorFn, r *http.Request) *Executor { return &Executor{ - conn: conn, - connID: connID, - storage: bst, - log: log, - pubsub: pubsubs, - hub: hub, - actualTopics: make(map[string]struct{}, 10), - req: r, + conn: conn, + connID: connID, + storage: bst, + log: log, + pubsub: pubsubs, + accessValidator: av, + actualTopics: make(map[string]struct{}, 10), + req: r, } } @@ -85,8 +85,12 @@ func (e *Executor) StartCommandLoop() error { //nolint:gocognit case commands.Join: e.log.Debug("get join command", "msg", msg) - err := validator.NewValidator().AssertTopicsAccess(e.hub, e.req, msg.Topics...) + val, err := e.accessValidator(e.req, msg.Topics...) if err != nil { + if val != nil { + e.log.Debug("validation error", "status", val.Status, "headers", val.Header, "body", val.Body) + } + resp := &Response{ Topic: "#join", Payload: msg.Topics, diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index c51c7ca1..4c722860 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -1,19 +1,23 @@ package websockets import ( + "context" "net/http" "sync" "time" "github.com/fasthttp/websocket" "github.com/google/uuid" + json "github.com/json-iterator/go" endure "github.com/spiral/endure/pkg/container" "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/payload" + phpPool "github.com/spiral/roadrunner/v2/pkg/pool" "github.com/spiral/roadrunner/v2/pkg/pubsub" - "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/config" "github.com/spiral/roadrunner/v2/plugins/http/attributes" "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" "github.com/spiral/roadrunner/v2/plugins/websockets/connection" "github.com/spiral/roadrunner/v2/plugins/websockets/executor" "github.com/spiral/roadrunner/v2/plugins/websockets/pool" @@ -26,12 +30,12 @@ const ( ) type Plugin struct { - mu sync.RWMutex + sync.RWMutex // Collection with all available pubsubs pubsubs map[string]pubsub.PubSub - Config *Config - log logger.Logger + cfg *Config + log logger.Logger // global connections map connections sync.Map @@ -40,25 +44,38 @@ type Plugin struct { // GO workers pool workersPool *pool.WorkersPool - hub channel.Hub + wsUpgrade *websocket.Upgrader + serveExit chan struct{} + + phpPool phpPool.Pool + server server.Server + + // function used to validate access to the requested resource + accessValidator validator.AccessValidatorFn } -func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.Hub) error { +func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { const op = errors.Op("websockets_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) } - err := cfg.UnmarshalKey(PluginName, &p.Config) + err := cfg.UnmarshalKey(PluginName, &p.cfg) if err != nil { return errors.E(op, err) } + p.cfg.InitDefault() + p.pubsubs = make(map[string]pubsub.PubSub) p.log = log p.storage = storage.NewStorage() p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log) - p.hub = channel + p.wsUpgrade = &websocket.Upgrader{ + HandshakeTimeout: time.Second * 60, + } + p.serveExit = make(chan struct{}) + p.server = server return nil } @@ -66,25 +83,54 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel. func (p *Plugin) Serve() chan error { errCh := make(chan error) + go func() { + var err error + p.Lock() + defer p.Unlock() + + p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{ + Debug: p.cfg.Pool.Debug, + NumWorkers: p.cfg.Pool.NumWorkers, + MaxJobs: p.cfg.Pool.MaxJobs, + AllocateTimeout: p.cfg.Pool.AllocateTimeout, + DestroyTimeout: p.cfg.Pool.DestroyTimeout, + Supervisor: p.cfg.Pool.Supervisor, + }, map[string]string{"RR_MODE": "http"}) + if err != nil { + errCh <- err + } + + p.accessValidator = p.defaultAccessValidator(p.phpPool) + }() + // run all pubsubs drivers for _, v := range p.pubsubs { go func(ps pubsub.PubSub) { for { - data, err := ps.Next() - if err != nil { - errCh <- err + select { + case <-p.serveExit: return + default: + data, err := ps.Next() + if err != nil { + errCh <- err + return + } + p.workersPool.Queue(data) } - - p.workersPool.Queue(data) } }(v) } + return errCh } func (p *Plugin) Stop() error { + // close workers pool p.workersPool.Stop() + p.Lock() + p.phpPool.Destroy(context.Background()) + p.Unlock() return nil } @@ -114,84 +160,72 @@ func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) { func (p *Plugin) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != p.Config.Path { + if r.URL.Path != p.cfg.Path { next.ServeHTTP(w, r) return } - r = attributes.Init(r) - - err := validator.NewValidator().AssertServerAccess(p.hub, r) + // we need to lock here, because accessValidator might not be set in the Serve func at the moment + p.RLock() + // before we hijacked connection, we still can write to the response headers + val, err := p.accessValidator(r) + p.RUnlock() if err != nil { - // show the error to the user - if av, ok := err.(*validator.AccessValidator); ok { - av.Copy(w) - } else { - w.WriteHeader(400) - return - } + w.WriteHeader(400) + return } - // connection upgrader - upgraded := websocket.Upgrader{ - HandshakeTimeout: time.Second * 60, - ReadBufferSize: 0, - WriteBufferSize: 0, - WriteBufferPool: nil, - Subprotocols: nil, - Error: nil, - CheckOrigin: nil, - EnableCompression: false, + if val.Status != http.StatusOK { + _, _ = w.Write(val.Body) + w.WriteHeader(val.Status) + return } // upgrade connection to websocket connection - _conn, err := upgraded.Upgrade(w, r, nil) + _conn, err := p.wsUpgrade.Upgrade(w, r, nil) if err != nil { // connection hijacked, do not use response.writer or request - p.log.Error("upgrade connection error", "error", err) + p.log.Error("upgrade connection", "error", err) return } // construct safe connection protected by mutexes safeConn := connection.NewConnection(_conn, p.log) + // generate UUID from the connection + connectionID := uuid.NewString() + // store connection + p.connections.Store(connectionID, safeConn) + defer func() { // close the connection on exit err = safeConn.Close() if err != nil { - p.log.Error("connection close error", "error", err) + p.log.Error("connection close", "error", err) } - }() - // generate UUID from the connection - connectionID := uuid.NewString() - // store connection - p.connections.Store(connectionID, safeConn) - // when exiting - delete the connection - defer func() { + // when exiting - delete the connection p.connections.Delete(connectionID) }() - p.mu.Lock() // Executor wraps a connection to have a safe abstraction - e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.hub, r) - p.mu.Unlock() - + e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r) p.log.Info("websocket client connected", "uuid", connectionID) - defer e.CleanUp() err = e.StartCommandLoop() if err != nil { - p.log.Error("command loop error", "error", err.Error()) + p.log.Error("command loop error, disconnecting", "error", err.Error()) return } + + p.log.Info("disconnected", "connectionID", connectionID) }) } // Publish is an entry point to the websocket PUBSUB func (p *Plugin) Publish(msg []*pubsub.Message) error { - p.mu.Lock() - defer p.mu.Unlock() + p.Lock() + defer p.Unlock() for i := 0; i < len(msg); i++ { for j := 0; j < len(msg[i].Topics); j++ { @@ -210,8 +244,8 @@ func (p *Plugin) Publish(msg []*pubsub.Message) error { func (p *Plugin) PublishAsync(msg []*pubsub.Message) { go func() { - p.mu.Lock() - defer p.mu.Unlock() + p.Lock() + defer p.Unlock() for i := 0; i < len(msg); i++ { for j := 0; j < len(msg[i].Topics); j++ { err := p.pubsubs[msg[i].Broker].Publish(msg) @@ -223,3 +257,69 @@ func (p *Plugin) PublishAsync(msg []*pubsub.Message) { } }() } + +func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn { + return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) { + p.RLock() + defer p.RUnlock() + const op = errors.Op("access_validator") + + r = attributes.Init(r) + + // if channels len is eq to 0, we use serverValidator + if len(topics) == 0 { + ctx, err := validator.TopicsAccessValidator(r, topics...) + if err != nil { + return nil, errors.E(op, err) + } + + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + return val, nil + } + + ctx, err := validator.ServerAccessValidator(r) + if err != nil { + return nil, errors.E(op, err) + } + + pd := payload.Payload{ + Context: ctx, + } + + resp, err := pool.Exec(pd) + if err != nil { + return nil, errors.E(op, err) + } + + val := &validator.AccessValidator{ + Body: resp.Body, + } + + err = json.Unmarshal(resp.Context, val) + if err != nil { + return nil, errors.E(op, err) + } + + if val.Status != http.StatusOK { + return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status)) + } + + return val, nil + } +} diff --git a/plugins/websockets/schema/message.fbs b/plugins/websockets/schema/message.fbs new file mode 100644 index 00000000..f2d92c78 --- /dev/null +++ b/plugins/websockets/schema/message.fbs @@ -0,0 +1,10 @@ +namespace message; + +table Message { + command:string; + broker:string; + topics:[string]; + payload:[byte]; +} + +root_type Message; diff --git a/plugins/websockets/schema/message/Message.go b/plugins/websockets/schema/message/Message.go new file mode 100644 index 00000000..26bbd12c --- /dev/null +++ b/plugins/websockets/schema/message/Message.go @@ -0,0 +1,118 @@ +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package message + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type Message struct { + _tab flatbuffers.Table +} + +func GetRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Message{} + x.Init(buf, n+offset) + return x +} + +func GetSizePrefixedRootAsMessage(buf []byte, offset flatbuffers.UOffsetT) *Message { + n := flatbuffers.GetUOffsetT(buf[offset+flatbuffers.SizeUint32:]) + x := &Message{} + x.Init(buf, n+offset+flatbuffers.SizeUint32) + return x +} + +func (rcv *Message) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Message) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Message) Command() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Message) Broker() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Message) Topics(j int) []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j*4)) + } + return nil +} + +func (rcv *Message) TopicsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Message) Payload(j int) int8 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt8(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Message) PayloadLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Message) MutatePayload(j int, n int8) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateInt8(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func MessageStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func MessageAddCommand(builder *flatbuffers.Builder, command flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(command), 0) +} +func MessageAddBroker(builder *flatbuffers.Builder, broker flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(broker), 0) +} +func MessageAddTopics(builder *flatbuffers.Builder, topics flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(topics), 0) +} +func MessageStartTopicsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MessageAddPayload(builder *flatbuffers.Builder, payload flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(payload), 0) +} +func MessageStartPayloadVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MessageEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go index cd70d9a7..e666f846 100644 --- a/plugins/websockets/validator/access_validator.go +++ b/plugins/websockets/validator/access_validator.go @@ -1,138 +1,76 @@ package validator import ( - "bytes" - "io" "net/http" "strings" + json "github.com/json-iterator/go" "github.com/spiral/errors" - "github.com/spiral/roadrunner/v2/plugins/channel" + handler "github.com/spiral/roadrunner/v2/pkg/worker_handler" "github.com/spiral/roadrunner/v2/plugins/http/attributes" ) -type AccessValidator struct { - buffer *bytes.Buffer - header http.Header - status int -} - -func NewValidator() *AccessValidator { - return &AccessValidator{ - buffer: bytes.NewBuffer(nil), - header: make(http.Header), - } -} - -// Copy all content to parent response writer. -func (w *AccessValidator) Copy(rw http.ResponseWriter) { - rw.WriteHeader(w.status) - - for k, v := range w.header { - for _, vv := range v { - rw.Header().Add(k, vv) - } - } +type AccessValidatorFn = func(r *http.Request, channels ...string) (*AccessValidator, error) - _, _ = io.Copy(rw, w.buffer) -} - -// Header returns the header map that will be sent by WriteHeader. -func (w *AccessValidator) Header() http.Header { - return w.header -} - -// Write writes the data to the connection as part of an HTTP reply. -func (w *AccessValidator) Write(p []byte) (int, error) { - return w.buffer.Write(p) -} - -// WriteHeader sends an HTTP response header with the provided status code. -func (w *AccessValidator) WriteHeader(statusCode int) { - w.status = statusCode -} - -// IsOK returns true if response contained 200 status code. -func (w *AccessValidator) IsOK() bool { - return w.status == 200 -} - -// Body returns response body to rely to user. -func (w *AccessValidator) Body() []byte { - return w.buffer.Bytes() -} - -// Error contains server response. -func (w *AccessValidator) Error() string { - return w.buffer.String() +type AccessValidator struct { + Header http.Header `json:"headers"` + Status int `json:"status"` + Body []byte } -// AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in -// case of error -func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error { +func ServerAccessValidator(r *http.Request) ([]byte, error) { const op = errors.Op("server_access_validator") + err := attributes.Set(r, "ws:joinServer", true) if err != nil { - return errors.E(op, err) + return nil, errors.E(op, err) } defer delete(attributes.All(r), "ws:joinServer") - // send payload to the worker - hub.ToWorker() <- struct { - RW http.ResponseWriter - Req *http.Request - }{ - w, - r, + req := &handler.Request{ + RemoteAddr: handler.FetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: handler.URI(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), } - resp := <-hub.FromWorker() - - rmsg := resp.(struct { - RW http.ResponseWriter - Req *http.Request - }) - - if !rmsg.RW.(*AccessValidator).IsOK() { - return w + data, err := json.Marshal(req) + if err != nil { + return nil, errors.E(op, err) } - return nil + return data, nil } -// AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies. -// the decision to authorize user will be based on response code (200). -func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error { - const op = errors.Op("topics_access_validator") - - err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")) +func TopicsAccessValidator(r *http.Request, topics ...string) ([]byte, error) { + const op = errors.Op("topic_access_validator") + err := attributes.Set(r, "ws:joinTopics", strings.Join(topics, ",")) if err != nil { - return errors.E(op, err) + return nil, errors.E(op, err) } defer delete(attributes.All(r), "ws:joinTopics") - // send payload to worker - hub.ToWorker() <- struct { - RW http.ResponseWriter - Req *http.Request - }{ - w, - r, + req := &handler.Request{ + RemoteAddr: handler.FetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: handler.URI(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), } - // wait response - resp := <-hub.FromWorker() - - rmsg := resp.(struct { - RW http.ResponseWriter - Req *http.Request - }) - - if !rmsg.RW.(*AccessValidator).IsOK() { - return w + data, err := json.Marshal(req) + if err != nil { + return nil, errors.E(op, err) } - return nil + return data, nil } diff --git a/plugins/websockets/validator/access_validator_test.go b/plugins/websockets/validator/access_validator_test.go deleted file mode 100644 index 4a07b00f..00000000 --- a/plugins/websockets/validator/access_validator_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package validator - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestResponseWrapper_Body(t *testing.T) { - w := NewValidator() - _, _ = w.Write([]byte("hello")) - - assert.Equal(t, []byte("hello"), w.Body()) -} - -func TestResponseWrapper_Header(t *testing.T) { - w := NewValidator() - w.Header().Set("k", "value") - - assert.Equal(t, "value", w.Header().Get("k")) -} - -func TestResponseWrapper_StatusCode(t *testing.T) { - w := NewValidator() - w.WriteHeader(200) - - assert.True(t, w.IsOK()) -} - -func TestResponseWrapper_StatusCodeBad(t *testing.T) { - w := NewValidator() - w.WriteHeader(400) - - assert.False(t, w.IsOK()) -} |