diff options
Diffstat (limited to 'plugins/temporal/workflow')
-rw-r--r-- | plugins/temporal/workflow/canceller.go | 41 | ||||
-rw-r--r-- | plugins/temporal/workflow/canceller_test.go | 33 | ||||
-rw-r--r-- | plugins/temporal/workflow/id_registry.go | 51 | ||||
-rw-r--r-- | plugins/temporal/workflow/message_queue.go | 47 | ||||
-rw-r--r-- | plugins/temporal/workflow/message_queue_test.go | 53 | ||||
-rw-r--r-- | plugins/temporal/workflow/plugin.go | 203 | ||||
-rw-r--r-- | plugins/temporal/workflow/process.go | 436 | ||||
-rw-r--r-- | plugins/temporal/workflow/workflow_pool.go | 190 |
8 files changed, 1054 insertions, 0 deletions
diff --git a/plugins/temporal/workflow/canceller.go b/plugins/temporal/workflow/canceller.go new file mode 100644 index 00000000..962c527f --- /dev/null +++ b/plugins/temporal/workflow/canceller.go @@ -0,0 +1,41 @@ +package workflow + +import ( + "sync" +) + +type cancellable func() error + +type canceller struct { + ids sync.Map +} + +func (c *canceller) register(id uint64, cancel cancellable) { + c.ids.Store(id, cancel) +} + +func (c *canceller) discard(id uint64) { + c.ids.Delete(id) +} + +func (c *canceller) cancel(ids ...uint64) error { + var err error + for _, id := range ids { + cancel, ok := c.ids.Load(id) + if ok == false { + continue + } + + // TODO return when minimum supported version will be go 1.15 + // go1.14 don't have LoadAndDelete method + // It was introduced only in go1.15 + c.ids.Delete(id) + + err = cancel.(cancellable)() + if err != nil { + return err + } + } + + return nil +} diff --git a/plugins/temporal/workflow/canceller_test.go b/plugins/temporal/workflow/canceller_test.go new file mode 100644 index 00000000..d6e846f8 --- /dev/null +++ b/plugins/temporal/workflow/canceller_test.go @@ -0,0 +1,33 @@ +package workflow + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_CancellerNoListeners(t *testing.T) { + c := &canceller{} + + assert.NoError(t, c.cancel(1)) +} + +func Test_CancellerListenerError(t *testing.T) { + c := &canceller{} + c.register(1, func() error { + return errors.New("failed") + }) + + assert.Error(t, c.cancel(1)) +} + +func Test_CancellerListenerDiscarded(t *testing.T) { + c := &canceller{} + c.register(1, func() error { + return errors.New("failed") + }) + + c.discard(1) + assert.NoError(t, c.cancel(1)) +} diff --git a/plugins/temporal/workflow/id_registry.go b/plugins/temporal/workflow/id_registry.go new file mode 100644 index 00000000..ac75cbda --- /dev/null +++ b/plugins/temporal/workflow/id_registry.go @@ -0,0 +1,51 @@ +package workflow + +import ( + "sync" + + bindings "go.temporal.io/sdk/internalbindings" +) + +// used to gain access to child workflow ids after they become available via callback result. +type idRegistry struct { + mu sync.Mutex + ids map[uint64]entry + listeners map[uint64]listener +} + +type listener func(w bindings.WorkflowExecution, err error) + +type entry struct { + w bindings.WorkflowExecution + err error +} + +func newIDRegistry() *idRegistry { + return &idRegistry{ + ids: map[uint64]entry{}, + listeners: map[uint64]listener{}, + } +} + +func (c *idRegistry) listen(id uint64, cl listener) { + c.mu.Lock() + defer c.mu.Unlock() + + c.listeners[id] = cl + + if e, ok := c.ids[id]; ok { + cl(e.w, e.err) + } +} + +func (c *idRegistry) push(id uint64, w bindings.WorkflowExecution, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + e := entry{w: w, err: err} + c.ids[id] = e + + if l, ok := c.listeners[id]; ok { + l(e.w, e.err) + } +} diff --git a/plugins/temporal/workflow/message_queue.go b/plugins/temporal/workflow/message_queue.go new file mode 100644 index 00000000..8f4409d1 --- /dev/null +++ b/plugins/temporal/workflow/message_queue.go @@ -0,0 +1,47 @@ +package workflow + +import ( + rrt "github.com/spiral/roadrunner/v2/plugins/temporal/protocol" + "go.temporal.io/api/common/v1" + "go.temporal.io/api/failure/v1" +) + +type messageQueue struct { + seqID func() uint64 + queue []rrt.Message +} + +func newMessageQueue(sedID func() uint64) *messageQueue { + return &messageQueue{ + seqID: sedID, + queue: make([]rrt.Message, 0, 5), + } +} + +func (mq *messageQueue) flush() { + mq.queue = mq.queue[0:0] +} + +func (mq *messageQueue) allocateMessage(cmd interface{}, payloads *common.Payloads) (uint64, rrt.Message) { + msg := rrt.Message{ + ID: mq.seqID(), + Command: cmd, + Payloads: payloads, + } + + return msg.ID, msg +} + +func (mq *messageQueue) pushCommand(cmd interface{}, payloads *common.Payloads) uint64 { + id, msg := mq.allocateMessage(cmd, payloads) + mq.queue = append(mq.queue, msg) + return id +} + +func (mq *messageQueue) pushResponse(id uint64, payloads *common.Payloads) { + mq.queue = append(mq.queue, rrt.Message{ID: id, Payloads: payloads}) +} + +func (mq *messageQueue) pushError(id uint64, failure *failure.Failure) { + mq.queue = append(mq.queue, rrt.Message{ID: id, Failure: failure}) +} diff --git a/plugins/temporal/workflow/message_queue_test.go b/plugins/temporal/workflow/message_queue_test.go new file mode 100644 index 00000000..1fcd409f --- /dev/null +++ b/plugins/temporal/workflow/message_queue_test.go @@ -0,0 +1,53 @@ +package workflow + +import ( + "sync/atomic" + "testing" + + "github.com/spiral/roadrunner/v2/plugins/temporal/protocol" + "github.com/stretchr/testify/assert" + "go.temporal.io/api/common/v1" + "go.temporal.io/api/failure/v1" +) + +func Test_MessageQueueFlushError(t *testing.T) { + var index uint64 + mq := newMessageQueue(func() uint64 { + return atomic.AddUint64(&index, 1) + }) + + mq.pushError(1, &failure.Failure{}) + assert.Len(t, mq.queue, 1) + + mq.flush() + assert.Len(t, mq.queue, 0) + assert.Equal(t, uint64(0), index) +} + +func Test_MessageQueueFlushResponse(t *testing.T) { + var index uint64 + mq := newMessageQueue(func() uint64 { + return atomic.AddUint64(&index, 1) + }) + + mq.pushResponse(1, &common.Payloads{}) + assert.Len(t, mq.queue, 1) + + mq.flush() + assert.Len(t, mq.queue, 0) + assert.Equal(t, uint64(0), index) +} + +func Test_MessageQueueCommandID(t *testing.T) { + var index uint64 + mq := newMessageQueue(func() uint64 { + return atomic.AddUint64(&index, 1) + }) + + n := mq.pushCommand(protocol.StartWorkflow{}, &common.Payloads{}) + assert.Equal(t, n, index) + assert.Len(t, mq.queue, 1) + + mq.flush() + assert.Len(t, mq.queue, 0) +} diff --git a/plugins/temporal/workflow/plugin.go b/plugins/temporal/workflow/plugin.go new file mode 100644 index 00000000..572d9a3b --- /dev/null +++ b/plugins/temporal/workflow/plugin.go @@ -0,0 +1,203 @@ +package workflow + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/events" + "github.com/spiral/roadrunner/v2/pkg/worker" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/spiral/roadrunner/v2/plugins/temporal/client" +) + +const ( + // PluginName defines public service name. + PluginName = "workflows" + + // RRMode sets as RR_MODE env variable to let worker know about the mode to run. + RRMode = "temporal/workflow" +) + +// Plugin manages workflows and workers. +type Plugin struct { + temporal client.Temporal + events events.Handler + server server.Server + log logger.Logger + mu sync.Mutex + reset chan struct{} + pool workflowPool + closing int64 +} + +// Init workflow plugin. +func (p *Plugin) Init(temporal client.Temporal, server server.Server, log logger.Logger) error { + p.temporal = temporal + p.server = server + p.events = events.NewEventsHandler() + p.log = log + p.reset = make(chan struct{}, 1) + + return nil +} + +// Serve starts workflow service. +func (p *Plugin) Serve() chan error { + const op = errors.Op("workflow_plugin_serve") + errCh := make(chan error, 1) + + pool, err := p.startPool() + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + + p.pool = pool + + go func() { + for { + select { + case <-p.reset: + if atomic.LoadInt64(&p.closing) == 1 { + return + } + + err := p.replacePool() + if err == nil { + continue + } + + bkoff := backoff.NewExponentialBackOff() + bkoff.InitialInterval = time.Second + + err = backoff.Retry(p.replacePool, bkoff) + if err != nil { + errCh <- errors.E(op, err) + } + } + } + }() + + return errCh +} + +// Stop workflow service. +func (p *Plugin) Stop() error { + const op = errors.Op("workflow_plugin_stop") + atomic.StoreInt64(&p.closing, 1) + + pool := p.getPool() + if pool != nil { + p.pool = nil + err := pool.Destroy(context.Background()) + if err != nil { + return errors.E(op, err) + } + return nil + } + + return nil +} + +// Name of the service. +func (p *Plugin) Name() string { + return PluginName +} + +// Workers returns list of available workflow workers. +func (p *Plugin) Workers() []worker.BaseProcess { + p.mu.Lock() + defer p.mu.Unlock() + return p.pool.Workers() +} + +// WorkflowNames returns list of all available workflows. +func (p *Plugin) WorkflowNames() []string { + return p.pool.WorkflowNames() +} + +// Reset resets underlying workflow pool with new copy. +func (p *Plugin) Reset() error { + p.reset <- struct{}{} + + return nil +} + +// AddListener adds event listeners to the service. +func (p *Plugin) poolListener(event interface{}) { + if ev, ok := event.(PoolEvent); ok { + if ev.Event == eventWorkerExit { + if ev.Caused != nil { + p.log.Error("Workflow pool error", "error", ev.Caused) + } + p.reset <- struct{}{} + } + } + + p.events.Push(event) +} + +// AddListener adds event listeners to the service. +func (p *Plugin) startPool() (workflowPool, error) { + const op = errors.Op("workflow_plugin_start_pool") + pool, err := newWorkflowPool( + p.temporal.GetCodec().WithLogger(p.log), + p.poolListener, + p.server, + ) + if err != nil { + return nil, errors.E(op, err) + } + + err = pool.Start(context.Background(), p.temporal) + if err != nil { + return nil, errors.E(op, err) + } + + p.log.Debug("Started workflow processing", "workflows", pool.WorkflowNames()) + + return pool, nil +} + +func (p *Plugin) replacePool() error { + p.mu.Lock() + const op = errors.Op("workflow_plugin_replace_pool") + defer p.mu.Unlock() + + if p.pool != nil { + err := p.pool.Destroy(context.Background()) + p.pool = nil + if err != nil { + p.log.Error( + "Unable to destroy expired workflow pool", + "error", + errors.E(op, err), + ) + return errors.E(op, err) + } + } + + pool, err := p.startPool() + if err != nil { + p.log.Error("Replace workflow pool failed", "error", err) + return errors.E(op, err) + } + + p.pool = pool + p.log.Debug("workflow pool successfully replaced") + + return nil +} + +// getPool returns currently pool. +func (p *Plugin) getPool() workflowPool { + p.mu.Lock() + defer p.mu.Unlock() + + return p.pool +} diff --git a/plugins/temporal/workflow/process.go b/plugins/temporal/workflow/process.go new file mode 100644 index 00000000..45e6885c --- /dev/null +++ b/plugins/temporal/workflow/process.go @@ -0,0 +1,436 @@ +package workflow + +import ( + "strconv" + "sync/atomic" + "time" + + "github.com/spiral/errors" + rrt "github.com/spiral/roadrunner/v2/plugins/temporal/protocol" + commonpb "go.temporal.io/api/common/v1" + bindings "go.temporal.io/sdk/internalbindings" + "go.temporal.io/sdk/workflow" +) + +// wraps single workflow process +type workflowProcess struct { + codec rrt.Codec + pool workflowPool + env bindings.WorkflowEnvironment + header *commonpb.Header + mq *messageQueue + ids *idRegistry + seqID uint64 + runID string + pipeline []rrt.Message + callbacks []func() error + canceller *canceller + inLoop bool +} + +// Execute workflow, bootstraps process. +func (wf *workflowProcess) Execute(env bindings.WorkflowEnvironment, header *commonpb.Header, input *commonpb.Payloads) { + wf.env = env + wf.header = header + wf.seqID = 0 + wf.runID = env.WorkflowInfo().WorkflowExecution.RunID + wf.canceller = &canceller{} + + // sequenceID shared for all worker workflows + wf.mq = newMessageQueue(wf.pool.SeqID) + wf.ids = newIDRegistry() + + env.RegisterCancelHandler(wf.handleCancel) + env.RegisterSignalHandler(wf.handleSignal) + env.RegisterQueryHandler(wf.handleQuery) + + var ( + lastCompletion = bindings.GetLastCompletionResult(env) + lastCompletionOffset = 0 + ) + + if lastCompletion != nil && len(lastCompletion.Payloads) != 0 { + if input == nil { + input = &commonpb.Payloads{Payloads: []*commonpb.Payload{}} + } + + input.Payloads = append(input.Payloads, lastCompletion.Payloads...) + lastCompletionOffset = len(lastCompletion.Payloads) + } + + _ = wf.mq.pushCommand( + rrt.StartWorkflow{ + Info: env.WorkflowInfo(), + LastCompletion: lastCompletionOffset, + }, + input, + ) +} + +// OnWorkflowTaskStarted handles single workflow tick and batch of pipeline from temporal server. +func (wf *workflowProcess) OnWorkflowTaskStarted() { + wf.inLoop = true + defer func() { wf.inLoop = false }() + + var err error + for _, callback := range wf.callbacks { + err = callback() + if err != nil { + panic(err) + } + } + wf.callbacks = nil + + if err := wf.flushQueue(); err != nil { + panic(err) + } + + for len(wf.pipeline) > 0 { + msg := wf.pipeline[0] + wf.pipeline = wf.pipeline[1:] + + if msg.IsCommand() { + err = wf.handleMessage(msg) + } + + if err != nil { + panic(err) + } + } +} + +// StackTrace renders workflow stack trace. +func (wf *workflowProcess) StackTrace() string { + result, err := wf.runCommand( + rrt.GetStackTrace{ + RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID, + }, + nil, + ) + + if err != nil { + return err.Error() + } + + var stacktrace string + err = wf.env.GetDataConverter().FromPayload(result.Payloads.Payloads[0], &stacktrace) + if err != nil { + return err.Error() + } + + return stacktrace +} + +// Close the workflow. +func (wf *workflowProcess) Close() { + // TODO: properly handle errors + // panic(err) + + _ = wf.mq.pushCommand( + rrt.DestroyWorkflow{RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID}, + nil, + ) + + _, _ = wf.discardQueue() +} + +// execution context. +func (wf *workflowProcess) getContext() rrt.Context { + return rrt.Context{ + TaskQueue: wf.env.WorkflowInfo().TaskQueueName, + TickTime: wf.env.Now().Format(time.RFC3339), + Replay: wf.env.IsReplaying(), + } +} + +// schedule cancel command +func (wf *workflowProcess) handleCancel() { + _ = wf.mq.pushCommand( + rrt.CancelWorkflow{RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID}, + nil, + ) +} + +// schedule the signal processing +func (wf *workflowProcess) handleSignal(name string, input *commonpb.Payloads) { + _ = wf.mq.pushCommand( + rrt.InvokeSignal{ + RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID, + Name: name, + }, + input, + ) +} + +// Handle query in blocking mode. +func (wf *workflowProcess) handleQuery(queryType string, queryArgs *commonpb.Payloads) (*commonpb.Payloads, error) { + result, err := wf.runCommand( + rrt.InvokeQuery{ + RunID: wf.runID, + Name: queryType, + }, + queryArgs, + ) + + if err != nil { + return nil, err + } + + if result.Failure != nil { + return nil, bindings.ConvertFailureToError(result.Failure, wf.env.GetDataConverter()) + } + + return result.Payloads, nil +} + +// process incoming command +func (wf *workflowProcess) handleMessage(msg rrt.Message) error { + const op = errors.Op("handleMessage") + var err error + + var ( + id = msg.ID + cmd = msg.Command + payloads = msg.Payloads + ) + + switch cmd := cmd.(type) { + case *rrt.ExecuteActivity: + params := cmd.ActivityParams(wf.env, payloads) + activityID := wf.env.ExecuteActivity(params, wf.createCallback(id)) + + wf.canceller.register(id, func() error { + wf.env.RequestCancelActivity(activityID) + return nil + }) + + case *rrt.ExecuteChildWorkflow: + params := cmd.WorkflowParams(wf.env, payloads) + + // always use deterministic id + if params.WorkflowID == "" { + nextID := atomic.AddUint64(&wf.seqID, 1) + params.WorkflowID = wf.env.WorkflowInfo().WorkflowExecution.RunID + "_" + strconv.Itoa(int(nextID)) + } + + wf.env.ExecuteChildWorkflow(params, wf.createCallback(id), func(r bindings.WorkflowExecution, e error) { + wf.ids.push(id, r, e) + }) + + wf.canceller.register(id, func() error { + wf.env.RequestCancelChildWorkflow(params.Namespace, params.WorkflowID) + return nil + }) + + case *rrt.GetChildWorkflowExecution: + wf.ids.listen(cmd.ID, func(w bindings.WorkflowExecution, err error) { + cl := wf.createCallback(id) + + // TODO rewrite + if err != nil { + panic(err) + } + + p, err := wf.env.GetDataConverter().ToPayloads(w) + if err != nil { + panic(err) + } + + cl(p, err) + }) + + case *rrt.NewTimer: + timerID := wf.env.NewTimer(cmd.ToDuration(), wf.createCallback(id)) + wf.canceller.register(id, func() error { + if timerID != nil { + wf.env.RequestCancelTimer(*timerID) + } + return nil + }) + + case *rrt.GetVersion: + version := wf.env.GetVersion( + cmd.ChangeID, + workflow.Version(cmd.MinSupported), + workflow.Version(cmd.MaxSupported), + ) + + result, err := wf.env.GetDataConverter().ToPayloads(version) + if err != nil { + return errors.E(op, err) + } + + wf.mq.pushResponse(id, result) + err = wf.flushQueue() + if err != nil { + panic(err) + } + + case *rrt.SideEffect: + wf.env.SideEffect( + func() (*commonpb.Payloads, error) { + return payloads, nil + }, + wf.createContinuableCallback(id), + ) + + case *rrt.CompleteWorkflow: + result, _ := wf.env.GetDataConverter().ToPayloads("completed") + wf.mq.pushResponse(id, result) + + if msg.Failure == nil { + wf.env.Complete(payloads, nil) + } else { + wf.env.Complete(nil, bindings.ConvertFailureToError(msg.Failure, wf.env.GetDataConverter())) + } + + case *rrt.ContinueAsNew: + result, _ := wf.env.GetDataConverter().ToPayloads("completed") + wf.mq.pushResponse(id, result) + + wf.env.Complete(nil, &workflow.ContinueAsNewError{ + WorkflowType: &bindings.WorkflowType{Name: cmd.Name}, + Input: payloads, + Header: wf.header, + TaskQueueName: cmd.Options.TaskQueueName, + WorkflowExecutionTimeout: cmd.Options.WorkflowExecutionTimeout, + WorkflowRunTimeout: cmd.Options.WorkflowRunTimeout, + WorkflowTaskTimeout: cmd.Options.WorkflowTaskTimeout, + }) + + case *rrt.SignalExternalWorkflow: + wf.env.SignalExternalWorkflow( + cmd.Namespace, + cmd.WorkflowID, + cmd.RunID, + cmd.Signal, + payloads, + nil, + cmd.ChildWorkflowOnly, + wf.createCallback(id), + ) + + case *rrt.CancelExternalWorkflow: + wf.env.RequestCancelExternalWorkflow(cmd.Namespace, cmd.WorkflowID, cmd.RunID, wf.createCallback(id)) + + case *rrt.Cancel: + err = wf.canceller.cancel(cmd.CommandIDs...) + if err != nil { + return errors.E(op, err) + } + + result, _ := wf.env.GetDataConverter().ToPayloads("completed") + wf.mq.pushResponse(id, result) + + err = wf.flushQueue() + if err != nil { + panic(err) + } + + case *rrt.Panic: + panic(errors.E(cmd.Message)) + + default: + panic("undefined command") + } + + return nil +} + +func (wf *workflowProcess) createCallback(id uint64) bindings.ResultHandler { + callback := func(result *commonpb.Payloads, err error) error { + wf.canceller.discard(id) + + if err != nil { + wf.mq.pushError(id, bindings.ConvertErrorToFailure(err, wf.env.GetDataConverter())) + return nil + } + + // fetch original payload + wf.mq.pushResponse(id, result) + return nil + } + + return func(result *commonpb.Payloads, err error) { + // timer cancel callback can happen inside the loop + if wf.inLoop { + err := callback(result, err) + if err != nil { + panic(err) + } + + return + } + + wf.callbacks = append(wf.callbacks, func() error { + return callback(result, err) + }) + } +} + +// callback to be called inside the queue processing, adds new messages at the end of the queue +func (wf *workflowProcess) createContinuableCallback(id uint64) bindings.ResultHandler { + callback := func(result *commonpb.Payloads, err error) { + wf.canceller.discard(id) + + if err != nil { + wf.mq.pushError(id, bindings.ConvertErrorToFailure(err, wf.env.GetDataConverter())) + return + } + + wf.mq.pushResponse(id, result) + err = wf.flushQueue() + if err != nil { + panic(err) + } + } + + return func(result *commonpb.Payloads, err error) { + callback(result, err) + } +} + +// Exchange messages between host and worker processes and add new commands to the queue. +func (wf *workflowProcess) flushQueue() error { + const op = errors.Op("flush queue") + messages, err := wf.codec.Execute(wf.pool, wf.getContext(), wf.mq.queue...) + wf.mq.flush() + + if err != nil { + return errors.E(op, err) + } + + wf.pipeline = append(wf.pipeline, messages...) + + return nil +} + +// Exchange messages between host and worker processes without adding new commands to the queue. +func (wf *workflowProcess) discardQueue() ([]rrt.Message, error) { + const op = errors.Op("discard queue") + messages, err := wf.codec.Execute(wf.pool, wf.getContext(), wf.mq.queue...) + wf.mq.flush() + + if err != nil { + return nil, errors.E(op, err) + } + + return messages, nil +} + +// Run single command and return single result. +func (wf *workflowProcess) runCommand(cmd interface{}, payloads *commonpb.Payloads) (rrt.Message, error) { + const op = errors.Op("workflow_process_runcommand") + _, msg := wf.mq.allocateMessage(cmd, payloads) + + result, err := wf.codec.Execute(wf.pool, wf.getContext(), msg) + if err != nil { + return rrt.Message{}, errors.E(op, err) + } + + if len(result) != 1 { + return rrt.Message{}, errors.E(op, errors.Str("unexpected worker response")) + } + + return result[0], nil +} diff --git a/plugins/temporal/workflow/workflow_pool.go b/plugins/temporal/workflow/workflow_pool.go new file mode 100644 index 00000000..b9ed46c8 --- /dev/null +++ b/plugins/temporal/workflow/workflow_pool.go @@ -0,0 +1,190 @@ +package workflow + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/pkg/events" + "github.com/spiral/roadrunner/v2/pkg/payload" + rrWorker "github.com/spiral/roadrunner/v2/pkg/worker" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/spiral/roadrunner/v2/plugins/temporal/client" + rrt "github.com/spiral/roadrunner/v2/plugins/temporal/protocol" + bindings "go.temporal.io/sdk/internalbindings" + "go.temporal.io/sdk/worker" + "go.temporal.io/sdk/workflow" +) + +const eventWorkerExit = 8390 + +// RR_MODE env variable key +const RR_MODE = "RR_MODE" //nolint + +// RR_CODEC env variable key +const RR_CODEC = "RR_CODEC" //nolint + +type workflowPool interface { + SeqID() uint64 + Exec(p payload.Payload) (payload.Payload, error) + Start(ctx context.Context, temporal client.Temporal) error + Destroy(ctx context.Context) error + Workers() []rrWorker.BaseProcess + WorkflowNames() []string +} + +// PoolEvent triggered on workflow pool worker events. +type PoolEvent struct { + Event int + Context interface{} + Caused error +} + +// workflowPoolImpl manages workflowProcess executions between worker restarts. +type workflowPoolImpl struct { + codec rrt.Codec + seqID uint64 + workflows map[string]rrt.WorkflowInfo + tWorkers []worker.Worker + mu sync.Mutex + worker rrWorker.SyncWorker + active bool +} + +// newWorkflowPool creates new workflow pool. +func newWorkflowPool(codec rrt.Codec, listener events.Listener, factory server.Server) (workflowPool, error) { + const op = errors.Op("new_workflow_pool") + w, err := factory.NewWorker( + context.Background(), + map[string]string{RR_MODE: RRMode, RR_CODEC: codec.GetName()}, + listener, + ) + if err != nil { + return nil, errors.E(op, err) + } + + go func() { + err := w.Wait() + listener(PoolEvent{Event: eventWorkerExit, Caused: err}) + }() + + return &workflowPoolImpl{codec: codec, worker: rrWorker.From(w)}, nil +} + +// Start the pool in non blocking mode. +func (pool *workflowPoolImpl) Start(ctx context.Context, temporal client.Temporal) error { + const op = errors.Op("workflow_pool_start") + pool.mu.Lock() + pool.active = true + pool.mu.Unlock() + + err := pool.initWorkers(ctx, temporal) + if err != nil { + return errors.E(op, err) + } + + for i := 0; i < len(pool.tWorkers); i++ { + err := pool.tWorkers[i].Start() + if err != nil { + return errors.E(op, err) + } + } + + return nil +} + +// Active. +func (pool *workflowPoolImpl) Active() bool { + return pool.active +} + +// Destroy stops all temporal workers and application worker. +func (pool *workflowPoolImpl) Destroy(ctx context.Context) error { + pool.mu.Lock() + defer pool.mu.Unlock() + const op = errors.Op("workflow_pool_destroy") + + pool.active = false + for i := 0; i < len(pool.tWorkers); i++ { + pool.tWorkers[i].Stop() + } + + worker.PurgeStickyWorkflowCache() + + err := pool.worker.Stop() + if err != nil { + return errors.E(op, err) + } + + return nil +} + +// NewWorkflowDefinition initiates new workflow process. +func (pool *workflowPoolImpl) NewWorkflowDefinition() bindings.WorkflowDefinition { + return &workflowProcess{ + codec: pool.codec, + pool: pool, + } +} + +// NewWorkflowDefinition initiates new workflow process. +func (pool *workflowPoolImpl) SeqID() uint64 { + return atomic.AddUint64(&pool.seqID, 1) +} + +// Exec set of commands in thread safe move. +func (pool *workflowPoolImpl) Exec(p payload.Payload) (payload.Payload, error) { + pool.mu.Lock() + defer pool.mu.Unlock() + + if !pool.active { + return payload.Payload{}, nil + } + + return pool.worker.Exec(p) +} + +func (pool *workflowPoolImpl) Workers() []rrWorker.BaseProcess { + return []rrWorker.BaseProcess{pool.worker} +} + +func (pool *workflowPoolImpl) WorkflowNames() []string { + names := make([]string, 0, len(pool.workflows)) + for name := range pool.workflows { + names = append(names, name) + } + + return names +} + +// initWorkers request workers workflows from underlying PHP and configures temporal workers linked to the pool. +func (pool *workflowPoolImpl) initWorkers(ctx context.Context, temporal client.Temporal) error { + const op = errors.Op("workflow_pool_init_workers") + workerInfo, err := rrt.FetchWorkerInfo(pool.codec, pool, temporal.GetDataConverter()) + if err != nil { + return errors.E(op, err) + } + + pool.workflows = make(map[string]rrt.WorkflowInfo) + pool.tWorkers = make([]worker.Worker, 0) + + for _, info := range workerInfo { + w, err := temporal.CreateWorker(info.TaskQueue, info.Options) + if err != nil { + return errors.E(op, err, pool.Destroy(ctx)) + } + + pool.tWorkers = append(pool.tWorkers, w) + for _, workflowInfo := range info.Workflows { + w.RegisterWorkflowWithOptions(pool, workflow.RegisterOptions{ + Name: workflowInfo.Name, + DisableAlreadyRegisteredCheck: false, + }) + + pool.workflows[workflowInfo.Name] = workflowInfo + } + } + + return nil +} |