diff options
author | Valery Piashchynski <[email protected]> | 2021-07-27 12:39:01 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2021-07-27 12:39:01 +0300 |
commit | 1e59ec2755a9cdafd26864ba532fa4d3eff46ecd (patch) | |
tree | 68c7c7e8d9f4d99debc4895ab8469e323c60f47b /plugins | |
parent | d72181126867c7e8fc05e5ac927bd90d01e0dbc7 (diff) |
Initial support for the cancellation via context
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins')
-rw-r--r-- | plugins/jobs/config.go | 7 | ||||
-rw-r--r-- | plugins/jobs/drivers/amqp/consumer.go | 156 | ||||
-rw-r--r-- | plugins/jobs/drivers/amqp/redial.go | 11 | ||||
-rw-r--r-- | plugins/jobs/drivers/beanstalk/connection.go | 25 | ||||
-rw-r--r-- | plugins/jobs/drivers/beanstalk/consumer.go | 23 | ||||
-rw-r--r-- | plugins/jobs/drivers/beanstalk/listen.go | 10 | ||||
-rw-r--r-- | plugins/jobs/drivers/ephemeral/consumer.go | 13 | ||||
-rw-r--r-- | plugins/jobs/drivers/sqs/consumer.go | 21 | ||||
-rw-r--r-- | plugins/jobs/drivers/sqs/listener.go | 4 | ||||
-rw-r--r-- | plugins/jobs/plugin.go | 44 |
10 files changed, 186 insertions, 128 deletions
diff --git a/plugins/jobs/config.go b/plugins/jobs/config.go index 1b613231..dfcfcb95 100644 --- a/plugins/jobs/config.go +++ b/plugins/jobs/config.go @@ -23,6 +23,9 @@ type Config struct { // Driver pipeline might be much larger than a main jobs queue PipelineSize uint64 `mapstructure:"pipeline_size"` + // Timeout in seconds is the per-push limit to put the job into queue + Timeout int `mapstructure:"timeout"` + // Pool configures roadrunner workers pool. Pool *poolImpl.Config `mapstructure:"Pool"` @@ -51,5 +54,9 @@ func (c *Config) InitDefaults() { c.Pipelines[k].With(pipelineName, k) } + if c.Timeout == 0 { + c.Timeout = 10 + } + c.Pool.InitDefaults() } diff --git a/plugins/jobs/drivers/amqp/consumer.go b/plugins/jobs/drivers/amqp/consumer.go index 8c55399c..714a714a 100644 --- a/plugins/jobs/drivers/amqp/consumer.go +++ b/plugins/jobs/drivers/amqp/consumer.go @@ -1,6 +1,7 @@ package amqp import ( + "context" "fmt" "sync" "sync/atomic" @@ -28,12 +29,18 @@ type JobConsumer struct { // amqp connection conn *amqp.Connection consumeChan *amqp.Channel - publishChan *amqp.Channel + publishChan chan *amqp.Channel consumeID string connStr string - retryTimeout time.Duration - prefetch int + retryTimeout time.Duration + // + // prefetch QoS AMQP + // + prefetch int + // + // pipeline's priority + // priority int64 exchangeName string queue string @@ -95,6 +102,7 @@ func NewAMQPConsumer(configKey string, log logger.Logger, cfg config.Configurer, delayCache: make(map[string]struct{}, 100), priority: pipeCfg.Priority, + publishChan: make(chan *amqp.Channel, 1), routingKey: pipeCfg.RoutingKey, queue: pipeCfg.Queue, exchangeType: pipeCfg.ExchangeType, @@ -118,11 +126,13 @@ func NewAMQPConsumer(configKey string, log logger.Logger, cfg config.Configurer, return nil, errors.E(op, err) } - jb.publishChan, err = jb.conn.Channel() + pch, err := jb.conn.Channel() if err != nil { return nil, errors.E(op, err) } + jb.publishChan <- pch + // run redialer for the connection jb.redialer() @@ -161,6 +171,7 @@ func FromPipeline(pipeline *pipeline.Pipeline, log logger.Logger, cfg config.Con retryTimeout: time.Minute * 5, delayCache: make(map[string]struct{}, 100), + publishChan: make(chan *amqp.Channel, 1), routingKey: pipeline.String(routingKey, ""), queue: pipeline.String(queue, "default"), exchangeType: pipeline.String(exchangeType, "direct"), @@ -185,14 +196,16 @@ func FromPipeline(pipeline *pipeline.Pipeline, log logger.Logger, cfg config.Con return nil, errors.E(op, err) } - jb.publishChan, err = jb.conn.Channel() + pch, err := jb.conn.Channel() if err != nil { return nil, errors.E(op, err) } + jb.publishChan <- pch + // register the pipeline // error here is always nil - _ = jb.Register(pipeline) + _ = jb.Register(context.Background(), pipeline) // run redialer for the connection jb.redialer() @@ -200,7 +213,7 @@ func FromPipeline(pipeline *pipeline.Pipeline, log logger.Logger, cfg config.Con return jb, nil } -func (j *JobConsumer) Push(job *job.Job) error { +func (j *JobConsumer) Push(ctx context.Context, job *job.Job) error { const op = errors.Op("rabbitmq_push") // check if the pipeline registered @@ -212,32 +225,69 @@ func (j *JobConsumer) Push(job *job.Job) error { // lock needed here to protect redial concurrent operation // we may be in the redial state here - j.Lock() - defer j.Unlock() - // convert - msg := fromJob(job) - p, err := pack(job.Ident, msg) - if err != nil { - return errors.E(op, err) - } + select { + case pch := <-j.publishChan: + // return the channel back + defer func() { + j.publishChan <- pch + }() + + // convert + msg := fromJob(job) + p, err := pack(job.Ident, msg) + if err != nil { + return errors.E(op, err) + } + + // handle timeouts + if msg.Options.DelayDuration() > 0 { + // TODO declare separate method for this if condition + delayMs := int64(msg.Options.DelayDuration().Seconds() * 1000) + tmpQ := fmt.Sprintf("delayed-%d.%s.%s", delayMs, j.exchangeName, j.queue) + + // delay cache optimization. + // If user already declared a queue with a delay, do not redeclare and rebind the queue + // Before -> 2.5k RPS with redeclaration + // After -> 30k RPS + if _, exists := j.delayCache[tmpQ]; exists { + // insert to the local, limited pipeline + err = pch.Publish(j.exchangeName, tmpQ, false, false, amqp.Publishing{ + Headers: p, + ContentType: contentType, + Timestamp: time.Now().UTC(), + DeliveryMode: amqp.Persistent, + Body: msg.Body(), + }) + + if err != nil { + return errors.E(op, err) + } + + return nil + } + + _, err = pch.QueueDeclare(tmpQ, true, false, false, false, amqp.Table{ + dlx: j.exchangeName, + dlxRoutingKey: j.routingKey, + dlxTTL: delayMs, + dlxExpires: delayMs * 2, + }) + + if err != nil { + return errors.E(op, err) + } + + err = pch.QueueBind(tmpQ, tmpQ, j.exchangeName, false, nil) + if err != nil { + return errors.E(op, err) + } - // handle timeouts - if msg.Options.DelayDuration() > 0 { - // TODO declare separate method for this if condition - delayMs := int64(msg.Options.DelayDuration().Seconds() * 1000) - tmpQ := fmt.Sprintf("delayed-%d.%s.%s", delayMs, j.exchangeName, j.queue) - - // delay cache optimization. - // If user already declared a queue with a delay, do not redeclare and rebind the queue - // Before -> 2.5k RPS with redeclaration - // After -> 30k RPS - if _, exists := j.delayCache[tmpQ]; exists { // insert to the local, limited pipeline - err = j.publishChan.Publish(j.exchangeName, tmpQ, false, false, amqp.Publishing{ + err = pch.Publish(j.exchangeName, tmpQ, false, false, amqp.Publishing{ Headers: p, ContentType: contentType, - Timestamp: time.Now(), + Timestamp: time.Now().UTC(), DeliveryMode: amqp.Persistent, Body: msg.Body(), }) @@ -246,64 +296,36 @@ func (j *JobConsumer) Push(job *job.Job) error { return errors.E(op, err) } - return nil - } - - _, err = j.publishChan.QueueDeclare(tmpQ, true, false, false, false, amqp.Table{ - dlx: j.exchangeName, - dlxRoutingKey: j.routingKey, - dlxTTL: delayMs, - dlxExpires: delayMs * 2, - }) + j.delayCache[tmpQ] = struct{}{} - if err != nil { - return errors.E(op, err) - } - - err = j.publishChan.QueueBind(tmpQ, tmpQ, j.exchangeName, false, nil) - if err != nil { - return errors.E(op, err) + return nil } // insert to the local, limited pipeline - err = j.publishChan.Publish(j.exchangeName, tmpQ, false, false, amqp.Publishing{ + err = pch.Publish(j.exchangeName, j.routingKey, false, false, amqp.Publishing{ Headers: p, ContentType: contentType, Timestamp: time.Now(), DeliveryMode: amqp.Persistent, Body: msg.Body(), }) - if err != nil { return errors.E(op, err) } - j.delayCache[tmpQ] = struct{}{} - return nil - } - // insert to the local, limited pipeline - err = j.publishChan.Publish(j.exchangeName, j.routingKey, false, false, amqp.Publishing{ - Headers: p, - ContentType: contentType, - Timestamp: time.Now(), - DeliveryMode: amqp.Persistent, - Body: msg.Body(), - }) - if err != nil { - return errors.E(op, err) + case <-ctx.Done(): + return errors.E(op, errors.TimeOut, ctx.Err()) } - - return nil } -func (j *JobConsumer) Register(pipeline *pipeline.Pipeline) error { - j.pipeline.Store(pipeline) +func (j *JobConsumer) Register(_ context.Context, p *pipeline.Pipeline) error { + j.pipeline.Store(p) return nil } -func (j *JobConsumer) Run(p *pipeline.Pipeline) error { +func (j *JobConsumer) Run(ctx context.Context, p *pipeline.Pipeline) error { const op = errors.Op("rabbit_consume") pipe := j.pipeline.Load().(*pipeline.Pipeline) @@ -353,7 +375,7 @@ func (j *JobConsumer) Run(p *pipeline.Pipeline) error { return nil } -func (j *JobConsumer) Pause(p string) { +func (j *JobConsumer) Pause(ctx context.Context, p string) { pipe := j.pipeline.Load().(*pipeline.Pipeline) if pipe.Name() != p { j.log.Error("no such pipeline", "requested pause on: ", p) @@ -391,7 +413,7 @@ func (j *JobConsumer) Pause(p string) { }) } -func (j *JobConsumer) Resume(p string) { +func (j *JobConsumer) Resume(ctx context.Context, p string) { pipe := j.pipeline.Load().(*pipeline.Pipeline) if pipe.Name() != p { j.log.Error("no such pipeline", "requested resume on: ", p) @@ -450,7 +472,7 @@ func (j *JobConsumer) Resume(p string) { }) } -func (j *JobConsumer) Stop() error { +func (j *JobConsumer) Stop(context.Context) error { j.stopCh <- struct{}{} pipe := j.pipeline.Load().(*pipeline.Pipeline) diff --git a/plugins/jobs/drivers/amqp/redial.go b/plugins/jobs/drivers/amqp/redial.go index d61c75b2..fd19f1ce 100644 --- a/plugins/jobs/drivers/amqp/redial.go +++ b/plugins/jobs/drivers/amqp/redial.go @@ -24,6 +24,9 @@ func (j *JobConsumer) redialer() { //nolint:gocognit j.Lock() + // trash the broken publish channel + <-j.publishChan + t := time.Now() pipe := j.pipeline.Load().(*pipeline.Pipeline) @@ -63,8 +66,7 @@ func (j *JobConsumer) redialer() { //nolint:gocognit } // redeclare publish channel - var errPubCh error - j.publishChan, errPubCh = j.conn.Channel() + pch, errPubCh := j.conn.Channel() if errPubCh != nil { return errors.E(op, errPubCh) } @@ -83,10 +85,12 @@ func (j *JobConsumer) redialer() { //nolint:gocognit return errors.E(op, err) } + j.publishChan <- pch // restart listener j.listener(deliv) j.log.Info("queues and subscribers redeclared successfully") + return nil } @@ -109,7 +113,8 @@ func (j *JobConsumer) redialer() { //nolint:gocognit case <-j.stopCh: if j.publishChan != nil { - err := j.publishChan.Close() + pch := <-j.publishChan + err := pch.Close() if err != nil { j.log.Error("publish channel close", "error", err) } diff --git a/plugins/jobs/drivers/beanstalk/connection.go b/plugins/jobs/drivers/beanstalk/connection.go index 6cc50c07..797b4821 100644 --- a/plugins/jobs/drivers/beanstalk/connection.go +++ b/plugins/jobs/drivers/beanstalk/connection.go @@ -1,6 +1,7 @@ package beanstalk import ( + "context" "net" "sync" "time" @@ -54,14 +55,14 @@ func NewConnPool(network, address, tName string, tout time.Duration, log logger. }, nil } -func (cp *ConnPool) Put(body []byte, pri uint32, delay, ttr time.Duration) (uint64, error) { +func (cp *ConnPool) Put(ctx context.Context, body []byte, pri uint32, delay, ttr time.Duration) (uint64, error) { cp.RLock() defer cp.RUnlock() id, err := cp.t.Put(body, pri, delay, ttr) if err != nil { // errN contains both, err and internal checkAndRedial error - errN := cp.checkAndRedial(err) + errN := cp.checkAndRedial(ctx, err) if errN != nil { return 0, errN } else { @@ -80,14 +81,14 @@ func (cp *ConnPool) Put(body []byte, pri uint32, delay, ttr time.Duration) (uint // Typically, a client will reserve a job, perform some work, then delete // the job with Conn.Delete. -func (cp *ConnPool) Reserve(reserveTimeout time.Duration) (uint64, []byte, error) { +func (cp *ConnPool) Reserve(ctx context.Context, reserveTimeout time.Duration) (uint64, []byte, error) { cp.RLock() defer cp.RUnlock() id, body, err := cp.ts.Reserve(reserveTimeout) if err != nil { // errN contains both, err and internal checkAndRedial error - errN := cp.checkAndRedial(err) + errN := cp.checkAndRedial(ctx, err) if errN != nil { return 0, nil, errN } else { @@ -99,14 +100,14 @@ func (cp *ConnPool) Reserve(reserveTimeout time.Duration) (uint64, []byte, error return id, body, nil } -func (cp *ConnPool) Delete(id uint64) error { +func (cp *ConnPool) Delete(ctx context.Context, id uint64) error { cp.RLock() defer cp.RUnlock() err := cp.conn.Delete(id) if err != nil { // errN contains both, err and internal checkAndRedial error - errN := cp.checkAndRedial(err) + errN := cp.checkAndRedial(ctx, err) if errN != nil { return errN } else { @@ -117,14 +118,12 @@ func (cp *ConnPool) Delete(id uint64) error { return nil } -func (cp *ConnPool) redial() error { +func (cp *ConnPool) redial(ctx context.Context) error { const op = errors.Op("connection_pool_redial") cp.Lock() // backoff here - expb := backoff.NewExponentialBackOff() - // set the retry timeout (minutes) - expb.MaxElapsedTime = time.Minute * 5 + expb := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) operation := func() error { connT, err := beanstalk.DialTimeout(cp.network, cp.address, cp.tout) @@ -165,7 +164,7 @@ func (cp *ConnPool) redial() error { var connErrors = map[string]struct{}{"EOF": {}} -func (cp *ConnPool) checkAndRedial(err error) error { +func (cp *ConnPool) checkAndRedial(ctx context.Context, err error) error { const op = errors.Op("connection_pool_check_redial") switch et := err.(type) { //nolint:gocritic // check if the error @@ -173,7 +172,7 @@ func (cp *ConnPool) checkAndRedial(err error) error { switch bErr := et.Err.(type) { case *net.OpError: cp.RUnlock() - errR := cp.redial() + errR := cp.redial(ctx) cp.RLock() // if redial failed - return if errR != nil { @@ -186,7 +185,7 @@ func (cp *ConnPool) checkAndRedial(err error) error { if _, ok := connErrors[et.Err.Error()]; ok { // if error is related to the broken connection - redial cp.RUnlock() - errR := cp.redial() + errR := cp.redial(ctx) cp.RLock() // if redial failed - return if errR != nil { diff --git a/plugins/jobs/drivers/beanstalk/consumer.go b/plugins/jobs/drivers/beanstalk/consumer.go index dec54426..90da3801 100644 --- a/plugins/jobs/drivers/beanstalk/consumer.go +++ b/plugins/jobs/drivers/beanstalk/consumer.go @@ -2,6 +2,7 @@ package beanstalk import ( "bytes" + "context" "strings" "sync/atomic" "time" @@ -139,7 +140,7 @@ func FromPipeline(pipe *pipeline.Pipeline, log logger.Logger, cfg config.Configu return jc, nil } -func (j *JobConsumer) Push(jb *job.Job) error { +func (j *JobConsumer) Push(ctx context.Context, jb *job.Job) error { const op = errors.Op("beanstalk_push") // check if the pipeline registered @@ -173,9 +174,9 @@ func (j *JobConsumer) Push(jb *job.Job) error { // <ttr> seconds, the job will time out and the server will release the job. // The minimum ttr is 1. If the client sends 0, the server will silently // increase the ttr to 1. Maximum ttr is 2**32-1. - id, err := j.pool.Put(bb.Bytes(), j.tubePriority, item.Options.DelayDuration(), item.Options.TimeoutDuration()) + id, err := j.pool.Put(ctx, bb.Bytes(), j.tubePriority, item.Options.DelayDuration(), item.Options.TimeoutDuration()) if err != nil { - errD := j.pool.Delete(id) + errD := j.pool.Delete(ctx, id) if errD != nil { return errors.E(op, errors.Errorf("%s:%s", err.Error(), errD.Error())) } @@ -185,13 +186,13 @@ func (j *JobConsumer) Push(jb *job.Job) error { return nil } -func (j *JobConsumer) Register(pipeline *pipeline.Pipeline) error { +func (j *JobConsumer) Register(ctx context.Context, p *pipeline.Pipeline) error { // register the pipeline - j.pipeline.Store(pipeline) + j.pipeline.Store(p) return nil } -func (j *JobConsumer) Run(p *pipeline.Pipeline) error { +func (j *JobConsumer) Run(ctx context.Context, p *pipeline.Pipeline) error { const op = errors.Op("beanstalk_run") // check if the pipeline registered @@ -203,7 +204,7 @@ func (j *JobConsumer) Run(p *pipeline.Pipeline) error { atomic.AddUint32(&j.listeners, 1) - go j.listen() + go j.listen(ctx) j.eh.Push(events.JobEvent{ Event: events.EventPipeActive, @@ -215,7 +216,7 @@ func (j *JobConsumer) Run(p *pipeline.Pipeline) error { return nil } -func (j *JobConsumer) Stop() error { +func (j *JobConsumer) Stop(context.Context) error { pipe := j.pipeline.Load().(*pipeline.Pipeline) if atomic.LoadUint32(&j.listeners) == 1 { @@ -232,7 +233,7 @@ func (j *JobConsumer) Stop() error { return nil } -func (j *JobConsumer) Pause(p string) { +func (j *JobConsumer) Pause(ctx context.Context, p string) { // load atomic value pipe := j.pipeline.Load().(*pipeline.Pipeline) if pipe.Name() != p { @@ -259,7 +260,7 @@ func (j *JobConsumer) Pause(p string) { }) } -func (j *JobConsumer) Resume(p string) { +func (j *JobConsumer) Resume(ctx context.Context, p string) { // load atomic value pipe := j.pipeline.Load().(*pipeline.Pipeline) if pipe.Name() != p { @@ -275,7 +276,7 @@ func (j *JobConsumer) Resume(p string) { } // start listener - go j.listen() + go j.listen(ctx) // increase num of listeners atomic.AddUint32(&j.listeners, 1) diff --git a/plugins/jobs/drivers/beanstalk/listen.go b/plugins/jobs/drivers/beanstalk/listen.go index 3e9061a3..b872cbd4 100644 --- a/plugins/jobs/drivers/beanstalk/listen.go +++ b/plugins/jobs/drivers/beanstalk/listen.go @@ -1,15 +1,19 @@ package beanstalk -import "github.com/beanstalkd/go-beanstalk" +import ( + "context" -func (j *JobConsumer) listen() { + "github.com/beanstalkd/go-beanstalk" +) + +func (j *JobConsumer) listen(ctx context.Context) { for { select { case <-j.stopCh: j.log.Warn("beanstalk listener stopped") return default: - id, body, err := j.pool.Reserve(j.reserveTimeout) + id, body, err := j.pool.Reserve(ctx, j.reserveTimeout) if err != nil { if errB, ok := err.(beanstalk.ConnError); ok { switch errB.Err { //nolint:gocritic diff --git a/plugins/jobs/drivers/ephemeral/consumer.go b/plugins/jobs/drivers/ephemeral/consumer.go index 9de64b82..043da118 100644 --- a/plugins/jobs/drivers/ephemeral/consumer.go +++ b/plugins/jobs/drivers/ephemeral/consumer.go @@ -1,6 +1,7 @@ package ephemeral import ( + "context" "sync" "sync/atomic" "time" @@ -82,7 +83,7 @@ func FromPipeline(pipeline *pipeline.Pipeline, log logger.Logger, eh events.Hand return jb, nil } -func (j *JobConsumer) Push(jb *job.Job) error { +func (j *JobConsumer) Push(ctx context.Context, jb *job.Job) error { const op = errors.Op("ephemeral_push") // check if the pipeline registered @@ -139,7 +140,7 @@ func (j *JobConsumer) consume() { } } -func (j *JobConsumer) Register(pipeline *pipeline.Pipeline) error { +func (j *JobConsumer) Register(ctx context.Context, pipeline *pipeline.Pipeline) error { const op = errors.Op("ephemeral_register") if _, ok := j.pipeline.Load(pipeline.Name()); ok { return errors.E(op, errors.Errorf("queue %s has already been registered", pipeline)) @@ -150,7 +151,7 @@ func (j *JobConsumer) Register(pipeline *pipeline.Pipeline) error { return nil } -func (j *JobConsumer) Pause(pipeline string) { +func (j *JobConsumer) Pause(ctx context.Context, pipeline string) { if q, ok := j.pipeline.Load(pipeline); ok { if q == true { // mark pipeline as turned off @@ -166,7 +167,7 @@ func (j *JobConsumer) Pause(pipeline string) { }) } -func (j *JobConsumer) Resume(pipeline string) { +func (j *JobConsumer) Resume(ctx context.Context, pipeline string) { if q, ok := j.pipeline.Load(pipeline); ok { if q == false { // mark pipeline as turned on @@ -183,7 +184,7 @@ func (j *JobConsumer) Resume(pipeline string) { } // Run is no-op for the ephemeral -func (j *JobConsumer) Run(pipe *pipeline.Pipeline) error { +func (j *JobConsumer) Run(ctx context.Context, pipe *pipeline.Pipeline) error { j.eh.Push(events.JobEvent{ Event: events.EventPipeActive, Driver: pipe.Driver(), @@ -193,7 +194,7 @@ func (j *JobConsumer) Run(pipe *pipeline.Pipeline) error { return nil } -func (j *JobConsumer) Stop() error { +func (j *JobConsumer) Stop(context.Context) error { var pipe string j.pipeline.Range(func(key, _ interface{}) bool { pipe = key.(string) diff --git a/plugins/jobs/drivers/sqs/consumer.go b/plugins/jobs/drivers/sqs/consumer.go index 08a6170e..b81d08e5 100644 --- a/plugins/jobs/drivers/sqs/consumer.go +++ b/plugins/jobs/drivers/sqs/consumer.go @@ -238,7 +238,7 @@ func FromPipeline(pipe *pipeline.Pipeline, log logger.Logger, cfg cfgPlugin.Conf return jb, nil } -func (j *JobConsumer) Push(jb *job.Job) error { +func (j *JobConsumer) Push(ctx context.Context, jb *job.Job) error { const op = errors.Op("sqs_push") // check if the pipeline registered @@ -256,9 +256,6 @@ func (j *JobConsumer) Push(jb *job.Job) error { msg := fromJob(jb) - // 10 seconds deadline to make a request TODO ??? - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) - defer cancel() // The new value for the message's visibility timeout (in seconds). Values range: 0 // to 43200. Maximum: 12 hours. _, err := j.client.SendMessage(ctx, msg.pack(j.queueURL)) @@ -269,12 +266,12 @@ func (j *JobConsumer) Push(jb *job.Job) error { return nil } -func (j *JobConsumer) Register(pipeline *pipeline.Pipeline) error { - j.pipeline.Store(pipeline) +func (j *JobConsumer) Register(_ context.Context, p *pipeline.Pipeline) error { + j.pipeline.Store(p) return nil } -func (j *JobConsumer) Run(p *pipeline.Pipeline) error { +func (j *JobConsumer) Run(_ context.Context, p *pipeline.Pipeline) error { const op = errors.Op("rabbit_consume") j.Lock() @@ -288,7 +285,7 @@ func (j *JobConsumer) Run(p *pipeline.Pipeline) error { atomic.AddUint32(&j.listeners, 1) // start listener - go j.listen() + go j.listen(context.Background()) j.eh.Push(events.JobEvent{ Event: events.EventPipeActive, @@ -300,7 +297,7 @@ func (j *JobConsumer) Run(p *pipeline.Pipeline) error { return nil } -func (j *JobConsumer) Stop() error { +func (j *JobConsumer) Stop(context.Context) error { j.pauseCh <- struct{}{} pipe := j.pipeline.Load().(*pipeline.Pipeline) @@ -313,7 +310,7 @@ func (j *JobConsumer) Stop() error { return nil } -func (j *JobConsumer) Pause(p string) { +func (j *JobConsumer) Pause(ctx context.Context, p string) { // load atomic value pipe := j.pipeline.Load().(*pipeline.Pipeline) if pipe.Name() != p { @@ -341,7 +338,7 @@ func (j *JobConsumer) Pause(p string) { }) } -func (j *JobConsumer) Resume(p string) { +func (j *JobConsumer) Resume(_ context.Context, p string) { // load atomic value pipe := j.pipeline.Load().(*pipeline.Pipeline) if pipe.Name() != p { @@ -357,7 +354,7 @@ func (j *JobConsumer) Resume(p string) { } // start listener - go j.listen() + go j.listen(context.Background()) // increase num of listeners atomic.AddUint32(&j.listeners, 1) diff --git a/plugins/jobs/drivers/sqs/listener.go b/plugins/jobs/drivers/sqs/listener.go index 887f8358..e2323fa3 100644 --- a/plugins/jobs/drivers/sqs/listener.go +++ b/plugins/jobs/drivers/sqs/listener.go @@ -18,14 +18,14 @@ const ( NonExistentQueue string = "AWS.SimpleQueueService.NonExistentQueue" ) -func (j *JobConsumer) listen() { //nolint:gocognit +func (j *JobConsumer) listen(ctx context.Context) { //nolint:gocognit for { select { case <-j.pauseCh: j.log.Warn("sqs listener stopped") return default: - message, err := j.client.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{ + message, err := j.client.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ QueueUrl: j.queueURL, MaxNumberOfMessages: j.prefetch, AttributeNames: []types.QueueAttributeName{types.QueueAttributeName(ApproximateReceiveCount)}, diff --git a/plugins/jobs/plugin.go b/plugins/jobs/plugin.go index 5779b368..d2d2ed9f 100644 --- a/plugins/jobs/plugin.go +++ b/plugins/jobs/plugin.go @@ -146,7 +146,7 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit p.consumers[name] = initializedDriver // register pipeline for the initialized driver - err = initializedDriver.Register(pipe) + err = initializedDriver.Register(context.Background(), pipe) if err != nil { errCh <- errors.E(op, errors.Errorf("pipe register failed for the driver: %s with pipe name: %s", pipe.Driver(), pipe.Name())) return false @@ -154,7 +154,9 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit // if pipeline initialized to be consumed, call Run on it if _, ok := p.consume[name]; ok { - err = initializedDriver.Run(pipe) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + defer cancel() + err = initializedDriver.Run(ctx, pipe) if err != nil { errCh <- errors.E(op, err) return false @@ -265,11 +267,14 @@ func (p *Plugin) Serve() chan error { //nolint:gocognit func (p *Plugin) Stop() error { for k, v := range p.consumers { - err := v.Stop() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + err := v.Stop(ctx) if err != nil { + cancel() p.log.Error("stop job driver", "driver", k) continue } + cancel() } // this function can block forever, but we don't care, because we might have a chance to exit from the pollers, @@ -347,11 +352,17 @@ func (p *Plugin) Push(j *job.Job) error { j.Options.Priority = ppl.Priority() } - err := d.Push(j) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + defer cancel() + + err := d.Push(ctx, j) if err != nil { + cancel() return errors.E(op, err) } + cancel() + return nil } @@ -377,10 +388,14 @@ func (p *Plugin) PushBatch(j []*job.Job) error { j[i].Options.Priority = ppl.Priority() } - err := d.Push(j[i]) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + err := d.Push(ctx, j[i]) if err != nil { + cancel() return errors.E(op, err) } + + cancel() } return nil @@ -400,9 +415,10 @@ func (p *Plugin) Pause(pp string) { p.log.Warn("driver for the pipeline not found", "pipeline", pp) return } - + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + defer cancel() // redirect call to the underlying driver - d.Pause(ppl.Name()) + d.Pause(ctx, ppl.Name()) } func (p *Plugin) Resume(pp string) { @@ -419,8 +435,10 @@ func (p *Plugin) Resume(pp string) { return } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + defer cancel() // redirect call to the underlying driver - d.Resume(ppl.Name()) + d.Resume(ctx, ppl.Name()) } // Declare a pipeline. @@ -445,7 +463,7 @@ func (p *Plugin) Declare(pipeline *pipeline.Pipeline) error { p.consumers[pipeline.Name()] = initializedDriver // register pipeline for the initialized driver - err = initializedDriver.Register(pipeline) + err = initializedDriver.Register(context.Background(), pipeline) if err != nil { return errors.E(op, errors.Errorf("pipe register failed for the driver: %s with pipe name: %s", pipeline.Driver(), pipeline.Name())) } @@ -453,7 +471,9 @@ func (p *Plugin) Declare(pipeline *pipeline.Pipeline) error { // if pipeline initialized to be consumed, call Run on it // but likely for the dynamic pipelines it should be started manually if _, ok := p.consume[pipeline.Name()]; ok { - err = initializedDriver.Run(pipeline) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + defer cancel() + err = initializedDriver.Run(ctx, pipeline) if err != nil { return errors.E(op, err) } @@ -485,8 +505,10 @@ func (p *Plugin) Destroy(pp string) error { // delete consumer delete(p.consumers, ppl.Name()) p.pipelines.Delete(pp) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(p.cfg.Timeout)) + defer cancel() - return d.Stop() + return d.Stop(ctx) } func (p *Plugin) List() []string { |