summaryrefslogtreecommitdiff
path: root/plugins/websockets
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-05-29 11:27:49 +0300
committerValery Piashchynski <[email protected]>2021-05-29 11:27:49 +0300
commit09b982813f8825f776abf20fb16c6085439ca4ba (patch)
tree1c4593bdc42503616b06f32bb6ee676cca38515a /plugins/websockets
parentfcda08498e8f914bbd0798da898818cd5d0e4348 (diff)
- Update channel plugin interfaces
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/websockets')
-rw-r--r--plugins/websockets/plugin.go12
-rw-r--r--plugins/websockets/validator/access_validator.go25
2 files changed, 26 insertions, 11 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
index 2a060716..b3495e77 100644
--- a/plugins/websockets/plugin.go
+++ b/plugins/websockets/plugin.go
@@ -3,6 +3,7 @@ package websockets
import (
"net/http"
"sync"
+ "sync/atomic"
"time"
"github.com/fasthttp/websocket"
@@ -39,6 +40,7 @@ type Plugin struct {
// GO workers pool
workersPool *pool.WorkersPool
+ stopped uint64
hub channel.Hub
}
@@ -59,6 +61,7 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.
p.storage = storage.NewStorage()
p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log)
p.hub = channel
+ p.stopped = 0
return nil
}
@@ -84,6 +87,7 @@ func (p *Plugin) Serve() chan error {
}
func (p *Plugin) Stop() error {
+ atomic.AddUint64(&p.stopped, 1)
p.workersPool.Stop()
return nil
}
@@ -118,7 +122,11 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
return
}
- p.mu.Lock()
+
+ if atomic.CompareAndSwapUint64(&p.stopped, 1, 1) {
+ // plugin stopped
+ return
+ }
r = attributes.Init(r)
@@ -133,8 +141,6 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
}
}
- p.mu.Unlock()
-
// connection upgrader
upgraded := websocket.Upgrader{
HandshakeTimeout: time.Second * 60,
diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go
index e3fde3d0..cd70d9a7 100644
--- a/plugins/websockets/validator/access_validator.go
+++ b/plugins/websockets/validator/access_validator.go
@@ -6,6 +6,7 @@ import (
"net/http"
"strings"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/plugins/channel"
"github.com/spiral/roadrunner/v2/plugins/http/attributes"
)
@@ -69,13 +70,16 @@ func (w *AccessValidator) Error() string {
// AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in
// case of error
func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error {
- if err := attributes.Set(r, "ws:joinServer", true); err != nil {
- return err
+ const op = errors.Op("server_access_validator")
+ err := attributes.Set(r, "ws:joinServer", true)
+ if err != nil {
+ return errors.E(op, err)
}
defer delete(attributes.All(r), "ws:joinServer")
- hub.ReceiveCh() <- struct {
+ // send payload to the worker
+ hub.ToWorker() <- struct {
RW http.ResponseWriter
Req *http.Request
}{
@@ -83,7 +87,7 @@ func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) e
r,
}
- resp := <-hub.SendCh()
+ resp := <-hub.FromWorker()
rmsg := resp.(struct {
RW http.ResponseWriter
@@ -100,13 +104,17 @@ func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) e
// AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies.
// the decision to authorize user will be based on response code (200).
func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error {
- if err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")); err != nil {
- return err
+ const op = errors.Op("topics_access_validator")
+
+ err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ","))
+ if err != nil {
+ return errors.E(op, err)
}
defer delete(attributes.All(r), "ws:joinTopics")
- hub.ReceiveCh() <- struct {
+ // send payload to worker
+ hub.ToWorker() <- struct {
RW http.ResponseWriter
Req *http.Request
}{
@@ -114,7 +122,8 @@ func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, c
r,
}
- resp := <-hub.SendCh()
+ // wait response
+ resp := <-hub.FromWorker()
rmsg := resp.(struct {
RW http.ResponseWriter