summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/pubsub/interface.go4
-rw-r--r--plugins/memory/pubsub.go34
-rw-r--r--plugins/redis/pubsub/channel.go4
-rw-r--r--plugins/redis/pubsub/pubsub.go10
-rw-r--r--plugins/websockets/plugin.go26
-rw-r--r--tests/plugins/broadcast/plugins/plugin1.go31
-rw-r--r--tests/plugins/broadcast/plugins/plugin2.go31
-rw-r--r--tests/plugins/broadcast/plugins/plugin3.go33
-rw-r--r--tests/plugins/broadcast/plugins/plugin4.go34
-rw-r--r--tests/plugins/broadcast/plugins/plugin5.go34
-rw-r--r--tests/plugins/broadcast/plugins/plugin6.go34
11 files changed, 144 insertions, 131 deletions
diff --git a/common/pubsub/interface.go b/common/pubsub/interface.go
index 06252d70..5b69d577 100644
--- a/common/pubsub/interface.go
+++ b/common/pubsub/interface.go
@@ -1,5 +1,7 @@
package pubsub
+import "context"
+
/*
This interface is in BETA. It might be changed.
*/
@@ -45,7 +47,7 @@ type Publisher interface {
// Reader interface should return next message
type Reader interface {
- Next() (*Message, error)
+ Next(ctx context.Context) (*Message, error)
}
// Constructor is a special pub-sub interface made to return a constructed PubSub type
diff --git a/plugins/memory/pubsub.go b/plugins/memory/pubsub.go
index c79f3eb0..fd30eb54 100644
--- a/plugins/memory/pubsub.go
+++ b/plugins/memory/pubsub.go
@@ -1,8 +1,10 @@
package memory
import (
+ "context"
"sync"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/pkg/bst"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -65,21 +67,25 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) {
}
}
-func (p *PubSubDriver) Next() (*pubsub.Message, error) {
- msg := <-p.pushCh
- if msg == nil {
- return nil, nil
- }
-
- p.RLock()
- defer p.RUnlock()
+func (p *PubSubDriver) Next(ctx context.Context) (*pubsub.Message, error) {
+ const op = errors.Op("pubsub_memory")
+ select {
+ case msg := <-p.pushCh:
+ if msg == nil {
+ return nil, nil
+ }
- // push only messages, which topics are subscibed
- // TODO better???
- // if we have active subscribers - send a message to a topic
- // or send nil instead
- if ok := p.storage.Contains(msg.Topic); ok {
- return msg, nil
+ p.RLock()
+ defer p.RUnlock()
+ // push only messages, which topics are subscibed
+ // TODO better???
+ // if we have active subscribers - send a message to a topic
+ // or send nil instead
+ if ok := p.storage.Contains(msg.Topic); ok {
+ return msg, nil
+ }
+ case <-ctx.Done():
+ return nil, errors.E(op, errors.TimeOut, ctx.Err())
}
return nil, nil
diff --git a/plugins/redis/pubsub/channel.go b/plugins/redis/pubsub/channel.go
index eef5a7b9..a1655ab2 100644
--- a/plugins/redis/pubsub/channel.go
+++ b/plugins/redis/pubsub/channel.go
@@ -92,6 +92,6 @@ func (r *redisChannel) stop() error {
return nil
}
-func (r *redisChannel) message() *pubsub.Message {
- return <-r.out
+func (r *redisChannel) message() chan *pubsub.Message {
+ return r.out
}
diff --git a/plugins/redis/pubsub/pubsub.go b/plugins/redis/pubsub/pubsub.go
index 95a9f6dd..c9ad3d58 100644
--- a/plugins/redis/pubsub/pubsub.go
+++ b/plugins/redis/pubsub/pubsub.go
@@ -172,6 +172,12 @@ func (p *PubSubDriver) Connections(topic string, res map[string]struct{}) {
}
// Next return next message
-func (p *PubSubDriver) Next() (*pubsub.Message, error) {
- return p.channel.message(), nil
+func (p *PubSubDriver) Next(ctx context.Context) (*pubsub.Message, error) {
+ const op = errors.Op("redis_driver_next")
+ select {
+ case msg := <-p.channel.message():
+ return msg, nil
+ case <-ctx.Done():
+ return nil, errors.E(op, errors.TimeOut, ctx.Err())
+ }
}
diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go
index a7db0f83..395b056f 100644
--- a/plugins/websockets/plugin.go
+++ b/plugins/websockets/plugin.go
@@ -58,6 +58,10 @@ type Plugin struct {
// server which produces commands to the pool
server server.Server
+ // stop receiving messages
+ cancel context.CancelFunc
+ ctx context.Context
+
// function used to validate access to the requested resource
accessValidator validator.AccessValidatorFn
}
@@ -90,6 +94,10 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se
p.server = server
p.log = log
p.broadcaster = b
+
+ ctx, cancel := context.WithCancel(context.Background())
+ p.ctx = ctx
+ p.cancel = cancel
return nil
}
@@ -130,17 +138,17 @@ func (p *Plugin) Serve() chan error {
// we need here only Reader part of the interface
go func(ps pubsub.Reader) {
for {
- select {
- case <-p.serveExit:
- return
- default:
- data, err := ps.Next()
- if err != nil {
- errCh <- errors.E(op, err)
+ data, err := ps.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
- p.workersPool.Queue(data)
+
+ errCh <- errors.E(op, err)
+ return
}
+
+ p.workersPool.Queue(data)
}
}(p.subReader)
@@ -150,6 +158,8 @@ func (p *Plugin) Serve() chan error {
func (p *Plugin) Stop() error {
// close workers pool
p.workersPool.Stop()
+ // cancel context
+ p.cancel()
p.Lock()
if p.phpPool == nil {
p.Unlock()
diff --git a/tests/plugins/broadcast/plugins/plugin1.go b/tests/plugins/broadcast/plugins/plugin1.go
index 01ad1479..ed5139a8 100644
--- a/tests/plugins/broadcast/plugins/plugin1.go
+++ b/tests/plugins/broadcast/plugins/plugin1.go
@@ -1,8 +1,10 @@
package plugins
import (
+ "context"
"fmt"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -14,14 +16,14 @@ type Plugin1 struct {
log logger.Logger
b broadcast.Broadcaster
driver pubsub.SubReader
-
- exit chan struct{}
+ ctx context.Context
+ cancel context.CancelFunc
}
func (p *Plugin1) Init(log logger.Logger, b broadcast.Broadcaster) error {
p.log = log
p.b = b
- p.exit = make(chan struct{}, 1)
+ p.ctx, p.cancel = context.WithCancel(context.Background())
return nil
}
@@ -42,22 +44,16 @@ func (p *Plugin1) Serve() chan error {
go func() {
for {
- select {
- case <-p.exit:
- return
- default:
- msg, err := p.driver.Next()
- if err != nil {
- errCh <- err
+ msg, err := p.driver.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
-
- if msg == nil {
- continue
- }
-
- p.log.Info(fmt.Sprintf("%s: %s", Plugin1Name, *msg))
+ errCh <- err
+ return
}
+
+ p.log.Info(fmt.Sprintf("%s: %s", Plugin1Name, *msg))
}
}()
@@ -68,8 +64,7 @@ func (p *Plugin1) Stop() error {
_ = p.driver.Unsubscribe("1", "foo")
_ = p.driver.Unsubscribe("1", "foo2")
_ = p.driver.Unsubscribe("1", "foo3")
-
- p.exit <- struct{}{}
+ p.cancel()
return nil
}
diff --git a/tests/plugins/broadcast/plugins/plugin2.go b/tests/plugins/broadcast/plugins/plugin2.go
index ee072ffe..20cc1b24 100644
--- a/tests/plugins/broadcast/plugins/plugin2.go
+++ b/tests/plugins/broadcast/plugins/plugin2.go
@@ -1,8 +1,10 @@
package plugins
import (
+ "context"
"fmt"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -14,13 +16,14 @@ type Plugin2 struct {
log logger.Logger
b broadcast.Broadcaster
driver pubsub.SubReader
- exit chan struct{}
+ ctx context.Context
+ cancel context.CancelFunc
}
func (p *Plugin2) Init(log logger.Logger, b broadcast.Broadcaster) error {
p.log = log
p.b = b
- p.exit = make(chan struct{}, 1)
+ p.ctx, p.cancel = context.WithCancel(context.Background())
return nil
}
@@ -40,22 +43,20 @@ func (p *Plugin2) Serve() chan error {
go func() {
for {
- select {
- case <-p.exit:
- return
- default:
- msg, err := p.driver.Next()
- if err != nil {
- errCh <- err
+ msg, err := p.driver.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
+ errCh <- err
+ return
+ }
- if msg == nil {
- continue
- }
-
- p.log.Info(fmt.Sprintf("%s: %s", Plugin2Name, *msg))
+ if msg == nil {
+ continue
}
+
+ p.log.Info(fmt.Sprintf("%s: %s", Plugin2Name, *msg))
}
}()
@@ -64,7 +65,7 @@ func (p *Plugin2) Serve() chan error {
func (p *Plugin2) Stop() error {
_ = p.driver.Unsubscribe("2", "foo")
- p.exit <- struct{}{}
+ p.cancel()
return nil
}
diff --git a/tests/plugins/broadcast/plugins/plugin3.go b/tests/plugins/broadcast/plugins/plugin3.go
index 288201d1..2f416d2e 100644
--- a/tests/plugins/broadcast/plugins/plugin3.go
+++ b/tests/plugins/broadcast/plugins/plugin3.go
@@ -1,8 +1,10 @@
package plugins
import (
+ "context"
"fmt"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -14,15 +16,14 @@ type Plugin3 struct {
log logger.Logger
b broadcast.Broadcaster
driver pubsub.SubReader
-
- exit chan struct{}
+ ctx context.Context
+ cancel context.CancelFunc
}
func (p *Plugin3) Init(log logger.Logger, b broadcast.Broadcaster) error {
p.log = log
p.b = b
-
- p.exit = make(chan struct{}, 1)
+ p.ctx, p.cancel = context.WithCancel(context.Background())
return nil
}
@@ -42,22 +43,20 @@ func (p *Plugin3) Serve() chan error {
go func() {
for {
- select {
- case <-p.exit:
- return
- default:
- msg, err := p.driver.Next()
- if err != nil {
- errCh <- err
+ msg, err := p.driver.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
+ errCh <- err
+ return
+ }
- if msg == nil {
- continue
- }
-
- p.log.Info(fmt.Sprintf("%s: %s", Plugin3Name, *msg))
+ if msg == nil {
+ continue
}
+
+ p.log.Info(fmt.Sprintf("%s: %s", Plugin3Name, *msg))
}
}()
@@ -66,7 +65,7 @@ func (p *Plugin3) Serve() chan error {
func (p *Plugin3) Stop() error {
_ = p.driver.Unsubscribe("3", "foo")
- p.exit <- struct{}{}
+ p.cancel()
return nil
}
diff --git a/tests/plugins/broadcast/plugins/plugin4.go b/tests/plugins/broadcast/plugins/plugin4.go
index 56f79c0f..e2209648 100644
--- a/tests/plugins/broadcast/plugins/plugin4.go
+++ b/tests/plugins/broadcast/plugins/plugin4.go
@@ -1,8 +1,10 @@
package plugins
import (
+ "context"
"fmt"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -14,15 +16,14 @@ type Plugin4 struct {
log logger.Logger
b broadcast.Broadcaster
driver pubsub.SubReader
-
- exit chan struct{}
+ ctx context.Context
+ cancel context.CancelFunc
}
func (p *Plugin4) Init(log logger.Logger, b broadcast.Broadcaster) error {
p.log = log
p.b = b
-
- p.exit = make(chan struct{}, 1)
+ p.ctx, p.cancel = context.WithCancel(context.Background())
return nil
}
@@ -42,22 +43,20 @@ func (p *Plugin4) Serve() chan error {
go func() {
for {
- select {
- case <-p.exit:
- return
- default:
- msg, err := p.driver.Next()
- if err != nil {
- errCh <- err
+ msg, err := p.driver.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
+ errCh <- err
+ return
+ }
- if msg == nil {
- continue
- }
-
- p.log.Info(fmt.Sprintf("%s: %s", Plugin4Name, *msg))
+ if msg == nil {
+ continue
}
+
+ p.log.Info(fmt.Sprintf("%s: %s", Plugin4Name, *msg))
}
}()
@@ -66,8 +65,7 @@ func (p *Plugin4) Serve() chan error {
func (p *Plugin4) Stop() error {
_ = p.driver.Unsubscribe("4", "foo")
-
- p.exit <- struct{}{}
+ p.cancel()
return nil
}
diff --git a/tests/plugins/broadcast/plugins/plugin5.go b/tests/plugins/broadcast/plugins/plugin5.go
index e7cd7e60..122046b8 100644
--- a/tests/plugins/broadcast/plugins/plugin5.go
+++ b/tests/plugins/broadcast/plugins/plugin5.go
@@ -1,8 +1,10 @@
package plugins
import (
+ "context"
"fmt"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -14,15 +16,14 @@ type Plugin5 struct {
log logger.Logger
b broadcast.Broadcaster
driver pubsub.SubReader
-
- exit chan struct{}
+ ctx context.Context
+ cancel context.CancelFunc
}
func (p *Plugin5) Init(log logger.Logger, b broadcast.Broadcaster) error {
p.log = log
p.b = b
-
- p.exit = make(chan struct{}, 1)
+ p.ctx, p.cancel = context.WithCancel(context.Background())
return nil
}
@@ -42,22 +43,20 @@ func (p *Plugin5) Serve() chan error {
go func() {
for {
- select {
- case <-p.exit:
- return
- default:
- msg, err := p.driver.Next()
- if err != nil {
- errCh <- err
+ msg, err := p.driver.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
+ errCh <- err
+ return
+ }
- if msg == nil {
- continue
- }
-
- p.log.Info(fmt.Sprintf("%s: %s", Plugin5Name, *msg))
+ if msg == nil {
+ continue
}
+
+ p.log.Info(fmt.Sprintf("%s: %s", Plugin5Name, *msg))
}
}()
@@ -66,8 +65,7 @@ func (p *Plugin5) Serve() chan error {
func (p *Plugin5) Stop() error {
_ = p.driver.Unsubscribe("5", "foo")
-
- p.exit <- struct{}{}
+ p.cancel()
return nil
}
diff --git a/tests/plugins/broadcast/plugins/plugin6.go b/tests/plugins/broadcast/plugins/plugin6.go
index 08272196..6ace0a79 100644
--- a/tests/plugins/broadcast/plugins/plugin6.go
+++ b/tests/plugins/broadcast/plugins/plugin6.go
@@ -1,8 +1,10 @@
package plugins
import (
+ "context"
"fmt"
+ "github.com/spiral/errors"
"github.com/spiral/roadrunner/v2/common/pubsub"
"github.com/spiral/roadrunner/v2/plugins/broadcast"
"github.com/spiral/roadrunner/v2/plugins/logger"
@@ -14,15 +16,14 @@ type Plugin6 struct {
log logger.Logger
b broadcast.Broadcaster
driver pubsub.SubReader
-
- exit chan struct{}
+ ctx context.Context
+ cancel context.CancelFunc
}
func (p *Plugin6) Init(log logger.Logger, b broadcast.Broadcaster) error {
p.log = log
p.b = b
-
- p.exit = make(chan struct{}, 1)
+ p.ctx, p.cancel = context.WithCancel(context.Background())
return nil
}
@@ -42,22 +43,20 @@ func (p *Plugin6) Serve() chan error {
go func() {
for {
- select {
- case <-p.exit:
- return
- default:
- msg, err := p.driver.Next()
- if err != nil {
- errCh <- err
+ msg, err := p.driver.Next(p.ctx)
+ if err != nil {
+ if errors.Is(errors.TimeOut, err) {
return
}
+ errCh <- err
+ return
+ }
- if msg == nil {
- continue
- }
-
- p.log.Info(fmt.Sprintf("%s: %s", Plugin6Name, *msg))
+ if msg == nil {
+ continue
}
+
+ p.log.Info(fmt.Sprintf("%s: %s", Plugin6Name, *msg))
}
}()
@@ -66,8 +65,7 @@ func (p *Plugin6) Serve() chan error {
func (p *Plugin6) Stop() error {
_ = p.driver.Unsubscribe("6", "foo")
-
- p.exit <- struct{}{}
+ p.cancel()
return nil
}