diff options
-rw-r--r-- | plugins/channel/interface.go | 4 | ||||
-rw-r--r-- | plugins/channel/plugin.go | 24 | ||||
-rw-r--r-- | plugins/http/channel.go | 5 | ||||
-rw-r--r-- | plugins/http/plugin.go | 7 | ||||
-rw-r--r-- | plugins/websockets/plugin.go | 12 | ||||
-rw-r--r-- | plugins/websockets/validator/access_validator.go | 25 |
6 files changed, 45 insertions, 32 deletions
diff --git a/plugins/channel/interface.go b/plugins/channel/interface.go index 9a405e89..50fc9c96 100644 --- a/plugins/channel/interface.go +++ b/plugins/channel/interface.go @@ -3,6 +3,6 @@ package channel // Hub used as a channel between two or more plugins // this is not a PUBLIC plugin, API might be changed at any moment type Hub interface { - SendCh() chan interface{} - ReceiveCh() chan interface{} + FromWorker() chan interface{} + ToWorker() chan interface{} } diff --git a/plugins/channel/plugin.go b/plugins/channel/plugin.go index 8901c42e..9afd1264 100644 --- a/plugins/channel/plugin.go +++ b/plugins/channel/plugin.go @@ -10,15 +10,16 @@ const ( type Plugin struct { sync.Mutex - send chan interface{} - receive chan interface{} + fromCh chan interface{} + toCh chan interface{} } func (p *Plugin) Init() error { p.Lock() defer p.Unlock() - p.send = make(chan interface{}) - p.receive = make(chan interface{}) + + p.fromCh = make(chan interface{}) + p.toCh = make(chan interface{}) return nil } @@ -27,26 +28,23 @@ func (p *Plugin) Serve() chan error { } func (p *Plugin) Stop() error { - close(p.receive) return nil } -func (p *Plugin) SendCh() chan interface{} { +func (p *Plugin) FromWorker() chan interface{} { p.Lock() defer p.Unlock() - // bi-directional queue - return p.send + // one-directional queue + return p.fromCh } -func (p *Plugin) ReceiveCh() chan interface{} { +func (p *Plugin) ToWorker() chan interface{} { p.Lock() defer p.Unlock() - // bi-directional queue - return p.receive + // one-directional queue + return p.toCh } -func (p *Plugin) Available() {} - func (p *Plugin) Name() string { return PluginName } diff --git a/plugins/http/channel.go b/plugins/http/channel.go index 42b73730..23b5ff3e 100644 --- a/plugins/http/channel.go +++ b/plugins/http/channel.go @@ -6,7 +6,7 @@ import ( // messages method used to read messages from the ws plugin with the auth requests for the topics and server func (p *Plugin) messages() { - for msg := range p.hub.ReceiveCh() { + for msg := range p.hub.ToWorker() { p.RLock() // msg here is the structure with http.ResponseWriter and http.Request rmsg := msg.(struct { @@ -14,9 +14,10 @@ func (p *Plugin) messages() { Req *http.Request }) + // invoke handler with redirected responsewriter and request p.handler.ServeHTTP(rmsg.RW, rmsg.Req) - p.hub.SendCh() <- struct { + p.hub.FromWorker() <- struct { RW http.ResponseWriter Req *http.Request }{ diff --git a/plugins/http/plugin.go b/plugins/http/plugin.go index 38b3621f..397de7ae 100644 --- a/plugins/http/plugin.go +++ b/plugins/http/plugin.go @@ -75,7 +75,7 @@ type Plugin struct { // Init must return configure svc and return true if svc hasStatus enabled. Must return error in case of // misconfiguration. Services must not be used without proper configuration pushed first. -func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server server.Server, channel channel.Hub) error { +func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server server.Server, hub channel.Hub) error { const op = errors.Op("http_plugin_init") if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) @@ -109,9 +109,7 @@ func (p *Plugin) Init(cfg config.Configurer, rrLogger logger.Logger, server serv p.cfg.Env[RrMode] = "http" p.server = server - p.hub = channel - - go p.messages() + p.hub = hub return nil } @@ -128,6 +126,7 @@ func (p *Plugin) logCallback(event interface{}) { // Serve serves the svc. func (p *Plugin) Serve() chan error { errCh := make(chan error, 2) + go p.messages() // run whole process in the goroutine go func() { // protect http initialization diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index 2a060716..b3495e77 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -3,6 +3,7 @@ package websockets import ( "net/http" "sync" + "sync/atomic" "time" "github.com/fasthttp/websocket" @@ -39,6 +40,7 @@ type Plugin struct { // GO workers pool workersPool *pool.WorkersPool + stopped uint64 hub channel.Hub } @@ -59,6 +61,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel. p.storage = storage.NewStorage() p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log) p.hub = channel + p.stopped = 0 return nil } @@ -84,6 +87,7 @@ func (p *Plugin) Serve() chan error { } func (p *Plugin) Stop() error { + atomic.AddUint64(&p.stopped, 1) p.workersPool.Stop() return nil } @@ -118,7 +122,11 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } - p.mu.Lock() + + if atomic.CompareAndSwapUint64(&p.stopped, 1, 1) { + // plugin stopped + return + } r = attributes.Init(r) @@ -133,8 +141,6 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler { } } - p.mu.Unlock() - // connection upgrader upgraded := websocket.Upgrader{ HandshakeTimeout: time.Second * 60, diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go index e3fde3d0..cd70d9a7 100644 --- a/plugins/websockets/validator/access_validator.go +++ b/plugins/websockets/validator/access_validator.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/spiral/errors" "github.com/spiral/roadrunner/v2/plugins/channel" "github.com/spiral/roadrunner/v2/plugins/http/attributes" ) @@ -69,13 +70,16 @@ func (w *AccessValidator) Error() string { // AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in // case of error func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error { - if err := attributes.Set(r, "ws:joinServer", true); err != nil { - return err + const op = errors.Op("server_access_validator") + err := attributes.Set(r, "ws:joinServer", true) + if err != nil { + return errors.E(op, err) } defer delete(attributes.All(r), "ws:joinServer") - hub.ReceiveCh() <- struct { + // send payload to the worker + hub.ToWorker() <- struct { RW http.ResponseWriter Req *http.Request }{ @@ -83,7 +87,7 @@ func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) e r, } - resp := <-hub.SendCh() + resp := <-hub.FromWorker() rmsg := resp.(struct { RW http.ResponseWriter @@ -100,13 +104,17 @@ func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) e // AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies. // the decision to authorize user will be based on response code (200). func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error { - if err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")); err != nil { - return err + const op = errors.Op("topics_access_validator") + + err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")) + if err != nil { + return errors.E(op, err) } defer delete(attributes.All(r), "ws:joinTopics") - hub.ReceiveCh() <- struct { + // send payload to worker + hub.ToWorker() <- struct { RW http.ResponseWriter Req *http.Request }{ @@ -114,7 +122,8 @@ func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, c r, } - resp := <-hub.SendCh() + // wait response + resp := <-hub.FromWorker() rmsg := resp.(struct { RW http.ResponseWriter |