diff options
Diffstat (limited to 'plugins/broadcast')
38 files changed, 2868 insertions, 0 deletions
diff --git a/plugins/broadcast/config.go b/plugins/broadcast/config.go new file mode 100644 index 00000000..aa270f64 --- /dev/null +++ b/plugins/broadcast/config.go @@ -0,0 +1 @@ +package broadcast diff --git a/plugins/broadcast/doc/.rr-broadcast.yaml b/plugins/broadcast/doc/.rr-broadcast.yaml new file mode 100644 index 00000000..a0a2ad5e --- /dev/null +++ b/plugins/broadcast/doc/.rr-broadcast.yaml @@ -0,0 +1,10 @@ +# broadcast service configuration.rr.yaml +broadcast: + # path to enable web-socket handler middleware + path: /ws + + # optional, redis broker configuration + redis: + addr: "localhost:6379" + passsword: "" + db: 0 diff --git a/plugins/broadcast/doc/broadcast.drawio b/plugins/broadcast/doc/broadcast.drawio new file mode 100644 index 00000000..f9845dc8 --- /dev/null +++ b/plugins/broadcast/doc/broadcast.drawio @@ -0,0 +1 @@ +<mxfile host="Electron" modified="2021-05-03T16:49:17.087Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/14.5.1 Chrome/89.0.4389.128 Electron/12.0.6 Safari/537.36" etag="RZNtN_6682KfuWpR1T35" version="14.5.1" type="device"><diagram id="fD2kwGC0DAS2S_q_IsmE" name="Page-1">1ZhdU6MwFIZ/TWd2L9xJCVB72a9VZ6y6Vkd7mcIBsqaECcG2/voNEAqIdXSEddpelLzJCcl78pCGHp6st2eCRMGcu8B6BnK3PTztGUYfY6R+UmWXK0Nk5YIvqKsblcKCvoAWdZyfUBfiWkPJOZM0qosOD0NwZE0jQvBNvZnHWf2uEfGhISwcwprqA3VloNWhgcqKc6B+UNzaQLpmTYrWWogD4vJNRcKzHp4IzmV+td5OgKXuFcbkcb8P1O5HJiCUHwnYmfPLu+WYhM4dZk9/BEXh6GSoxyZ3xYzBVQboIhcy4D4PCZuV6ljwJHQh7RWpUtnmkvNIiX0l/gUpdzqbJJFcSYFcM10LWyofK9fLtKtfli5Nt7rnrLDThXyc6eAOTl9LMU+EA+/Mub83Xy1b4GuQYqfiBDAi6XO9f6LXj79vtw+94VTd2UB6rdtFnvVKN01U70IS4YPUUWWe1EVlGKWUZe8TmdQDfiYs0VMY316PppPR4k7J5/fKQfTj5vL+7OLqZyPp9ZRuAiphEZHMxY0Cu54+jzI24YyLLBa7BE49R+mxFPwJKjW2cworb5+8ZxAStu+nr5kWHWDiV/YWT4hNBcsiBUGFSBsdzmQtB5813DoWdJTDYvdYNkyLy2pdGZaVOkDO0I/uDIEO0fxSQo0GQQ+LVknxrPT7JinZJ43goazo+acdgvZ7qiZI7cUfI+j1g6w1w+3jI6jGT4lT9wThYyAINwi6nU0vWobI8wznze3GtVe2ZXcDi2l8NyyD44Olvt30/x8s5jHAYjZgubg6mc/m17fLloEB+wAwg+EKoW6AsezvBua04a+aOlOHxDj1NlmdRCzxaRg3zFYmyLqjhFE/VNeOMgOUeePUKqpOiSNdsaaumxMHMX0hq6yr1Nko/bufzcwa96xp2peCLM5567djPX5lPR580HqjK+uLAVU3ghvVFRrft7sZQN+1YPDW2h7aA0xa2gwaR7sOzx6qWL4SyM+G5ZsVPPsH</diagram></mxfile>
\ No newline at end of file diff --git a/plugins/broadcast/memory/memory.go b/plugins/broadcast/memory/memory.go new file mode 100644 index 00000000..5b85d68f --- /dev/null +++ b/plugins/broadcast/memory/memory.go @@ -0,0 +1,131 @@ +package memory + +import ( + "errors" + "sync/atomic" +) + +// Memory manages broadcasting in memory. +type Memory struct { + router *Router + messages chan *Message + join, leave chan subscriber + stop chan interface{} + stopped int32 +} + +// memoryBroker creates new memory based message broker. +func memoryBroker() *Memory { + return &Memory{ + router: NewRouter(), + messages: make(chan *Message), + join: make(chan subscriber), + leave: make(chan subscriber), + stop: make(chan interface{}), + stopped: 0, + } +} + +// Serve serves broker. +func (m *Memory) Serve() error { + for { + select { + case ctx := <-m.join: + ctx.done <- m.handleJoin(ctx) + case ctx := <-m.leave: + ctx.done <- m.handleLeave(ctx) + case msg := <-m.messages: + m.router.Dispatch(msg) + case <-m.stop: + return nil + } + } +} + +func (m *Memory) handleJoin(sub subscriber) (err error) { + if sub.pattern != "" { + _, err = m.router.SubscribePattern(sub.upstream, sub.pattern) + return err + } + + m.router.Subscribe(sub.upstream, sub.topics...) + return nil +} + +func (m *Memory) handleLeave(sub subscriber) error { + if sub.pattern != "" { + m.router.UnsubscribePattern(sub.upstream, sub.pattern) + return nil + } + + m.router.Unsubscribe(sub.upstream, sub.topics...) + return nil +} + +// Stop closes the consumption and disconnects broker. +func (m *Memory) Stop() { + if atomic.CompareAndSwapInt32(&m.stopped, 0, 1) { + close(m.stop) + } +} + +// Subscribe broker to one or multiple channels. +func (m *Memory) Subscribe(upstream chan *Message, topics ...string) error { + if atomic.LoadInt32(&m.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, topics: topics, done: make(chan error)} + + m.join <- ctx + return <-ctx.done +} + +// SubscribePattern broker to pattern. +func (m *Memory) SubscribePattern(upstream chan *Message, pattern string) error { + if atomic.LoadInt32(&m.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, pattern: pattern, done: make(chan error)} + + m.join <- ctx + return <-ctx.done +} + +// Unsubscribe broker from one or multiple channels. +func (m *Memory) Unsubscribe(upstream chan *Message, topics ...string) error { + if atomic.LoadInt32(&m.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, topics: topics, done: make(chan error)} + + m.leave <- ctx + return <-ctx.done +} + +// UnsubscribePattern broker from pattern. +func (m *Memory) UnsubscribePattern(upstream chan *Message, pattern string) error { + if atomic.LoadInt32(&m.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, pattern: pattern, done: make(chan error)} + + m.leave <- ctx + return <-ctx.done +} + +// Publish one or multiple Channel. +func (m *Memory) Publish(messages ...*Message) error { + if atomic.LoadInt32(&m.stopped) == 1 { + return errors.New("broker has been stopped") + } + + for _, msg := range messages { + m.messages <- msg + } + + return nil +} diff --git a/plugins/broadcast/memory/memory_test.go b/plugins/broadcast/memory/memory_test.go new file mode 100644 index 00000000..0eb8d03e --- /dev/null +++ b/plugins/broadcast/memory/memory_test.go @@ -0,0 +1,80 @@ +package memory + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMemory_Broadcast(t *testing.T) { + br, _, c := setup(`{}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) // must not be delivered + + assert.NoError(t, client.Subscribe("topic")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello2"))) + assert.Equal(t, `hello2`, readStr(<-client.Channel())) + + assert.NoError(t, client.Unsubscribe("topic")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello3"))) + + assert.NoError(t, client.Subscribe("topic")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello4"))) + assert.Equal(t, `hello4`, readStr(<-client.Channel())) +} + +func TestMemory_BroadcastPattern(t *testing.T) { + br, _, c := setup(`{}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) // must not be delivered + + assert.NoError(t, client.SubscribePattern("topic/*")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/1", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, client.Publish(newMessage("topic/1", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/2", "hello2"))) + assert.Equal(t, `hello2`, readStr(<-client.Channel())) + + assert.NoError(t, br.Broker().Publish(newMessage("different", "hello4"))) + assert.NoError(t, br.Broker().Publish(newMessage("topic/2", "hello5"))) + + assert.Equal(t, `hello5`, readStr(<-client.Channel())) + + assert.NoError(t, client.UnsubscribePattern("topic/*")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/3", "hello6"))) + + assert.NoError(t, client.SubscribePattern("topic/*")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/4", "hello7"))) + assert.Equal(t, `hello7`, readStr(<-client.Channel())) +} + +func TestMemory_NotActive(t *testing.T) { + b := memoryBroker() + b.stopped = 1 + + assert.Error(t, b.Publish(nil)) + assert.Error(t, b.Subscribe(nil)) + assert.Error(t, b.Unsubscribe(nil)) + assert.Error(t, b.SubscribePattern(nil, "")) + assert.Error(t, b.UnsubscribePattern(nil, "")) +} diff --git a/plugins/broadcast/plugin.go b/plugins/broadcast/plugin.go new file mode 100644 index 00000000..3cedf555 --- /dev/null +++ b/plugins/broadcast/plugin.go @@ -0,0 +1,11 @@ +package broadcast + + +type Plugin struct { + +} + + +func (p *Plugin) Init() error { + return nil +} diff --git a/plugins/broadcast/redis/redis.go b/plugins/broadcast/redis/redis.go new file mode 100644 index 00000000..41f48658 --- /dev/null +++ b/plugins/broadcast/redis/redis.go @@ -0,0 +1,172 @@ +package redis + +import ( + "context" + "errors" + "sync/atomic" + + "github.com/go-redis/redis/v8" +) + +// Redis based broadcast Router. +type Redis struct { + client redis.UniversalClient + psClient redis.UniversalClient + router *Router + messages chan *Message + listen, leave chan subscriber + stop chan interface{} + stopped int32 +} + +// creates new redis broker +func redisBroker(cfg *RedisConfig) (*Redis, error) { + client := cfg.redisClient() + if _, err := client.Ping(context.Background()).Result(); err != nil { + return nil, err + } + + psClient := cfg.redisClient() + if _, err := psClient.Ping(context.Background()).Result(); err != nil { + return nil, err + } + + return &Redis{ + client: client, + psClient: psClient, + router: NewRouter(), + messages: make(chan *Message), + listen: make(chan subscriber), + leave: make(chan subscriber), + stop: make(chan interface{}), + stopped: 0, + }, nil +} + +// Serve serves broker. +func (r *Redis) Serve() error { + pubsub := r.psClient.Subscribe(context.Background()) + channel := pubsub.Channel() + + for { + select { + case ctx := <-r.listen: + ctx.done <- r.handleJoin(ctx, pubsub) + case ctx := <-r.leave: + ctx.done <- r.handleLeave(ctx, pubsub) + case msg := <-channel: + r.router.Dispatch(&Message{ + Topic: msg.Channel, + Payload: []byte(msg.Payload), + }) + case <-r.stop: + return nil + } + } +} + +func (r *Redis) handleJoin(sub subscriber, pubsub *redis.PubSub) error { + if sub.pattern != "" { + newPatterns, err := r.router.SubscribePattern(sub.upstream, sub.pattern) + if err != nil || len(newPatterns) == 0 { + return err + } + + return pubsub.PSubscribe(context.Background(), newPatterns...) + } + + newTopics := r.router.Subscribe(sub.upstream, sub.topics...) + if len(newTopics) == 0 { + return nil + } + + return pubsub.Subscribe(context.Background(), newTopics...) +} + +func (r *Redis) handleLeave(sub subscriber, pubsub *redis.PubSub) error { + if sub.pattern != "" { + dropPatterns := r.router.UnsubscribePattern(sub.upstream, sub.pattern) + if len(dropPatterns) == 0 { + return nil + } + + return pubsub.PUnsubscribe(context.Background(), dropPatterns...) + } + + dropTopics := r.router.Unsubscribe(sub.upstream, sub.topics...) + if len(dropTopics) == 0 { + return nil + } + + return pubsub.Unsubscribe(context.Background(), dropTopics...) +} + +// Stop closes the consumption and disconnects broker. +func (r *Redis) Stop() { + if atomic.CompareAndSwapInt32(&r.stopped, 0, 1) { + close(r.stop) + } +} + +// Subscribe broker to one or multiple channels. +func (r *Redis) Subscribe(upstream chan *Message, topics ...string) error { + if atomic.LoadInt32(&r.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, topics: topics, done: make(chan error)} + + r.listen <- ctx + return <-ctx.done +} + +// SubscribePattern broker to pattern. +func (r *Redis) SubscribePattern(upstream chan *Message, pattern string) error { + if atomic.LoadInt32(&r.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, pattern: pattern, done: make(chan error)} + + r.listen <- ctx + return <-ctx.done +} + +// Unsubscribe broker from one or multiple channels. +func (r *Redis) Unsubscribe(upstream chan *Message, topics ...string) error { + if atomic.LoadInt32(&r.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, topics: topics, done: make(chan error)} + + r.leave <- ctx + return <-ctx.done +} + +// UnsubscribePattern broker from pattern. +func (r *Redis) UnsubscribePattern(upstream chan *Message, pattern string) error { + if atomic.LoadInt32(&r.stopped) == 1 { + return errors.New("broker has been stopped") + } + + ctx := subscriber{upstream: upstream, pattern: pattern, done: make(chan error)} + + r.leave <- ctx + return <-ctx.done +} + +// Publish one or multiple Channel. +func (r *Redis) Publish(messages ...*Message) error { + if atomic.LoadInt32(&r.stopped) == 1 { + return errors.New("broker has been stopped") + } + + for _, msg := range messages { + if err := r.client.Publish(context.Background(), msg.Topic, []byte(msg.Payload)).Err(); err != nil { + return err + } + } + + return nil +} diff --git a/plugins/broadcast/redis/redis_test.go b/plugins/broadcast/redis/redis_test.go new file mode 100644 index 00000000..37027e01 --- /dev/null +++ b/plugins/broadcast/redis/redis_test.go @@ -0,0 +1,98 @@ +package redis + +import ( + "fmt" + "testing" + + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" +) + +func TestRedis_Error(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + //c := service.NewContainer(logger) + //c.Register(rpc.ID, &rpc.Service{}) + //c.Register(ID, &Service{}) + // + //err := c.Init(&testCfg{ + // broadcast: `{"redis":{"addr":"localhost:6372"}}`, + // rpc: fmt.Sprintf(`{"join":"tcp://:%v"}`, rpcPort), + //}) + + rpcPort++ + + assert.Error(t, err) +} + +func TestRedis_Broadcast(t *testing.T) { + br, _, c := setup(`{"redis":{"addr":"localhost:6379"}}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) // must not be delivered + + assert.NoError(t, client.Subscribe("topic")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello2"))) + assert.Equal(t, `hello2`, readStr(<-client.Channel())) + + assert.NoError(t, client.Unsubscribe("topic")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello3"))) + + assert.NoError(t, client.Subscribe("topic")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello4"))) + assert.Equal(t, `hello4`, readStr(<-client.Channel())) +} + +func TestRedis_BroadcastPattern(t *testing.T) { + br, _, c := setup(`{"redis":{"addr":"localhost:6379"}}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) // must not be delivered + + assert.NoError(t, client.SubscribePattern("topic/*")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/1", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/2", "hello2"))) + assert.Equal(t, `hello2`, readStr(<-client.Channel())) + + assert.NoError(t, br.Broker().Publish(newMessage("different", "hello4"))) + assert.NoError(t, br.Broker().Publish(newMessage("topic/2", "hello5"))) + + assert.Equal(t, `hello5`, readStr(<-client.Channel())) + + assert.NoError(t, client.UnsubscribePattern("topic/*")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/3", "hello6"))) + + assert.NoError(t, client.SubscribePattern("topic/*")) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/4", "hello7"))) + assert.Equal(t, `hello7`, readStr(<-client.Channel())) +} + +func TestRedis_NotActive(t *testing.T) { + b := &Redis{} + b.stopped = 1 + + assert.Error(t, b.Publish(nil)) + assert.Error(t, b.Subscribe(nil)) + assert.Error(t, b.Unsubscribe(nil)) + assert.Error(t, b.SubscribePattern(nil, "")) + assert.Error(t, b.UnsubscribePattern(nil, "")) +} diff --git a/plugins/broadcast/root/Makefile b/plugins/broadcast/root/Makefile new file mode 100644 index 00000000..d88312d2 --- /dev/null +++ b/plugins/broadcast/root/Makefile @@ -0,0 +1,9 @@ +clean: + rm -rf rr-jobbroadcast +install: all + cp rr-broadcast /usr/local/bin/rr-broadcast +uninstall: + rm -f /usr/local/bin/rr-broadcast +test: + composer update + go test -v -race -cover diff --git a/plugins/broadcast/root/broker.go b/plugins/broadcast/root/broker.go new file mode 100644 index 00000000..923c8105 --- /dev/null +++ b/plugins/broadcast/root/broker.go @@ -0,0 +1,36 @@ +package broadcast + +import "encoding/json" + +// Broker defines the ability to operate as message passing broker. +type Broker interface { + // Serve serves broker. + Serve() error + + // Stop closes the consumption and disconnects broker. + Stop() + + // Subscribe broker to one or multiple topics. + Subscribe(upstream chan *Message, topics ...string) error + + // SubscribePattern broker to pattern. + SubscribePattern(upstream chan *Message, pattern string) error + + // Unsubscribe broker from one or multiple topics. + Unsubscribe(upstream chan *Message, topics ...string) error + + // UnsubscribePattern broker from pattern. + UnsubscribePattern(upstream chan *Message, pattern string) error + + // Publish one or multiple Channel. + Publish(messages ...*Message) error +} + +// Message represent single message. +type Message struct { + // Topic message been pushed into. + Topic string `json:"topic"` + + // Payload to be broadcasted. Must be valid json when transferred over RPC. + Payload json.RawMessage `json:"payload"` +} diff --git a/plugins/broadcast/root/client.go b/plugins/broadcast/root/client.go new file mode 100644 index 00000000..c5761f94 --- /dev/null +++ b/plugins/broadcast/root/client.go @@ -0,0 +1,133 @@ +package broadcast + +import "sync" + +// Client subscribes to a given topic and consumes or publish messages to it. +type Client struct { + upstream chan *Message + broker Broker + mu sync.Mutex + topics []string + patterns []string +} + +// Channel returns incoming messages channel. +func (c *Client) Channel() chan *Message { + return c.upstream +} + +// Publish message into associated topic or topics. +func (c *Client) Publish(msg ...*Message) error { + return c.broker.Publish(msg...) +} + +// Subscribe client to specific topics. +func (c *Client) Subscribe(topics ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + + newTopics := make([]string, 0) + for _, topic := range topics { + found := false + for _, e := range c.topics { + if e == topic { + found = true + break + } + } + + if !found { + newTopics = append(newTopics, topic) + } + } + + if len(newTopics) == 0 { + return nil + } + + c.topics = append(c.topics, newTopics...) + + return c.broker.Subscribe(c.upstream, newTopics...) +} + +// SubscribePattern subscribe client to the specific topic pattern. +func (c *Client) SubscribePattern(pattern string) error { + c.mu.Lock() + defer c.mu.Unlock() + + for _, g := range c.patterns { + if g == pattern { + return nil + } + } + + c.patterns = append(c.patterns, pattern) + return c.broker.SubscribePattern(c.upstream, pattern) +} + +// Unsubscribe client from specific topics +func (c *Client) Unsubscribe(topics ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + + dropTopics := make([]string, 0) + for _, topic := range topics { + for i, e := range c.topics { + if e == topic { + c.topics = append(c.topics[:i], c.topics[i+1:]...) + dropTopics = append(dropTopics, topic) + } + } + } + + if len(dropTopics) == 0 { + return nil + } + + return c.broker.Unsubscribe(c.upstream, dropTopics...) +} + +// UnsubscribePattern client from topic pattern. +func (c *Client) UnsubscribePattern(pattern string) error { + c.mu.Lock() + defer c.mu.Unlock() + + for i := range c.patterns { + if c.patterns[i] == pattern { + c.patterns = append(c.patterns[:i], c.patterns[i+1:]...) + + return c.broker.UnsubscribePattern(c.upstream, pattern) + } + } + + return nil +} + +// Topics return all the topics client subscribed to. +func (c *Client) Topics() []string { + c.mu.Lock() + defer c.mu.Unlock() + + return c.topics +} + +// Patterns return all the patterns client subscribed to. +func (c *Client) Patterns() []string { + c.mu.Lock() + defer c.mu.Unlock() + + return c.patterns +} + +// Close the client and consumption. +func (c *Client) Close() (err error) { + c.mu.Lock() + defer c.mu.Unlock() + + if len(c.topics) != 0 { + err = c.broker.Unsubscribe(c.upstream, c.topics...) + } + + close(c.upstream) + return err +} diff --git a/plugins/broadcast/root/client_test.go b/plugins/broadcast/root/client_test.go new file mode 100644 index 00000000..52a50d57 --- /dev/null +++ b/plugins/broadcast/root/client_test.go @@ -0,0 +1,59 @@ +package broadcast + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Client_Topics(t *testing.T) { + br, _, c := setup(`{}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + assert.Equal(t, []string{}, client.Topics()) + + assert.NoError(t, client.Subscribe("topic")) + assert.Equal(t, []string{"topic"}, client.Topics()) + + assert.NoError(t, client.Subscribe("topic")) + assert.Equal(t, []string{"topic"}, client.Topics()) + + assert.NoError(t, br.broker.Subscribe(client.upstream, "topic")) + assert.Equal(t, []string{"topic"}, client.Topics()) + + assert.NoError(t, br.Broker().Publish(newMessage("topic", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, client.Unsubscribe("topic")) + assert.NoError(t, client.Unsubscribe("topic")) + assert.NoError(t, br.broker.Unsubscribe(client.upstream, "topic")) + + assert.Equal(t, []string{}, client.Topics()) +} + +func Test_Client_Patterns(t *testing.T) { + br, _, c := setup(`{}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + assert.Equal(t, []string{}, client.Patterns()) + + assert.NoError(t, client.SubscribePattern("topic/*")) + assert.Equal(t, []string{"topic/*"}, client.Patterns()) + + assert.NoError(t, br.broker.SubscribePattern(client.upstream, "topic/*")) + assert.Equal(t, []string{"topic/*"}, client.Patterns()) + + assert.NoError(t, br.Broker().Publish(newMessage("topic/1", "hello1"))) + assert.Equal(t, `hello1`, readStr(<-client.Channel())) + + assert.NoError(t, client.UnsubscribePattern("topic/*")) + assert.NoError(t, br.broker.UnsubscribePattern(client.upstream, "topic/*")) + + assert.Equal(t, []string{}, client.Patterns()) +} diff --git a/plugins/broadcast/root/config.go b/plugins/broadcast/root/config.go new file mode 100644 index 00000000..8c732441 --- /dev/null +++ b/plugins/broadcast/root/config.go @@ -0,0 +1,61 @@ +package broadcast + +import ( + "errors" + + "github.com/go-redis/redis/v8" +) + +// Config configures the broadcast extension. +type Config struct { + // RedisConfig configures redis broker. + Redis *RedisConfig +} + +// Hydrate reads the configuration values from the source configuration. +//func (c *Config) Hydrate(cfg service.Config) error { +// if err := cfg.Unmarshal(c); err != nil { +// return err +// } +// +// if c.Redis != nil { +// return c.Redis.isValid() +// } +// +// return nil +//} + +// InitDefaults enables in memory broadcast configuration. +func (c *Config) InitDefaults() error { + return nil +} + +// RedisConfig configures redis broker. +type RedisConfig struct { + // Addr of the redis server. + Addr string + + // Password to redis server. + Password string + + // DB index. + DB int +} + +// clusterOptions +func (cfg *RedisConfig) redisClient() redis.UniversalClient { + return redis.NewClient(&redis.Options{ + Addr: cfg.Addr, + Password: cfg.Password, + PoolSize: 2, + }) +} + +// check if redis config is valid. +func (cfg *RedisConfig) isValid() error { + if cfg.Addr == "" { + return errors.New("redis addr is required") + } + + return nil +} diff --git a/plugins/broadcast/root/config_test.go b/plugins/broadcast/root/config_test.go new file mode 100644 index 00000000..28191c6b --- /dev/null +++ b/plugins/broadcast/root/config_test.go @@ -0,0 +1,60 @@ +package broadcast + +import ( + "encoding/json" + "testing" + + "github.com/spiral/roadrunner/service" + "github.com/spiral/roadrunner/service/rpc" + "github.com/stretchr/testify/assert" +) + +type testCfg struct { + rpc string + broadcast string + target string +} + +func (cfg *testCfg) Get(name string) service.Config { + if name == ID { + return &testCfg{target: cfg.broadcast} + } + + if name == rpc.ID { + return &testCfg{target: cfg.rpc} + } + + return nil +} + +func (cfg *testCfg) Unmarshal(out interface{}) error { + return json.Unmarshal([]byte(cfg.target), out) +} + +func Test_Config_Hydrate_Error(t *testing.T) { + cfg := &testCfg{target: `{"dead`} + c := &Config{} + + assert.Error(t, c.Hydrate(cfg)) +} + +func Test_Config_Hydrate_OK(t *testing.T) { + cfg := &testCfg{target: `{"path":"/path"}`} + c := &Config{} + + assert.NoError(t, c.Hydrate(cfg)) +} + +func Test_Config_Redis_Error(t *testing.T) { + cfg := &testCfg{target: `{"path":"/path","redis":{}}`} + c := &Config{} + + assert.Error(t, c.Hydrate(cfg)) +} + +func Test_Config_Redis_OK(t *testing.T) { + cfg := &testCfg{target: `{"path":"/path","redis":{"addr":"localhost:6379"}}`} + c := &Config{} + + assert.NoError(t, c.Hydrate(cfg)) +} diff --git a/plugins/broadcast/root/router.go b/plugins/broadcast/root/router.go new file mode 100644 index 00000000..91137f8b --- /dev/null +++ b/plugins/broadcast/root/router.go @@ -0,0 +1,170 @@ +package broadcast + +//import "github.com/gobwas/glob" + +// Router performs internal message routing to multiple subscribers. +type Router struct { + wildcard map[string]wildcard + routes map[string][]chan *Message +} + +// wildcard handles number of topics via glob pattern. +type wildcard struct { + //glob glob.Glob + upstream []chan *Message +} + +// helper for blocking join/leave flow +type subscriber struct { + upstream chan *Message + done chan error + topics []string + pattern string +} + +// NewRouter creates new topic and pattern router. +func NewRouter() *Router { + return &Router{ + wildcard: make(map[string]wildcard), + routes: make(map[string][]chan *Message), + } +} + +// Dispatch to all connected topics. +func (r *Router) Dispatch(msg *Message) { + for _, w := range r.wildcard { + if w.glob.Match(msg.Topic) { + for _, upstream := range w.upstream { + upstream <- msg + } + } + } + + if routes, ok := r.routes[msg.Topic]; ok { + for _, upstream := range routes { + upstream <- msg + } + } +} + +// Subscribe to topic and return list of newly assigned topics. +func (r *Router) Subscribe(upstream chan *Message, topics ...string) (newTopics []string) { + newTopics = make([]string, 0) + for _, topic := range topics { + if _, ok := r.routes[topic]; !ok { + r.routes[topic] = []chan *Message{upstream} + if !r.collapsed(topic) { + newTopics = append(newTopics, topic) + } + continue + } + + joined := false + for _, up := range r.routes[topic] { + if up == upstream { + joined = true + break + } + } + + if !joined { + r.routes[topic] = append(r.routes[topic], upstream) + } + } + + return newTopics +} + +// Unsubscribe from given list of topics and return list of topics which are no longer claimed. +func (r *Router) Unsubscribe(upstream chan *Message, topics ...string) (dropTopics []string) { + dropTopics = make([]string, 0) + for _, topic := range topics { + if _, ok := r.routes[topic]; !ok { + // no such topic, ignore + continue + } + + for i := range r.routes[topic] { + if r.routes[topic][i] == upstream { + r.routes[topic] = append(r.routes[topic][:i], r.routes[topic][i+1:]...) + break + } + } + + if len(r.routes[topic]) == 0 { + delete(r.routes, topic) + + // standalone empty subscription + if !r.collapsed(topic) { + dropTopics = append(dropTopics, topic) + } + } + } + + return dropTopics +} + +// SubscribePattern subscribes to glob parent and return true and return array of newly added patterns. Error in +// case if blob is invalid. +func (r *Router) SubscribePattern(upstream chan *Message, pattern string) (newPatterns []string, err error) { + if w, ok := r.wildcard[pattern]; ok { + joined := false + for _, up := range w.upstream { + if up == upstream { + joined = true + break + } + } + + if !joined { + w.upstream = append(w.upstream, upstream) + } + + return nil, nil + } + + g, err := glob.Compile(pattern) + if err != nil { + return nil, err + } + + r.wildcard[pattern] = wildcard{glob: g, upstream: []chan *Message{upstream}} + + return []string{pattern}, nil +} + +// UnsubscribePattern unsubscribe from the pattern and returns an array of patterns which are no longer claimed. +func (r *Router) UnsubscribePattern(upstream chan *Message, pattern string) (dropPatterns []string) { + // todo: store and return collapsed topics + + w, ok := r.wildcard[pattern] + if !ok { + // no such pattern + return nil + } + + for i, up := range w.upstream { + if up == upstream { + w.upstream[i] = w.upstream[len(w.upstream)-1] + w.upstream[len(w.upstream)-1] = nil + w.upstream = w.upstream[:len(w.upstream)-1] + + if len(w.upstream) == 0 { + delete(r.wildcard, pattern) + return []string{pattern} + } + } + } + + return nil +} + +func (r *Router) collapsed(topic string) bool { + for _, w := range r.wildcard { + if w.glob.Match(topic) { + return true + } + } + + return false +} diff --git a/plugins/broadcast/root/rpc.go b/plugins/broadcast/root/rpc.go new file mode 100644 index 00000000..5604a574 --- /dev/null +++ b/plugins/broadcast/root/rpc.go @@ -0,0 +1,25 @@ +package broadcast + +import "golang.org/x/sync/errgroup" + +type rpcService struct { + svc *Service +} + +// Publish Messages. +func (r *rpcService) Publish(msg []*Message, ok *bool) error { + *ok = true + return r.svc.Publish(msg...) +} + +// Publish Messages in async mode. Blocks until get an err or nil from publish +func (r *rpcService) PublishAsync(msg []*Message, ok *bool) error { + *ok = true + g := &errgroup.Group{} + + g.Go(func() error { + return r.svc.Publish(msg...) + }) + + return g.Wait() +} diff --git a/plugins/broadcast/root/rpc_test.go b/plugins/broadcast/root/rpc_test.go new file mode 100644 index 00000000..157c4e70 --- /dev/null +++ b/plugins/broadcast/root/rpc_test.go @@ -0,0 +1,72 @@ +package broadcast + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRPC_Broadcast(t *testing.T) { + br, rpc, c := setup(`{}`) + defer c.Stop() + + client := br.NewClient() + defer client.Close() + + rcpClient, err := rpc.Client() + assert.NoError(t, err) + + // must not be delivered + ok := false + assert.NoError(t, rcpClient.Call( + "broadcast.Publish", + []*Message{newMessage("topic", `"hello1"`)}, + &ok, + )) + assert.True(t, ok) + + assert.NoError(t, client.Subscribe("topic")) + + assert.NoError(t, rcpClient.Call( + "broadcast.Publish", + []*Message{newMessage("topic", `"hello1"`)}, + &ok, + )) + assert.True(t, ok) + assert.Equal(t, `"hello1"`, readStr(<-client.Channel())) + + assert.NoError(t, rcpClient.Call( + "broadcast.Publish", + []*Message{newMessage("topic", `"hello2"`)}, + &ok, + )) + assert.True(t, ok) + assert.Equal(t, `"hello2"`, readStr(<-client.Channel())) + + assert.NoError(t, client.Unsubscribe("topic")) + + assert.NoError(t, rcpClient.Call( + "broadcast.Publish", + []*Message{newMessage("topic", `"hello3"`)}, + &ok, + )) + assert.True(t, ok) + + assert.NoError(t, client.Subscribe("topic")) + + assert.NoError(t, rcpClient.Call( + "broadcast.Publish", + []*Message{newMessage("topic", `"hello4"`)}, + &ok, + )) + assert.True(t, ok) + assert.Equal(t, `"hello4"`, readStr(<-client.Channel())) + + assert.NoError(t, rcpClient.Call( + "broadcast.PublishAsync", + []*Message{newMessage("topic", `"hello5"`)}, + &ok, + )) + assert.True(t, ok) + assert.Equal(t, `"hello5"`, readStr(<-client.Channel())) +} diff --git a/plugins/broadcast/root/service.go b/plugins/broadcast/root/service.go new file mode 100644 index 00000000..8b175b3e --- /dev/null +++ b/plugins/broadcast/root/service.go @@ -0,0 +1,85 @@ +package broadcast + +import ( + "errors" + "sync" + + "github.com/spiral/roadrunner/service/rpc" +) + +// ID defines public service name. +const ID = "broadcast" + +// Service manages even broadcasting and websocket interface. +type Service struct { + // service and broker configuration + cfg *Config + + // broker + mu sync.Mutex + broker Broker +} + +// Init service. +func (s *Service) Init(cfg *Config, rpc *rpc.Service) (ok bool, err error) { + s.cfg = cfg + + if rpc != nil { + if err := rpc.Register(ID, &rpcService{svc: s}); err != nil { + return false, err + } + } + + s.mu.Lock() + if s.cfg.Redis != nil { + if s.broker, err = redisBroker(s.cfg.Redis); err != nil { + return false, err + } + } else { + s.broker = memoryBroker() + } + s.mu.Unlock() + + return true, nil +} + +// Serve broadcast broker. +func (s *Service) Serve() (err error) { + return s.broker.Serve() +} + +// Stop closes broadcast broker. +func (s *Service) Stop() { + broker := s.Broker() + if broker != nil { + broker.Stop() + } +} + +// Broker returns associated broker. +func (s *Service) Broker() Broker { + s.mu.Lock() + defer s.mu.Unlock() + + return s.broker +} + +// NewClient returns single connected client with ability to consume or produce into associated topic(svc). +func (s *Service) NewClient() *Client { + return &Client{ + upstream: make(chan *Message), + broker: s.Broker(), + topics: make([]string, 0), + patterns: make([]string, 0), + } +} + +// Publish one or multiple Channel. +func (s *Service) Publish(msg ...*Message) error { + broker := s.Broker() + if broker == nil { + return errors.New("no stopped broker") + } + + return s.Broker().Publish(msg...) +} diff --git a/plugins/broadcast/root/service_test.go b/plugins/broadcast/root/service_test.go new file mode 100644 index 00000000..10b924cc --- /dev/null +++ b/plugins/broadcast/root/service_test.go @@ -0,0 +1,65 @@ +package broadcast + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spiral/roadrunner/service" + "github.com/spiral/roadrunner/service/rpc" + "github.com/stretchr/testify/assert" +) + +var rpcPort = 6010 + +func setup(cfg string) (*Service, *rpc.Service, service.Container) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(ID, &Service{}) + + err := c.Init(&testCfg{ + broadcast: cfg, + rpc: fmt.Sprintf(`{"listen":"tcp://:%v"}`, rpcPort), + }) + + rpcPort++ + + if err != nil { + panic(err) + } + + go func() { + err = c.Serve() + if err != nil { + panic(err) + } + }() + time.Sleep(time.Millisecond * 100) + + b, _ := c.Get(ID) + br := b.(*Service) + + r, _ := c.Get(rpc.ID) + rp := r.(*rpc.Service) + + return br, rp, c +} + +func readStr(m *Message) string { + return strings.TrimRight(string(m.Payload), "\n") +} + +func newMessage(t, m string) *Message { + return &Message{Topic: t, Payload: []byte(m)} +} + +func TestService_Publish(t *testing.T) { + svc := &Service{} + assert.Error(t, svc.Publish(nil)) +} diff --git a/plugins/broadcast/root/tests/.rr.yaml b/plugins/broadcast/root/tests/.rr.yaml new file mode 100644 index 00000000..c35a12fc --- /dev/null +++ b/plugins/broadcast/root/tests/.rr.yaml @@ -0,0 +1,2 @@ +broadcast: + redis.addr: "localhost:6379"
\ No newline at end of file diff --git a/plugins/broadcast/root/tests/Broadcast/BroadcastTest.php b/plugins/broadcast/root/tests/Broadcast/BroadcastTest.php new file mode 100644 index 00000000..d6014bf0 --- /dev/null +++ b/plugins/broadcast/root/tests/Broadcast/BroadcastTest.php @@ -0,0 +1,56 @@ +<?php + +/** + * Spiral Framework. + * + * @license MIT + * @author Anton Titov (Wolfy-J) + */ + +declare(strict_types=1); + +namespace Spiral\Broadcast\Tests; + +use PHPUnit\Framework\TestCase; +use Spiral\Broadcast\Broadcast; +use Spiral\Broadcast\Exception\BroadcastException; +use Spiral\Broadcast\Message; +use Spiral\Goridge\RPC; +use Spiral\Goridge\SocketRelay; + +class BroadcastTest extends TestCase +{ + public function testBroadcast(): void + { + $rpc = new RPC(new SocketRelay('localhost', 6001)); + $br = new Broadcast($rpc); + + $br->publish( + new Message('tests/topic', 'hello'), + new Message('tests/123', ['key' => 'value']) + ); + + while (filesize(__DIR__ . '/../log.txt') < 40) { + clearstatcache(true, __DIR__ . '/../log.txt'); + usleep(1000); + } + + clearstatcache(true, __DIR__ . '/../log.txt'); + $content = file_get_contents(__DIR__ . '/../log.txt'); + + $this->assertSame('tests/topic: "hello" +tests/123: {"key":"value"} +', $content); + } + + public function testBroadcastException(): void + { + $rpc = new RPC(new SocketRelay('localhost', 6002)); + $br = new Broadcast($rpc); + + $this->expectException(BroadcastException::class); + $br->publish( + new Message('topic', 'hello') + ); + } +} diff --git a/plugins/broadcast/root/tests/Broadcast/MessageTest.php b/plugins/broadcast/root/tests/Broadcast/MessageTest.php new file mode 100644 index 00000000..dd9e1cc3 --- /dev/null +++ b/plugins/broadcast/root/tests/Broadcast/MessageTest.php @@ -0,0 +1,24 @@ +<?php + +/** + * Spiral Framework. + * + * @license MIT + * @author Anton Titov (Wolfy-J) + */ + +declare(strict_types=1); + +namespace Spiral\Broadcast\Tests; + +use PHPUnit\Framework\TestCase; +use Spiral\Broadcast\Message; + +class MessageTest extends TestCase +{ + public function testSerialize(): void + { + $m = new Message('topic', ['hello' => 'world']); + $this->assertSame('{"topic":"topic","payload":{"hello":"world"}}', json_encode($m)); + } +} diff --git a/plugins/broadcast/root/tests/bootstrap.php b/plugins/broadcast/root/tests/bootstrap.php new file mode 100644 index 00000000..d0dfb88b --- /dev/null +++ b/plugins/broadcast/root/tests/bootstrap.php @@ -0,0 +1,15 @@ +<?php + +/** + * Spiral Framework, SpiralScout LLC. + * + * @author Anton Titov (Wolfy-J) + */ + +declare(strict_types=1); + +error_reporting(E_ALL | E_STRICT); +ini_set('display_errors', 'stderr'); + +//Composer +require dirname(__DIR__) . '/vendor_php/autoload.php'; diff --git a/plugins/broadcast/root/tests/docker-compose.yml b/plugins/broadcast/root/tests/docker-compose.yml new file mode 100644 index 00000000..123aa9b9 --- /dev/null +++ b/plugins/broadcast/root/tests/docker-compose.yml @@ -0,0 +1,9 @@ +version: '3' + +services: + redis: + image: 'bitnami/redis:latest' + environment: + - ALLOW_EMPTY_PASSWORD=yes + ports: + - "6379:6379"
\ No newline at end of file diff --git a/plugins/broadcast/root/tests/go-client.go b/plugins/broadcast/root/tests/go-client.go new file mode 100644 index 00000000..21442a01 --- /dev/null +++ b/plugins/broadcast/root/tests/go-client.go @@ -0,0 +1,78 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spiral/broadcast/v2" + rr "github.com/spiral/roadrunner/cmd/rr/cmd" + "github.com/spiral/roadrunner/service/rpc" + "golang.org/x/sync/errgroup" +) + +type logService struct { + broadcast *broadcast.Service + stop chan interface{} +} + +func (l *logService) Init(service *broadcast.Service) (bool, error) { + l.broadcast = service + + return true, nil +} + +func (l *logService) Serve() error { + l.stop = make(chan interface{}) + + client := l.broadcast.NewClient() + if err := client.SubscribePattern("tests/*"); err != nil { + return err + } + + logFile, _ := os.Create("log.txt") + + g := &errgroup.Group{} + g.Go(func() error { + for msg := range client.Channel() { + _, err := logFile.Write([]byte(fmt.Sprintf( + "%s: %s\n", + msg.Topic, + string(msg.Payload), + ))) + if err != nil { + return err + } + + err = logFile.Sync() + if err != nil { + return err + } + } + return nil + }) + + <-l.stop + err := logFile.Close() + if err != nil { + return err + } + + err = client.Close() + if err != nil { + return err + } + + return g.Wait() +} + +func (l *logService) Stop() { + close(l.stop) +} + +func main() { + rr.Container.Register(rpc.ID, &rpc.Service{}) + rr.Container.Register(broadcast.ID, &broadcast.Service{}) + rr.Container.Register("log", &logService{}) + + rr.Execute() +} diff --git a/plugins/broadcast/rpc.go b/plugins/broadcast/rpc.go new file mode 100644 index 00000000..aa270f64 --- /dev/null +++ b/plugins/broadcast/rpc.go @@ -0,0 +1 @@ +package broadcast diff --git a/plugins/broadcast/websockets/Makefile b/plugins/broadcast/websockets/Makefile new file mode 100644 index 00000000..f32efbdb --- /dev/null +++ b/plugins/broadcast/websockets/Makefile @@ -0,0 +1,2 @@ +est: + go test -v -race -cover diff --git a/plugins/broadcast/websockets/access_validator.go b/plugins/broadcast/websockets/access_validator.go new file mode 100644 index 00000000..bf27386d --- /dev/null +++ b/plugins/broadcast/websockets/access_validator.go @@ -0,0 +1,102 @@ +package websockets + +import ( + "bytes" + "io" + "net/http" + "strings" + + "github.com/spiral/roadrunner/v2/plugins/http/attributes" +) + +type accessValidator struct { + buffer *bytes.Buffer + header http.Header + status int +} + +func newValidator() *accessValidator { + return &accessValidator{ + buffer: bytes.NewBuffer(nil), + header: make(http.Header), + } +} + +// copy all content to parent response writer. +func (w *accessValidator) copy(rw http.ResponseWriter) { + rw.WriteHeader(w.status) + + for k, v := range w.header { + for _, vv := range v { + rw.Header().Add(k, vv) + } + } + + _, _ = io.Copy(rw, w.buffer) +} + +// Header returns the header map that will be sent by WriteHeader. +func (w *accessValidator) Header() http.Header { + return w.header +} + +// Write writes the data to the connection as part of an HTTP reply. +func (w *accessValidator) Write(p []byte) (int, error) { + return w.buffer.Write(p) +} + +// WriteHeader sends an HTTP response header with the provided status code. +func (w *accessValidator) WriteHeader(statusCode int) { + w.status = statusCode +} + +// IsOK returns true if response contained 200 status code. +func (w *accessValidator) IsOK() bool { + return w.status == 200 +} + +// Body returns response body to rely to user. +func (w *accessValidator) Body() []byte { + return w.buffer.Bytes() +} + +// Error contains server response. +func (w *accessValidator) Error() string { + return w.buffer.String() +} + +// assertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in +// case of error +func (w *accessValidator) assertServerAccess(f http.HandlerFunc, r *http.Request) error { + if err := attributes.Set(r, "ws:joinServer", true); err != nil { + return err + } + + defer delete(attributes.All(r), "ws:joinServer") + + f(w, r) + + if !w.IsOK() { + return w + } + + return nil +} + +// assertAccess checks if user can access given upstream, the application will receive all user headers and cookies. +// the decision to authorize user will be based on response code (200). +func (w *accessValidator) assertTopicsAccess(f http.HandlerFunc, r *http.Request, channels ...string) error { + if err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")); err != nil { + return err + } + + defer delete(attributes.All(r), "ws:joinTopics") + + f(w, r) + + if !w.IsOK() { + return w + } + + return nil +} diff --git a/plugins/broadcast/websockets/access_validator_test.go b/plugins/broadcast/websockets/access_validator_test.go new file mode 100644 index 00000000..41372727 --- /dev/null +++ b/plugins/broadcast/websockets/access_validator_test.go @@ -0,0 +1,35 @@ +package websockets + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResponseWrapper_Body(t *testing.T) { + w := newValidator() + _, _ =w.Write([]byte("hello")) + + assert.Equal(t, []byte("hello"), w.Body()) +} + +func TestResponseWrapper_Header(t *testing.T) { + w := newValidator() + w.Header().Set("k", "value") + + assert.Equal(t, "value", w.Header().Get("k")) +} + +func TestResponseWrapper_StatusCode(t *testing.T) { + w := newValidator() + w.WriteHeader(200) + + assert.True(t, w.IsOK()) +} + +func TestResponseWrapper_StatusCodeBad(t *testing.T) { + w := newValidator() + w.WriteHeader(400) + + assert.False(t, w.IsOK()) +} diff --git a/plugins/broadcast/websockets/config.go b/plugins/broadcast/websockets/config.go new file mode 100644 index 00000000..8a71c7af --- /dev/null +++ b/plugins/broadcast/websockets/config.go @@ -0,0 +1,21 @@ +package websockets + + +// Config defines the websocket service configuration. +type Config struct { + // Path defines on this URL the middleware must be activated. Same path must + // be handled by underlying application kernel to authorize the consumption. + Path string + + // NoOrigin disables origin check, only for debug. + NoOrigin bool +} + +// Hydrate reads the configuration values from the source configuration. +//func (c *Config) Hydrate(cfg service.Config) error { +// if err := cfg.Unmarshal(c); err != nil { +// return err +// } +// +// return nil +//} diff --git a/plugins/broadcast/websockets/config_test.go b/plugins/broadcast/websockets/config_test.go new file mode 100644 index 00000000..e646fdc4 --- /dev/null +++ b/plugins/broadcast/websockets/config_test.go @@ -0,0 +1,34 @@ +package websockets + +import ( + "encoding/json" + "testing" + + "github.com/spiral/roadrunner/service" + "github.com/stretchr/testify/assert" +) + +type mockCfg struct{ cfg string } + +func (cfg *mockCfg) Get(name string) service.Config { + if name == "same" || name == "jobs" { + return cfg + } + + return nil +} +func (cfg *mockCfg) Unmarshal(out interface{}) error { return json.Unmarshal([]byte(cfg.cfg), out) } + +func Test_Config_Hydrate_Error(t *testing.T) { + cfg := &mockCfg{cfg: `{"dead`} + c := &Config{} + + assert.Error(t, c.Hydrate(cfg)) +} + +func Test_Config_Hydrate_OK(t *testing.T) { + cfg := &mockCfg{cfg: `{"path":"/path"}`} + c := &Config{} + + assert.NoError(t, c.Hydrate(cfg)) +} diff --git a/plugins/broadcast/websockets/conn_context.go b/plugins/broadcast/websockets/conn_context.go new file mode 100644 index 00000000..f7d62833 --- /dev/null +++ b/plugins/broadcast/websockets/conn_context.go @@ -0,0 +1,66 @@ +package websockets + +import ( + "encoding/json" + + "github.com/gorilla/websocket" +) + +// ConnContext carries information about websocket connection and it's topics. +type ConnContext struct { + // Conn to the client. + Conn *websocket.Conn + + // Topics contain list of currently subscribed topics. + Topics []string + + // upstream to push messages into. + upstream chan *broadcast.Message +} + +// SendMessage message directly to the client. +func (ctx *ConnContext) SendMessage(topic string, payload interface{}) (err error) { + msg := &broadcast.Message{Topic: topic} + msg.Payload, err = json.Marshal(payload) + + if err == nil { + ctx.upstream <- msg + } + + return err +} + +func (ctx *ConnContext) serve(errHandler func(err error, conn *websocket.Conn)) { + for msg := range ctx.upstream { + if err := ctx.Conn.WriteJSON(msg); err != nil { + errHandler(err, ctx.Conn) + } + } +} + +func (ctx *ConnContext) addTopics(topics ...string) { + for _, topic := range topics { + found := false + for _, e := range ctx.Topics { + if e == topic { + found = true + break + } + } + + if !found { + ctx.Topics = append(ctx.Topics, topic) + } + } +} + +func (ctx *ConnContext) dropTopic(topics ...string) { + for _, topic := range topics { + for i, e := range ctx.Topics { + if e == topic { + ctx.Topics[i] = ctx.Topics[len(ctx.Topics)-1] + ctx.Topics = ctx.Topics[:len(ctx.Topics)-1] + } + } + } +} diff --git a/plugins/broadcast/websockets/conn_context_test.go b/plugins/broadcast/websockets/conn_context_test.go new file mode 100644 index 00000000..466aaa30 --- /dev/null +++ b/plugins/broadcast/websockets/conn_context_test.go @@ -0,0 +1,28 @@ +package websockets + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnContext_ManageTopics(t *testing.T) { + ctx := &ConnContext{Topics: make([]string, 0)} + + assert.Equal(t, []string{}, ctx.Topics) + + ctx.addTopics("a", "b") + assert.Equal(t, []string{"a", "b"}, ctx.Topics) + + ctx.addTopics("a", "c") + assert.Equal(t, []string{"a", "b", "c"}, ctx.Topics) + + ctx.dropTopic("b", "c") + assert.Equal(t, []string{"a"}, ctx.Topics) + + ctx.dropTopic("b", "c") + assert.Equal(t, []string{"a"}, ctx.Topics) + + ctx.dropTopic("a") + assert.Equal(t, []string{}, ctx.Topics) +} diff --git a/plugins/broadcast/websockets/conn_pool.go b/plugins/broadcast/websockets/conn_pool.go new file mode 100644 index 00000000..80092a44 --- /dev/null +++ b/plugins/broadcast/websockets/conn_pool.go @@ -0,0 +1,125 @@ +package websockets + +import ( + "errors" + "sync" + + "github.com/gorilla/websocket" + "github.com/spiral/broadcast/v2" +) + +// manages a set of websocket connections +type connPool struct { + errHandler func(err error, conn *websocket.Conn) + + mur sync.Mutex + client *broadcast.Client + router *broadcast.Router + + mu sync.Mutex + conns map[*websocket.Conn]*ConnContext +} + +// create new connection pool +func newPool(client *broadcast.Client, errHandler func(err error, conn *websocket.Conn)) *connPool { + cp := &connPool{ + client: client, + router: broadcast.NewRouter(), + errHandler: errHandler, + conns: map[*websocket.Conn]*ConnContext{}, + } + + go func() { + for msg := range cp.client.Channel() { + cp.mur.Lock() + cp.router.Dispatch(msg) + cp.mur.Unlock() + } + }() + + return cp +} + +// connect the websocket and register client in message router +func (cp *connPool) connect(conn *websocket.Conn) (*ConnContext, error) { + ctx := &ConnContext{ + Conn: conn, + Topics: []string{}, + upstream: make(chan *broadcast.Message), + } + + cp.mu.Lock() + cp.conns[conn] = ctx + cp.mu.Unlock() + + go ctx.serve(cp.errHandler) + + return ctx, nil +} + +// disconnect the websocket +func (cp *connPool) disconnect(conn *websocket.Conn) error { + cp.mu.Lock() + defer cp.mu.Unlock() + + ctx, ok := cp.conns[conn] + if !ok { + return errors.New("no such connection") + } + + if err := cp.unsubscribe(ctx, ctx.Topics...); err != nil { + cp.errHandler(err, conn) + } + + delete(cp.conns, conn) + + return conn.Close() +} + +// subscribe the connection +func (cp *connPool) subscribe(ctx *ConnContext, topics ...string) error { + cp.mur.Lock() + defer cp.mur.Unlock() + + ctx.addTopics(topics...) + + newTopics := cp.router.Subscribe(ctx.upstream, topics...) + if len(newTopics) != 0 { + return cp.client.Subscribe(newTopics...) + } + + return nil +} + +// unsubscribe the connection +func (cp *connPool) unsubscribe(ctx *ConnContext, topics ...string) error { + cp.mur.Lock() + defer cp.mur.Unlock() + + ctx.dropTopic(topics...) + + dropTopics := cp.router.Unsubscribe(ctx.upstream, topics...) + if len(dropTopics) != 0 { + return cp.client.Unsubscribe(dropTopics...) + } + + return nil +} + +// close the connection pool and disconnect all listeners +func (cp *connPool) close() { + cp.mu.Lock() + defer cp.mu.Unlock() + + for conn, ctx := range cp.conns { + if err := cp.unsubscribe(ctx, ctx.Topics...); err != nil { + cp.errHandler(err, conn) + } + + delete(cp.conns, conn) + + if err := conn.Close(); err != nil { + cp.errHandler(err, conn) + } + } +} diff --git a/plugins/broadcast/websockets/event.go b/plugins/broadcast/websockets/event.go new file mode 100644 index 00000000..3634bb89 --- /dev/null +++ b/plugins/broadcast/websockets/event.go @@ -0,0 +1,40 @@ +package websockets + +import ( + "github.com/gorilla/websocket" +) + +const ( + // EventConnect fired when new client is connected, the context is *websocket.Conn. + EventConnect = iota + 2500 + + // EventDisconnect fired when websocket is disconnected, context is empty. + EventDisconnect + + // EventJoin caused when topics are being consumed, context if *TopicEvent. + EventJoin + + // EventLeave caused when topic consumption are stopped, context if *TopicEvent. + EventLeave + + // EventError when any broadcast error occurred, the context is *ErrorEvent. + EventError +) + +// ErrorEvent represents singular broadcast error event. +type ErrorEvent struct { + // Conn specific to the error. + Conn *websocket.Conn + + // Error contains job specific error. + Error error +} + +// TopicEvent caused when topic is joined or left. +type TopicEvent struct { + // Conn associated with topics. + Conn *websocket.Conn + + // Topics specific to event. + Topics []string +} diff --git a/plugins/broadcast/websockets/rpc.go b/plugins/broadcast/websockets/rpc.go new file mode 100644 index 00000000..1c62b902 --- /dev/null +++ b/plugins/broadcast/websockets/rpc.go @@ -0,0 +1,17 @@ +package websockets + +type rpcService struct { + svc *Service +} + +// Subscribe subscribes broadcast client to the given topic ahead of any websocket connections. +func (r *rpcService) Subscribe(topic string, ok *bool) error { + *ok = true + return r.svc.client.Subscribe(topic) +} + +// SubscribePattern subscribes broadcast client to +func (r *rpcService) SubscribePattern(pattern string, ok *bool) error { + *ok = true + return r.svc.client.SubscribePattern(pattern) +} diff --git a/plugins/broadcast/websockets/service.go b/plugins/broadcast/websockets/service.go new file mode 100644 index 00000000..f3c0c781 --- /dev/null +++ b/plugins/broadcast/websockets/service.go @@ -0,0 +1,228 @@ +package websockets + +import ( + "encoding/json" + "net/http" + "sync" + "sync/atomic" + + "github.com/gorilla/websocket" +) + +// ID defines service id. +const ID = "ws" + +// Service to manage websocket clients. +type Service struct { + cfg *Config + upgrade websocket.Upgrader + client *broadcast.Client + connPool *connPool + listeners []func(event int, ctx interface{}) + mu sync.Mutex + stopped int32 + stop chan error +} + +// AddListener attaches server event controller. +func (s *Service) AddListener(l func(event int, ctx interface{})) { + s.listeners = append(s.listeners, l) +} + +// Init the service. +func (s *Service) Init( + cfg *Config, + env env.Environment, + rttp *rhttp.Service, + rpc *rpc.Service, + broadcast *broadcast.Service, +) (bool, error) { + if broadcast == nil || rpc == nil { + // unable to activate + return false, nil + } + + s.cfg = cfg + s.client = broadcast.NewClient() + s.connPool = newPool(s.client, s.reportError) + s.stopped = 0 + + if err := rpc.Register(ID, &rpcService{svc: s}); err != nil { + return false, err + } + + if env != nil { + // ensure that underlying kernel knows what route to handle + env.SetEnv("RR_BROADCAST_PATH", cfg.Path) + } + + // init all this stuff + s.upgrade = websocket.Upgrader{} + + if s.cfg.NoOrigin { + s.upgrade.CheckOrigin = func(r *http.Request) bool { + return true + } + } + + rttp.AddMiddleware(s.middleware) + + return true, nil +} + +// Serve the websocket connections. +func (s *Service) Serve() error { + defer s.client.Close() + defer s.connPool.close() + + s.mu.Lock() + s.stop = make(chan error) + s.mu.Unlock() + + return <-s.stop +} + +// Stop the service and disconnect all connections. +func (s *Service) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + + if atomic.CompareAndSwapInt32(&s.stopped, 0, 1) { + close(s.stop) + } +} + +// middleware intercepts websocket connections. +func (s *Service) middleware(f http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != s.cfg.Path { + f(w, r) + return + } + + // checking server access + if err := newValidator().assertServerAccess(f, r); err != nil { + // show the error to the user + if av, ok := err.(*accessValidator); ok { + av.copy(w) + } else { + w.WriteHeader(400) + } + return + } + + conn, err := s.upgrade.Upgrade(w, r, nil) + if err != nil { + s.reportError(err, nil) + return + } + + s.throw(EventConnect, conn) + + // manage connection + ctx, err := s.connPool.connect(conn) + if err != nil { + s.reportError(err, conn) + return + } + + s.serveConn(ctx, f, r) + } +} + +// send and receive messages over websocket +func (s *Service) serveConn(ctx *ConnContext, f http.HandlerFunc, r *http.Request) { + defer func() { + if err := s.connPool.disconnect(ctx.Conn); err != nil { + s.reportError(err, ctx.Conn) + } + s.throw(EventDisconnect, ctx.Conn) + }() + + s.handleCommands(ctx, f, r) +} + +func (s *Service) handleCommands(ctx *ConnContext, f http.HandlerFunc, r *http.Request) { + cmd := &broadcast.Message{} + for { + if err := ctx.Conn.ReadJSON(cmd); err != nil { + s.reportError(err, ctx.Conn) + return + } + + switch cmd.Topic { + case "join": + topics := make([]string, 0) + if err := unmarshalCommand(cmd, &topics); err != nil { + s.reportError(err, ctx.Conn) + return + } + + if len(topics) == 0 { + continue + } + + if err := newValidator().assertTopicsAccess(f, r, topics...); err != nil { + s.reportError(err, ctx.Conn) + + if err := ctx.SendMessage("#join", topics); err != nil { + s.reportError(err, ctx.Conn) + return + } + + continue + } + + if err := s.connPool.subscribe(ctx, topics...); err != nil { + s.reportError(err, ctx.Conn) + return + } + + if err := ctx.SendMessage("@join", topics); err != nil { + s.reportError(err, ctx.Conn) + return + } + + s.throw(EventJoin, &TopicEvent{Conn: ctx.Conn, Topics: topics}) + case "leave": + topics := make([]string, 0) + if err := unmarshalCommand(cmd, &topics); err != nil { + s.reportError(err, ctx.Conn) + return + } + + if len(topics) == 0 { + continue + } + + if err := s.connPool.unsubscribe(ctx, topics...); err != nil { + s.reportError(err, ctx.Conn) + return + } + + if err := ctx.SendMessage("@leave", topics); err != nil { + s.reportError(err, ctx.Conn) + return + } + + s.throw(EventLeave, &TopicEvent{Conn: ctx.Conn, Topics: topics}) + } + } +} + +// handle connection error +func (s *Service) reportError(err error, conn *websocket.Conn) { + s.throw(EventError, &ErrorEvent{Conn: conn, Error: err}) +} + +// throw handles service, server and pool events. +func (s *Service) throw(event int, ctx interface{}) { + for _, l := range s.listeners { + l(event, ctx) + } +} + +// unmarshalCommand command data. +func unmarshalCommand(msg *broadcast.Message, v interface{}) error { + return json.Unmarshal(msg.Payload, v) +} diff --git a/plugins/broadcast/websockets/service_test.go b/plugins/broadcast/websockets/service_test.go new file mode 100644 index 00000000..911efc38 --- /dev/null +++ b/plugins/broadcast/websockets/service_test.go @@ -0,0 +1,706 @@ +package websockets + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spiral/broadcast/v2" + "github.com/spiral/roadrunner/service" + "github.com/spiral/roadrunner/service/env" + rrhttp "github.com/spiral/roadrunner/service/http" + "github.com/spiral/roadrunner/service/rpc" + "github.com/stretchr/testify/assert" +) + +type testCfg struct { + http string + rpc string + ws string + broadcast string + target string +} + +func (cfg *testCfg) Get(name string) service.Config { + if name == rrhttp.ID { + return &testCfg{target: cfg.http} + } + + if name == ID { + return &testCfg{target: cfg.ws} + } + + if name == rpc.ID { + return &testCfg{target: cfg.rpc} + } + + if name == broadcast.ID { + return &testCfg{target: cfg.broadcast} + } + + return nil +} +func (cfg *testCfg) Unmarshal(out interface{}) error { + return json.Unmarshal([]byte(cfg.target), out) +} + +func readStr(m interface{}) string { + return strings.TrimRight(string(m.([]byte)), "\n") +} + +func Test_HttpService_Echo(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(rrhttp.ID, &rrhttp.Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6041", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 3000) + defer c.Stop() + + req, err := http.NewRequest("GET", "http://localhost:6041/", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + + b, _ := ioutil.ReadAll(r.Body) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, []byte(""), b) +} + +func Test_HttpService_Echo400(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(rrhttp.ID, &rrhttp.Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6040", + "workers":{"command": "php tests/worker-stop.php", "pool.numWorkers": 1} + }`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 3000) + defer c.Stop() + + req, err := http.NewRequest("GET", "http://localhost:6040/", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + + assert.NoError(t, err) + assert.Equal(t, 401, r.StatusCode) +} + +func Test_Service_EnvPath(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6029", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6002"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 3000) + defer c.Stop() + + req, err := http.NewRequest("GET", "http://localhost:6029/", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + if err != nil { + panic(err) + } + defer func() { + _ = r.Body.Close() + }() + + b, _ := ioutil.ReadAll(r.Body) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, []byte("/ws"), b) +} + +func Test_Service_Disabled(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + _, s := c.Get(ID) + assert.Equal(t, service.StatusInactive, s) +} + +func Test_Service_JoinTopic(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6038", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6003"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6038", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + return + } + read <- message + } + }() + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["topic"]}`)) + assert.NoError(t, err) + + assert.Equal(t, `{"topic":"@join","payload":["topic"]}`, readStr(<-read)) +} + +func Test_Service_DenyJoin(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6037", + "workers":{"command": "php tests/worker-deny.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6004"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6037", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + read <- err + continue + } + read <- message + } + }() + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["topic"]}`)) + assert.NoError(t, err) + + assert.Equal(t, `{"topic":"#join","payload":["topic"]}`, readStr(<-read)) +} + +func Test_Service_DenyJoinServer(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6037", + "workers":{"command": "php tests/worker-stop.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6005"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6037", Path: "/ws"} + + _, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.Error(t, err) +} + +func Test_Service_EmptyTopics(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6036", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6006"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6036", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + read <- err + continue + } + read <- message + } + }() + + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":[]}`))) + + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["a"]}`))) + assert.Equal(t, `{"topic":"@join","payload":["a"]}`, readStr(<-read)) + + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"leave", "payload":[]}`))) + + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"leave", "payload":["a"]}`))) + assert.Equal(t, `{"topic":"@leave","payload":["a"]}`, readStr(<-read)) + + // must be automatically closed during service stop + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["a"]}`))) + assert.Equal(t, `{"topic":"@join","payload":["a"]}`, readStr(<-read)) +} + +func Test_Service_BadTopics(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6035", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6007"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6035", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + read <- err + continue + } + read <- message + } + }() + + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":"hello"}`))) + assert.Error(t, (<-read).(error)) +} + +func Test_Service_BadTopicsLeave(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6034", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6008"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6034", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + read <- err + continue + } + read <- message + } + }() + + assert.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"leave", "payload":"hello"}`))) + assert.Error(t, (<-read).(error)) +} + +func Test_Service_Events(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6033", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6009"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + b, _ := c.Get(ID) + br := b.(*Service) + + done := make(chan interface{}) + br.AddListener(func(event int, ctx interface{}) { + if event == EventConnect { + close(done) + } + }) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + u := url.URL{Scheme: "ws", Host: "localhost:6033", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + <-done + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + return + } + read <- message + } + }() + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["topic"]}`)) + assert.NoError(t, err) + + assert.Equal(t, `{"topic":"@join","payload":["topic"]}`, readStr(<-read)) +} + +func Test_Service_Warmup(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6033", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6009"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + rp, _ := c.Get(rpc.ID) + + b, _ := c.Get(ID) + br := b.(*Service) + + done := make(chan interface{}) + br.AddListener(func(event int, ctx interface{}) { + if event == EventConnect { + close(done) + } + }) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + client, err := rp.(*rpc.Service).Client() + assert.NoError(t, err) + + var ok bool + assert.NoError(t, client.Call("ws.SubscribePattern", "test", &ok)) + assert.True(t, ok) + assert.NoError(t, client.Call("ws.Subscribe", "test", &ok)) + assert.True(t, ok) + + u := url.URL{Scheme: "ws", Host: "localhost:6033", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + <-done + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + return + } + read <- message + } + }() + + // not delivered + assert.NoError(t, br.client.Publish(&broadcast.Message{Topic: "topic", Payload: []byte(`"hello"`)})) + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["topic"]}`)) + assert.NoError(t, err) + + assert.Equal(t, `{"topic":"@join","payload":["topic"]}`, readStr(<-read)) + + assert.NoError(t, br.client.Publish(&broadcast.Message{Topic: "topic", Payload: []byte(`"hello"`)})) + assert.Equal(t, `{"topic":"topic","payload":"hello"}`, readStr(<-read)) +} + +func Test_Service_Stop(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(env.ID, &env.Service{}) + c.Register(rpc.ID, &rpc.Service{}) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(broadcast.ID, &broadcast.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + http: `{ + "address": ":6033", + "workers":{"command": "php tests/worker-ok.php", "pool.numWorkers": 1} + }`, + rpc: `{"listen":"tcp://127.0.0.1:6009"}`, + ws: `{"path":"/ws"}`, + broadcast: `{}`, + })) + + rp, _ := c.Get(rpc.ID) + + b, _ := c.Get(ID) + br := b.(*Service) + + done := make(chan interface{}) + br.AddListener(func(event int, ctx interface{}) { + if event == EventConnect { + close(done) + } + }) + + go func() { _ = c.Serve() }() + time.Sleep(time.Millisecond * 1000) + defer c.Stop() + + client, err := rp.(*rpc.Service).Client() + assert.NoError(t, err) + + var ok bool + assert.NoError(t, client.Call("ws.SubscribePattern", "test", &ok)) + assert.True(t, ok) + assert.NoError(t, client.Call("ws.Subscribe", "test", &ok)) + assert.True(t, ok) + + u := url.URL{Scheme: "ws", Host: "localhost:6033", Path: "/ws"} + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + <-done + + read := make(chan interface{}) + + go func() { + defer close(read) + for { + _, message, err := conn.ReadMessage() + if err != nil { + return + } + read <- message + } + }() + + // not delivered + assert.NoError(t, br.client.Publish(&broadcast.Message{Topic: "topic", Payload: []byte(`"hello"`)})) + + br.Stop() + + err = conn.WriteMessage(websocket.TextMessage, []byte(`{"topic":"join", "payload":["topic"]}`)) + assert.NoError(t, err) +} |