summaryrefslogtreecommitdiff
path: root/plugins/broadcast/websockets/service.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/broadcast/websockets/service.go')
-rw-r--r--plugins/broadcast/websockets/service.go228
1 files changed, 228 insertions, 0 deletions
diff --git a/plugins/broadcast/websockets/service.go b/plugins/broadcast/websockets/service.go
new file mode 100644
index 00000000..f3c0c781
--- /dev/null
+++ b/plugins/broadcast/websockets/service.go
@@ -0,0 +1,228 @@
+package websockets
+
+import (
+ "encoding/json"
+ "net/http"
+ "sync"
+ "sync/atomic"
+
+ "github.com/gorilla/websocket"
+)
+
+// ID defines service id.
+const ID = "ws"
+
+// Service to manage websocket clients.
+type Service struct {
+ cfg *Config
+ upgrade websocket.Upgrader
+ client *broadcast.Client
+ connPool *connPool
+ listeners []func(event int, ctx interface{})
+ mu sync.Mutex
+ stopped int32
+ stop chan error
+}
+
+// AddListener attaches server event controller.
+func (s *Service) AddListener(l func(event int, ctx interface{})) {
+ s.listeners = append(s.listeners, l)
+}
+
+// Init the service.
+func (s *Service) Init(
+ cfg *Config,
+ env env.Environment,
+ rttp *rhttp.Service,
+ rpc *rpc.Service,
+ broadcast *broadcast.Service,
+) (bool, error) {
+ if broadcast == nil || rpc == nil {
+ // unable to activate
+ return false, nil
+ }
+
+ s.cfg = cfg
+ s.client = broadcast.NewClient()
+ s.connPool = newPool(s.client, s.reportError)
+ s.stopped = 0
+
+ if err := rpc.Register(ID, &rpcService{svc: s}); err != nil {
+ return false, err
+ }
+
+ if env != nil {
+ // ensure that underlying kernel knows what route to handle
+ env.SetEnv("RR_BROADCAST_PATH", cfg.Path)
+ }
+
+ // init all this stuff
+ s.upgrade = websocket.Upgrader{}
+
+ if s.cfg.NoOrigin {
+ s.upgrade.CheckOrigin = func(r *http.Request) bool {
+ return true
+ }
+ }
+
+ rttp.AddMiddleware(s.middleware)
+
+ return true, nil
+}
+
+// Serve the websocket connections.
+func (s *Service) Serve() error {
+ defer s.client.Close()
+ defer s.connPool.close()
+
+ s.mu.Lock()
+ s.stop = make(chan error)
+ s.mu.Unlock()
+
+ return <-s.stop
+}
+
+// Stop the service and disconnect all connections.
+func (s *Service) Stop() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if atomic.CompareAndSwapInt32(&s.stopped, 0, 1) {
+ close(s.stop)
+ }
+}
+
+// middleware intercepts websocket connections.
+func (s *Service) middleware(f http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != s.cfg.Path {
+ f(w, r)
+ return
+ }
+
+ // checking server access
+ if err := newValidator().assertServerAccess(f, r); err != nil {
+ // show the error to the user
+ if av, ok := err.(*accessValidator); ok {
+ av.copy(w)
+ } else {
+ w.WriteHeader(400)
+ }
+ return
+ }
+
+ conn, err := s.upgrade.Upgrade(w, r, nil)
+ if err != nil {
+ s.reportError(err, nil)
+ return
+ }
+
+ s.throw(EventConnect, conn)
+
+ // manage connection
+ ctx, err := s.connPool.connect(conn)
+ if err != nil {
+ s.reportError(err, conn)
+ return
+ }
+
+ s.serveConn(ctx, f, r)
+ }
+}
+
+// send and receive messages over websocket
+func (s *Service) serveConn(ctx *ConnContext, f http.HandlerFunc, r *http.Request) {
+ defer func() {
+ if err := s.connPool.disconnect(ctx.Conn); err != nil {
+ s.reportError(err, ctx.Conn)
+ }
+ s.throw(EventDisconnect, ctx.Conn)
+ }()
+
+ s.handleCommands(ctx, f, r)
+}
+
+func (s *Service) handleCommands(ctx *ConnContext, f http.HandlerFunc, r *http.Request) {
+ cmd := &broadcast.Message{}
+ for {
+ if err := ctx.Conn.ReadJSON(cmd); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ switch cmd.Topic {
+ case "join":
+ topics := make([]string, 0)
+ if err := unmarshalCommand(cmd, &topics); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ if len(topics) == 0 {
+ continue
+ }
+
+ if err := newValidator().assertTopicsAccess(f, r, topics...); err != nil {
+ s.reportError(err, ctx.Conn)
+
+ if err := ctx.SendMessage("#join", topics); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ continue
+ }
+
+ if err := s.connPool.subscribe(ctx, topics...); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ if err := ctx.SendMessage("@join", topics); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ s.throw(EventJoin, &TopicEvent{Conn: ctx.Conn, Topics: topics})
+ case "leave":
+ topics := make([]string, 0)
+ if err := unmarshalCommand(cmd, &topics); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ if len(topics) == 0 {
+ continue
+ }
+
+ if err := s.connPool.unsubscribe(ctx, topics...); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ if err := ctx.SendMessage("@leave", topics); err != nil {
+ s.reportError(err, ctx.Conn)
+ return
+ }
+
+ s.throw(EventLeave, &TopicEvent{Conn: ctx.Conn, Topics: topics})
+ }
+ }
+}
+
+// handle connection error
+func (s *Service) reportError(err error, conn *websocket.Conn) {
+ s.throw(EventError, &ErrorEvent{Conn: conn, Error: err})
+}
+
+// throw handles service, server and pool events.
+func (s *Service) throw(event int, ctx interface{}) {
+ for _, l := range s.listeners {
+ l(event, ctx)
+ }
+}
+
+// unmarshalCommand command data.
+func unmarshalCommand(msg *broadcast.Message, v interface{}) error {
+ return json.Unmarshal(msg.Payload, v)
+}