summaryrefslogtreecommitdiff
path: root/plugins/websockets/plugin.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/websockets/plugin.go')
-rw-r--r--plugins/websockets/plugin.go208
1 files changed, 154 insertions, 54 deletions
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
index c51c7ca1..4c722860 100644
--- a/plugins/websockets/plugin.go
+++ b/plugins/websockets/plugin.go
@@ -1,19 +1,23 @@
package websockets
import (
+ "context"
"net/http"
"sync"
"time"
"github.com/fasthttp/websocket"
"github.com/google/uuid"
+ json "github.com/json-iterator/go"
endure "github.com/spiral/endure/pkg/container"
"github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/payload"
+ phpPool "github.com/spiral/roadrunner/v2/pkg/pool"
"github.com/spiral/roadrunner/v2/pkg/pubsub"
- "github.com/spiral/roadrunner/v2/plugins/channel"
"github.com/spiral/roadrunner/v2/plugins/config"
"github.com/spiral/roadrunner/v2/plugins/http/attributes"
"github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/plugins/server"
"github.com/spiral/roadrunner/v2/plugins/websockets/connection"
"github.com/spiral/roadrunner/v2/plugins/websockets/executor"
"github.com/spiral/roadrunner/v2/plugins/websockets/pool"
@@ -26,12 +30,12 @@ const (
)
type Plugin struct {
- mu sync.RWMutex
+ sync.RWMutex
// Collection with all available pubsubs
pubsubs map[string]pubsub.PubSub
- Config *Config
- log logger.Logger
+ cfg *Config
+ log logger.Logger
// global connections map
connections sync.Map
@@ -40,25 +44,38 @@ type Plugin struct {
// GO workers pool
workersPool *pool.WorkersPool
- hub channel.Hub
+ wsUpgrade *websocket.Upgrader
+ serveExit chan struct{}
+
+ phpPool phpPool.Pool
+ server server.Server
+
+ // function used to validate access to the requested resource
+ accessValidator validator.AccessValidatorFn
}
-func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.Hub) error {
+func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error {
const op = errors.Op("websockets_plugin_init")
if !cfg.Has(PluginName) {
return errors.E(op, errors.Disabled)
}
- err := cfg.UnmarshalKey(PluginName, &p.Config)
+ err := cfg.UnmarshalKey(PluginName, &p.cfg)
if err != nil {
return errors.E(op, err)
}
+ p.cfg.InitDefault()
+
p.pubsubs = make(map[string]pubsub.PubSub)
p.log = log
p.storage = storage.NewStorage()
p.workersPool = pool.NewWorkersPool(p.storage, &p.connections, log)
- p.hub = channel
+ p.wsUpgrade = &websocket.Upgrader{
+ HandshakeTimeout: time.Second * 60,
+ }
+ p.serveExit = make(chan struct{})
+ p.server = server
return nil
}
@@ -66,25 +83,54 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, channel channel.
func (p *Plugin) Serve() chan error {
errCh := make(chan error)
+ go func() {
+ var err error
+ p.Lock()
+ defer p.Unlock()
+
+ p.phpPool, err = p.server.NewWorkerPool(context.Background(), phpPool.Config{
+ Debug: p.cfg.Pool.Debug,
+ NumWorkers: p.cfg.Pool.NumWorkers,
+ MaxJobs: p.cfg.Pool.MaxJobs,
+ AllocateTimeout: p.cfg.Pool.AllocateTimeout,
+ DestroyTimeout: p.cfg.Pool.DestroyTimeout,
+ Supervisor: p.cfg.Pool.Supervisor,
+ }, map[string]string{"RR_MODE": "http"})
+ if err != nil {
+ errCh <- err
+ }
+
+ p.accessValidator = p.defaultAccessValidator(p.phpPool)
+ }()
+
// run all pubsubs drivers
for _, v := range p.pubsubs {
go func(ps pubsub.PubSub) {
for {
- data, err := ps.Next()
- if err != nil {
- errCh <- err
+ select {
+ case <-p.serveExit:
return
+ default:
+ data, err := ps.Next()
+ if err != nil {
+ errCh <- err
+ return
+ }
+ p.workersPool.Queue(data)
}
-
- p.workersPool.Queue(data)
}
}(v)
}
+
return errCh
}
func (p *Plugin) Stop() error {
+ // close workers pool
p.workersPool.Stop()
+ p.Lock()
+ p.phpPool.Destroy(context.Background())
+ p.Unlock()
return nil
}
@@ -114,84 +160,72 @@ func (p *Plugin) GetPublishers(name endure.Named, pub pubsub.PubSub) {
func (p *Plugin) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != p.Config.Path {
+ if r.URL.Path != p.cfg.Path {
next.ServeHTTP(w, r)
return
}
- r = attributes.Init(r)
-
- err := validator.NewValidator().AssertServerAccess(p.hub, r)
+ // we need to lock here, because accessValidator might not be set in the Serve func at the moment
+ p.RLock()
+ // before we hijacked connection, we still can write to the response headers
+ val, err := p.accessValidator(r)
+ p.RUnlock()
if err != nil {
- // show the error to the user
- if av, ok := err.(*validator.AccessValidator); ok {
- av.Copy(w)
- } else {
- w.WriteHeader(400)
- return
- }
+ w.WriteHeader(400)
+ return
}
- // connection upgrader
- upgraded := websocket.Upgrader{
- HandshakeTimeout: time.Second * 60,
- ReadBufferSize: 0,
- WriteBufferSize: 0,
- WriteBufferPool: nil,
- Subprotocols: nil,
- Error: nil,
- CheckOrigin: nil,
- EnableCompression: false,
+ if val.Status != http.StatusOK {
+ _, _ = w.Write(val.Body)
+ w.WriteHeader(val.Status)
+ return
}
// upgrade connection to websocket connection
- _conn, err := upgraded.Upgrade(w, r, nil)
+ _conn, err := p.wsUpgrade.Upgrade(w, r, nil)
if err != nil {
// connection hijacked, do not use response.writer or request
- p.log.Error("upgrade connection error", "error", err)
+ p.log.Error("upgrade connection", "error", err)
return
}
// construct safe connection protected by mutexes
safeConn := connection.NewConnection(_conn, p.log)
+ // generate UUID from the connection
+ connectionID := uuid.NewString()
+ // store connection
+ p.connections.Store(connectionID, safeConn)
+
defer func() {
// close the connection on exit
err = safeConn.Close()
if err != nil {
- p.log.Error("connection close error", "error", err)
+ p.log.Error("connection close", "error", err)
}
- }()
- // generate UUID from the connection
- connectionID := uuid.NewString()
- // store connection
- p.connections.Store(connectionID, safeConn)
- // when exiting - delete the connection
- defer func() {
+ // when exiting - delete the connection
p.connections.Delete(connectionID)
}()
- p.mu.Lock()
// Executor wraps a connection to have a safe abstraction
- e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.hub, r)
- p.mu.Unlock()
-
+ e := executor.NewExecutor(safeConn, p.log, p.storage, connectionID, p.pubsubs, p.accessValidator, r)
p.log.Info("websocket client connected", "uuid", connectionID)
-
defer e.CleanUp()
err = e.StartCommandLoop()
if err != nil {
- p.log.Error("command loop error", "error", err.Error())
+ p.log.Error("command loop error, disconnecting", "error", err.Error())
return
}
+
+ p.log.Info("disconnected", "connectionID", connectionID)
})
}
// Publish is an entry point to the websocket PUBSUB
func (p *Plugin) Publish(msg []*pubsub.Message) error {
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.Lock()
+ defer p.Unlock()
for i := 0; i < len(msg); i++ {
for j := 0; j < len(msg[i].Topics); j++ {
@@ -210,8 +244,8 @@ func (p *Plugin) Publish(msg []*pubsub.Message) error {
func (p *Plugin) PublishAsync(msg []*pubsub.Message) {
go func() {
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.Lock()
+ defer p.Unlock()
for i := 0; i < len(msg); i++ {
for j := 0; j < len(msg[i].Topics); j++ {
err := p.pubsubs[msg[i].Broker].Publish(msg)
@@ -223,3 +257,69 @@ func (p *Plugin) PublishAsync(msg []*pubsub.Message) {
}
}()
}
+
+func (p *Plugin) defaultAccessValidator(pool phpPool.Pool) validator.AccessValidatorFn {
+ return func(r *http.Request, topics ...string) (*validator.AccessValidator, error) {
+ p.RLock()
+ defer p.RUnlock()
+ const op = errors.Op("access_validator")
+
+ r = attributes.Init(r)
+
+ // if channels len is eq to 0, we use serverValidator
+ if len(topics) == 0 {
+ ctx, err := validator.TopicsAccessValidator(r, topics...)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ pd := payload.Payload{
+ Context: ctx,
+ }
+
+ resp, err := pool.Exec(pd)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val := &validator.AccessValidator{
+ Body: resp.Body,
+ }
+ err = json.Unmarshal(resp.Context, val)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ return val, nil
+ }
+
+ ctx, err := validator.ServerAccessValidator(r)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ pd := payload.Payload{
+ Context: ctx,
+ }
+
+ resp, err := pool.Exec(pd)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ val := &validator.AccessValidator{
+ Body: resp.Body,
+ }
+
+ err = json.Unmarshal(resp.Context, val)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ if val.Status != http.StatusOK {
+ return val, errors.E(op, errors.Errorf("access forbidden, code: %d", val.Status))
+ }
+
+ return val, nil
+ }
+}