summaryrefslogtreecommitdiff
path: root/plugins/temporal/workflow
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/temporal/workflow')
-rw-r--r--plugins/temporal/workflow/canceller.go41
-rw-r--r--plugins/temporal/workflow/canceller_test.go33
-rw-r--r--plugins/temporal/workflow/id_registry.go51
-rw-r--r--plugins/temporal/workflow/message_queue.go47
-rw-r--r--plugins/temporal/workflow/message_queue_test.go53
-rw-r--r--plugins/temporal/workflow/plugin.go203
-rw-r--r--plugins/temporal/workflow/process.go436
-rw-r--r--plugins/temporal/workflow/workflow_pool.go190
8 files changed, 1054 insertions, 0 deletions
diff --git a/plugins/temporal/workflow/canceller.go b/plugins/temporal/workflow/canceller.go
new file mode 100644
index 00000000..962c527f
--- /dev/null
+++ b/plugins/temporal/workflow/canceller.go
@@ -0,0 +1,41 @@
+package workflow
+
+import (
+ "sync"
+)
+
+type cancellable func() error
+
+type canceller struct {
+ ids sync.Map
+}
+
+func (c *canceller) register(id uint64, cancel cancellable) {
+ c.ids.Store(id, cancel)
+}
+
+func (c *canceller) discard(id uint64) {
+ c.ids.Delete(id)
+}
+
+func (c *canceller) cancel(ids ...uint64) error {
+ var err error
+ for _, id := range ids {
+ cancel, ok := c.ids.Load(id)
+ if ok == false {
+ continue
+ }
+
+ // TODO return when minimum supported version will be go 1.15
+ // go1.14 don't have LoadAndDelete method
+ // It was introduced only in go1.15
+ c.ids.Delete(id)
+
+ err = cancel.(cancellable)()
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/plugins/temporal/workflow/canceller_test.go b/plugins/temporal/workflow/canceller_test.go
new file mode 100644
index 00000000..d6e846f8
--- /dev/null
+++ b/plugins/temporal/workflow/canceller_test.go
@@ -0,0 +1,33 @@
+package workflow
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_CancellerNoListeners(t *testing.T) {
+ c := &canceller{}
+
+ assert.NoError(t, c.cancel(1))
+}
+
+func Test_CancellerListenerError(t *testing.T) {
+ c := &canceller{}
+ c.register(1, func() error {
+ return errors.New("failed")
+ })
+
+ assert.Error(t, c.cancel(1))
+}
+
+func Test_CancellerListenerDiscarded(t *testing.T) {
+ c := &canceller{}
+ c.register(1, func() error {
+ return errors.New("failed")
+ })
+
+ c.discard(1)
+ assert.NoError(t, c.cancel(1))
+}
diff --git a/plugins/temporal/workflow/id_registry.go b/plugins/temporal/workflow/id_registry.go
new file mode 100644
index 00000000..ac75cbda
--- /dev/null
+++ b/plugins/temporal/workflow/id_registry.go
@@ -0,0 +1,51 @@
+package workflow
+
+import (
+ "sync"
+
+ bindings "go.temporal.io/sdk/internalbindings"
+)
+
+// used to gain access to child workflow ids after they become available via callback result.
+type idRegistry struct {
+ mu sync.Mutex
+ ids map[uint64]entry
+ listeners map[uint64]listener
+}
+
+type listener func(w bindings.WorkflowExecution, err error)
+
+type entry struct {
+ w bindings.WorkflowExecution
+ err error
+}
+
+func newIDRegistry() *idRegistry {
+ return &idRegistry{
+ ids: map[uint64]entry{},
+ listeners: map[uint64]listener{},
+ }
+}
+
+func (c *idRegistry) listen(id uint64, cl listener) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.listeners[id] = cl
+
+ if e, ok := c.ids[id]; ok {
+ cl(e.w, e.err)
+ }
+}
+
+func (c *idRegistry) push(id uint64, w bindings.WorkflowExecution, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ e := entry{w: w, err: err}
+ c.ids[id] = e
+
+ if l, ok := c.listeners[id]; ok {
+ l(e.w, e.err)
+ }
+}
diff --git a/plugins/temporal/workflow/message_queue.go b/plugins/temporal/workflow/message_queue.go
new file mode 100644
index 00000000..8f4409d1
--- /dev/null
+++ b/plugins/temporal/workflow/message_queue.go
@@ -0,0 +1,47 @@
+package workflow
+
+import (
+ rrt "github.com/spiral/roadrunner/v2/plugins/temporal/protocol"
+ "go.temporal.io/api/common/v1"
+ "go.temporal.io/api/failure/v1"
+)
+
+type messageQueue struct {
+ seqID func() uint64
+ queue []rrt.Message
+}
+
+func newMessageQueue(sedID func() uint64) *messageQueue {
+ return &messageQueue{
+ seqID: sedID,
+ queue: make([]rrt.Message, 0, 5),
+ }
+}
+
+func (mq *messageQueue) flush() {
+ mq.queue = mq.queue[0:0]
+}
+
+func (mq *messageQueue) allocateMessage(cmd interface{}, payloads *common.Payloads) (uint64, rrt.Message) {
+ msg := rrt.Message{
+ ID: mq.seqID(),
+ Command: cmd,
+ Payloads: payloads,
+ }
+
+ return msg.ID, msg
+}
+
+func (mq *messageQueue) pushCommand(cmd interface{}, payloads *common.Payloads) uint64 {
+ id, msg := mq.allocateMessage(cmd, payloads)
+ mq.queue = append(mq.queue, msg)
+ return id
+}
+
+func (mq *messageQueue) pushResponse(id uint64, payloads *common.Payloads) {
+ mq.queue = append(mq.queue, rrt.Message{ID: id, Payloads: payloads})
+}
+
+func (mq *messageQueue) pushError(id uint64, failure *failure.Failure) {
+ mq.queue = append(mq.queue, rrt.Message{ID: id, Failure: failure})
+}
diff --git a/plugins/temporal/workflow/message_queue_test.go b/plugins/temporal/workflow/message_queue_test.go
new file mode 100644
index 00000000..1fcd409f
--- /dev/null
+++ b/plugins/temporal/workflow/message_queue_test.go
@@ -0,0 +1,53 @@
+package workflow
+
+import (
+ "sync/atomic"
+ "testing"
+
+ "github.com/spiral/roadrunner/v2/plugins/temporal/protocol"
+ "github.com/stretchr/testify/assert"
+ "go.temporal.io/api/common/v1"
+ "go.temporal.io/api/failure/v1"
+)
+
+func Test_MessageQueueFlushError(t *testing.T) {
+ var index uint64
+ mq := newMessageQueue(func() uint64 {
+ return atomic.AddUint64(&index, 1)
+ })
+
+ mq.pushError(1, &failure.Failure{})
+ assert.Len(t, mq.queue, 1)
+
+ mq.flush()
+ assert.Len(t, mq.queue, 0)
+ assert.Equal(t, uint64(0), index)
+}
+
+func Test_MessageQueueFlushResponse(t *testing.T) {
+ var index uint64
+ mq := newMessageQueue(func() uint64 {
+ return atomic.AddUint64(&index, 1)
+ })
+
+ mq.pushResponse(1, &common.Payloads{})
+ assert.Len(t, mq.queue, 1)
+
+ mq.flush()
+ assert.Len(t, mq.queue, 0)
+ assert.Equal(t, uint64(0), index)
+}
+
+func Test_MessageQueueCommandID(t *testing.T) {
+ var index uint64
+ mq := newMessageQueue(func() uint64 {
+ return atomic.AddUint64(&index, 1)
+ })
+
+ n := mq.pushCommand(protocol.StartWorkflow{}, &common.Payloads{})
+ assert.Equal(t, n, index)
+ assert.Len(t, mq.queue, 1)
+
+ mq.flush()
+ assert.Len(t, mq.queue, 0)
+}
diff --git a/plugins/temporal/workflow/plugin.go b/plugins/temporal/workflow/plugin.go
new file mode 100644
index 00000000..572d9a3b
--- /dev/null
+++ b/plugins/temporal/workflow/plugin.go
@@ -0,0 +1,203 @@
+package workflow
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/events"
+ "github.com/spiral/roadrunner/v2/pkg/worker"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+ "github.com/spiral/roadrunner/v2/plugins/server"
+ "github.com/spiral/roadrunner/v2/plugins/temporal/client"
+)
+
+const (
+ // PluginName defines public service name.
+ PluginName = "workflows"
+
+ // RRMode sets as RR_MODE env variable to let worker know about the mode to run.
+ RRMode = "temporal/workflow"
+)
+
+// Plugin manages workflows and workers.
+type Plugin struct {
+ temporal client.Temporal
+ events events.Handler
+ server server.Server
+ log logger.Logger
+ mu sync.Mutex
+ reset chan struct{}
+ pool workflowPool
+ closing int64
+}
+
+// Init workflow plugin.
+func (p *Plugin) Init(temporal client.Temporal, server server.Server, log logger.Logger) error {
+ p.temporal = temporal
+ p.server = server
+ p.events = events.NewEventsHandler()
+ p.log = log
+ p.reset = make(chan struct{}, 1)
+
+ return nil
+}
+
+// Serve starts workflow service.
+func (p *Plugin) Serve() chan error {
+ const op = errors.Op("workflow_plugin_serve")
+ errCh := make(chan error, 1)
+
+ pool, err := p.startPool()
+ if err != nil {
+ errCh <- errors.E(op, err)
+ return errCh
+ }
+
+ p.pool = pool
+
+ go func() {
+ for {
+ select {
+ case <-p.reset:
+ if atomic.LoadInt64(&p.closing) == 1 {
+ return
+ }
+
+ err := p.replacePool()
+ if err == nil {
+ continue
+ }
+
+ bkoff := backoff.NewExponentialBackOff()
+ bkoff.InitialInterval = time.Second
+
+ err = backoff.Retry(p.replacePool, bkoff)
+ if err != nil {
+ errCh <- errors.E(op, err)
+ }
+ }
+ }
+ }()
+
+ return errCh
+}
+
+// Stop workflow service.
+func (p *Plugin) Stop() error {
+ const op = errors.Op("workflow_plugin_stop")
+ atomic.StoreInt64(&p.closing, 1)
+
+ pool := p.getPool()
+ if pool != nil {
+ p.pool = nil
+ err := pool.Destroy(context.Background())
+ if err != nil {
+ return errors.E(op, err)
+ }
+ return nil
+ }
+
+ return nil
+}
+
+// Name of the service.
+func (p *Plugin) Name() string {
+ return PluginName
+}
+
+// Workers returns list of available workflow workers.
+func (p *Plugin) Workers() []worker.BaseProcess {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.pool.Workers()
+}
+
+// WorkflowNames returns list of all available workflows.
+func (p *Plugin) WorkflowNames() []string {
+ return p.pool.WorkflowNames()
+}
+
+// Reset resets underlying workflow pool with new copy.
+func (p *Plugin) Reset() error {
+ p.reset <- struct{}{}
+
+ return nil
+}
+
+// AddListener adds event listeners to the service.
+func (p *Plugin) poolListener(event interface{}) {
+ if ev, ok := event.(PoolEvent); ok {
+ if ev.Event == eventWorkerExit {
+ if ev.Caused != nil {
+ p.log.Error("Workflow pool error", "error", ev.Caused)
+ }
+ p.reset <- struct{}{}
+ }
+ }
+
+ p.events.Push(event)
+}
+
+// AddListener adds event listeners to the service.
+func (p *Plugin) startPool() (workflowPool, error) {
+ const op = errors.Op("workflow_plugin_start_pool")
+ pool, err := newWorkflowPool(
+ p.temporal.GetCodec().WithLogger(p.log),
+ p.poolListener,
+ p.server,
+ )
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ err = pool.Start(context.Background(), p.temporal)
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ p.log.Debug("Started workflow processing", "workflows", pool.WorkflowNames())
+
+ return pool, nil
+}
+
+func (p *Plugin) replacePool() error {
+ p.mu.Lock()
+ const op = errors.Op("workflow_plugin_replace_pool")
+ defer p.mu.Unlock()
+
+ if p.pool != nil {
+ err := p.pool.Destroy(context.Background())
+ p.pool = nil
+ if err != nil {
+ p.log.Error(
+ "Unable to destroy expired workflow pool",
+ "error",
+ errors.E(op, err),
+ )
+ return errors.E(op, err)
+ }
+ }
+
+ pool, err := p.startPool()
+ if err != nil {
+ p.log.Error("Replace workflow pool failed", "error", err)
+ return errors.E(op, err)
+ }
+
+ p.pool = pool
+ p.log.Debug("workflow pool successfully replaced")
+
+ return nil
+}
+
+// getPool returns currently pool.
+func (p *Plugin) getPool() workflowPool {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.pool
+}
diff --git a/plugins/temporal/workflow/process.go b/plugins/temporal/workflow/process.go
new file mode 100644
index 00000000..45e6885c
--- /dev/null
+++ b/plugins/temporal/workflow/process.go
@@ -0,0 +1,436 @@
+package workflow
+
+import (
+ "strconv"
+ "sync/atomic"
+ "time"
+
+ "github.com/spiral/errors"
+ rrt "github.com/spiral/roadrunner/v2/plugins/temporal/protocol"
+ commonpb "go.temporal.io/api/common/v1"
+ bindings "go.temporal.io/sdk/internalbindings"
+ "go.temporal.io/sdk/workflow"
+)
+
+// wraps single workflow process
+type workflowProcess struct {
+ codec rrt.Codec
+ pool workflowPool
+ env bindings.WorkflowEnvironment
+ header *commonpb.Header
+ mq *messageQueue
+ ids *idRegistry
+ seqID uint64
+ runID string
+ pipeline []rrt.Message
+ callbacks []func() error
+ canceller *canceller
+ inLoop bool
+}
+
+// Execute workflow, bootstraps process.
+func (wf *workflowProcess) Execute(env bindings.WorkflowEnvironment, header *commonpb.Header, input *commonpb.Payloads) {
+ wf.env = env
+ wf.header = header
+ wf.seqID = 0
+ wf.runID = env.WorkflowInfo().WorkflowExecution.RunID
+ wf.canceller = &canceller{}
+
+ // sequenceID shared for all worker workflows
+ wf.mq = newMessageQueue(wf.pool.SeqID)
+ wf.ids = newIDRegistry()
+
+ env.RegisterCancelHandler(wf.handleCancel)
+ env.RegisterSignalHandler(wf.handleSignal)
+ env.RegisterQueryHandler(wf.handleQuery)
+
+ var (
+ lastCompletion = bindings.GetLastCompletionResult(env)
+ lastCompletionOffset = 0
+ )
+
+ if lastCompletion != nil && len(lastCompletion.Payloads) != 0 {
+ if input == nil {
+ input = &commonpb.Payloads{Payloads: []*commonpb.Payload{}}
+ }
+
+ input.Payloads = append(input.Payloads, lastCompletion.Payloads...)
+ lastCompletionOffset = len(lastCompletion.Payloads)
+ }
+
+ _ = wf.mq.pushCommand(
+ rrt.StartWorkflow{
+ Info: env.WorkflowInfo(),
+ LastCompletion: lastCompletionOffset,
+ },
+ input,
+ )
+}
+
+// OnWorkflowTaskStarted handles single workflow tick and batch of pipeline from temporal server.
+func (wf *workflowProcess) OnWorkflowTaskStarted() {
+ wf.inLoop = true
+ defer func() { wf.inLoop = false }()
+
+ var err error
+ for _, callback := range wf.callbacks {
+ err = callback()
+ if err != nil {
+ panic(err)
+ }
+ }
+ wf.callbacks = nil
+
+ if err := wf.flushQueue(); err != nil {
+ panic(err)
+ }
+
+ for len(wf.pipeline) > 0 {
+ msg := wf.pipeline[0]
+ wf.pipeline = wf.pipeline[1:]
+
+ if msg.IsCommand() {
+ err = wf.handleMessage(msg)
+ }
+
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// StackTrace renders workflow stack trace.
+func (wf *workflowProcess) StackTrace() string {
+ result, err := wf.runCommand(
+ rrt.GetStackTrace{
+ RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID,
+ },
+ nil,
+ )
+
+ if err != nil {
+ return err.Error()
+ }
+
+ var stacktrace string
+ err = wf.env.GetDataConverter().FromPayload(result.Payloads.Payloads[0], &stacktrace)
+ if err != nil {
+ return err.Error()
+ }
+
+ return stacktrace
+}
+
+// Close the workflow.
+func (wf *workflowProcess) Close() {
+ // TODO: properly handle errors
+ // panic(err)
+
+ _ = wf.mq.pushCommand(
+ rrt.DestroyWorkflow{RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID},
+ nil,
+ )
+
+ _, _ = wf.discardQueue()
+}
+
+// execution context.
+func (wf *workflowProcess) getContext() rrt.Context {
+ return rrt.Context{
+ TaskQueue: wf.env.WorkflowInfo().TaskQueueName,
+ TickTime: wf.env.Now().Format(time.RFC3339),
+ Replay: wf.env.IsReplaying(),
+ }
+}
+
+// schedule cancel command
+func (wf *workflowProcess) handleCancel() {
+ _ = wf.mq.pushCommand(
+ rrt.CancelWorkflow{RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID},
+ nil,
+ )
+}
+
+// schedule the signal processing
+func (wf *workflowProcess) handleSignal(name string, input *commonpb.Payloads) {
+ _ = wf.mq.pushCommand(
+ rrt.InvokeSignal{
+ RunID: wf.env.WorkflowInfo().WorkflowExecution.RunID,
+ Name: name,
+ },
+ input,
+ )
+}
+
+// Handle query in blocking mode.
+func (wf *workflowProcess) handleQuery(queryType string, queryArgs *commonpb.Payloads) (*commonpb.Payloads, error) {
+ result, err := wf.runCommand(
+ rrt.InvokeQuery{
+ RunID: wf.runID,
+ Name: queryType,
+ },
+ queryArgs,
+ )
+
+ if err != nil {
+ return nil, err
+ }
+
+ if result.Failure != nil {
+ return nil, bindings.ConvertFailureToError(result.Failure, wf.env.GetDataConverter())
+ }
+
+ return result.Payloads, nil
+}
+
+// process incoming command
+func (wf *workflowProcess) handleMessage(msg rrt.Message) error {
+ const op = errors.Op("handleMessage")
+ var err error
+
+ var (
+ id = msg.ID
+ cmd = msg.Command
+ payloads = msg.Payloads
+ )
+
+ switch cmd := cmd.(type) {
+ case *rrt.ExecuteActivity:
+ params := cmd.ActivityParams(wf.env, payloads)
+ activityID := wf.env.ExecuteActivity(params, wf.createCallback(id))
+
+ wf.canceller.register(id, func() error {
+ wf.env.RequestCancelActivity(activityID)
+ return nil
+ })
+
+ case *rrt.ExecuteChildWorkflow:
+ params := cmd.WorkflowParams(wf.env, payloads)
+
+ // always use deterministic id
+ if params.WorkflowID == "" {
+ nextID := atomic.AddUint64(&wf.seqID, 1)
+ params.WorkflowID = wf.env.WorkflowInfo().WorkflowExecution.RunID + "_" + strconv.Itoa(int(nextID))
+ }
+
+ wf.env.ExecuteChildWorkflow(params, wf.createCallback(id), func(r bindings.WorkflowExecution, e error) {
+ wf.ids.push(id, r, e)
+ })
+
+ wf.canceller.register(id, func() error {
+ wf.env.RequestCancelChildWorkflow(params.Namespace, params.WorkflowID)
+ return nil
+ })
+
+ case *rrt.GetChildWorkflowExecution:
+ wf.ids.listen(cmd.ID, func(w bindings.WorkflowExecution, err error) {
+ cl := wf.createCallback(id)
+
+ // TODO rewrite
+ if err != nil {
+ panic(err)
+ }
+
+ p, err := wf.env.GetDataConverter().ToPayloads(w)
+ if err != nil {
+ panic(err)
+ }
+
+ cl(p, err)
+ })
+
+ case *rrt.NewTimer:
+ timerID := wf.env.NewTimer(cmd.ToDuration(), wf.createCallback(id))
+ wf.canceller.register(id, func() error {
+ if timerID != nil {
+ wf.env.RequestCancelTimer(*timerID)
+ }
+ return nil
+ })
+
+ case *rrt.GetVersion:
+ version := wf.env.GetVersion(
+ cmd.ChangeID,
+ workflow.Version(cmd.MinSupported),
+ workflow.Version(cmd.MaxSupported),
+ )
+
+ result, err := wf.env.GetDataConverter().ToPayloads(version)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ wf.mq.pushResponse(id, result)
+ err = wf.flushQueue()
+ if err != nil {
+ panic(err)
+ }
+
+ case *rrt.SideEffect:
+ wf.env.SideEffect(
+ func() (*commonpb.Payloads, error) {
+ return payloads, nil
+ },
+ wf.createContinuableCallback(id),
+ )
+
+ case *rrt.CompleteWorkflow:
+ result, _ := wf.env.GetDataConverter().ToPayloads("completed")
+ wf.mq.pushResponse(id, result)
+
+ if msg.Failure == nil {
+ wf.env.Complete(payloads, nil)
+ } else {
+ wf.env.Complete(nil, bindings.ConvertFailureToError(msg.Failure, wf.env.GetDataConverter()))
+ }
+
+ case *rrt.ContinueAsNew:
+ result, _ := wf.env.GetDataConverter().ToPayloads("completed")
+ wf.mq.pushResponse(id, result)
+
+ wf.env.Complete(nil, &workflow.ContinueAsNewError{
+ WorkflowType: &bindings.WorkflowType{Name: cmd.Name},
+ Input: payloads,
+ Header: wf.header,
+ TaskQueueName: cmd.Options.TaskQueueName,
+ WorkflowExecutionTimeout: cmd.Options.WorkflowExecutionTimeout,
+ WorkflowRunTimeout: cmd.Options.WorkflowRunTimeout,
+ WorkflowTaskTimeout: cmd.Options.WorkflowTaskTimeout,
+ })
+
+ case *rrt.SignalExternalWorkflow:
+ wf.env.SignalExternalWorkflow(
+ cmd.Namespace,
+ cmd.WorkflowID,
+ cmd.RunID,
+ cmd.Signal,
+ payloads,
+ nil,
+ cmd.ChildWorkflowOnly,
+ wf.createCallback(id),
+ )
+
+ case *rrt.CancelExternalWorkflow:
+ wf.env.RequestCancelExternalWorkflow(cmd.Namespace, cmd.WorkflowID, cmd.RunID, wf.createCallback(id))
+
+ case *rrt.Cancel:
+ err = wf.canceller.cancel(cmd.CommandIDs...)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ result, _ := wf.env.GetDataConverter().ToPayloads("completed")
+ wf.mq.pushResponse(id, result)
+
+ err = wf.flushQueue()
+ if err != nil {
+ panic(err)
+ }
+
+ case *rrt.Panic:
+ panic(errors.E(cmd.Message))
+
+ default:
+ panic("undefined command")
+ }
+
+ return nil
+}
+
+func (wf *workflowProcess) createCallback(id uint64) bindings.ResultHandler {
+ callback := func(result *commonpb.Payloads, err error) error {
+ wf.canceller.discard(id)
+
+ if err != nil {
+ wf.mq.pushError(id, bindings.ConvertErrorToFailure(err, wf.env.GetDataConverter()))
+ return nil
+ }
+
+ // fetch original payload
+ wf.mq.pushResponse(id, result)
+ return nil
+ }
+
+ return func(result *commonpb.Payloads, err error) {
+ // timer cancel callback can happen inside the loop
+ if wf.inLoop {
+ err := callback(result, err)
+ if err != nil {
+ panic(err)
+ }
+
+ return
+ }
+
+ wf.callbacks = append(wf.callbacks, func() error {
+ return callback(result, err)
+ })
+ }
+}
+
+// callback to be called inside the queue processing, adds new messages at the end of the queue
+func (wf *workflowProcess) createContinuableCallback(id uint64) bindings.ResultHandler {
+ callback := func(result *commonpb.Payloads, err error) {
+ wf.canceller.discard(id)
+
+ if err != nil {
+ wf.mq.pushError(id, bindings.ConvertErrorToFailure(err, wf.env.GetDataConverter()))
+ return
+ }
+
+ wf.mq.pushResponse(id, result)
+ err = wf.flushQueue()
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ return func(result *commonpb.Payloads, err error) {
+ callback(result, err)
+ }
+}
+
+// Exchange messages between host and worker processes and add new commands to the queue.
+func (wf *workflowProcess) flushQueue() error {
+ const op = errors.Op("flush queue")
+ messages, err := wf.codec.Execute(wf.pool, wf.getContext(), wf.mq.queue...)
+ wf.mq.flush()
+
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ wf.pipeline = append(wf.pipeline, messages...)
+
+ return nil
+}
+
+// Exchange messages between host and worker processes without adding new commands to the queue.
+func (wf *workflowProcess) discardQueue() ([]rrt.Message, error) {
+ const op = errors.Op("discard queue")
+ messages, err := wf.codec.Execute(wf.pool, wf.getContext(), wf.mq.queue...)
+ wf.mq.flush()
+
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ return messages, nil
+}
+
+// Run single command and return single result.
+func (wf *workflowProcess) runCommand(cmd interface{}, payloads *commonpb.Payloads) (rrt.Message, error) {
+ const op = errors.Op("workflow_process_runcommand")
+ _, msg := wf.mq.allocateMessage(cmd, payloads)
+
+ result, err := wf.codec.Execute(wf.pool, wf.getContext(), msg)
+ if err != nil {
+ return rrt.Message{}, errors.E(op, err)
+ }
+
+ if len(result) != 1 {
+ return rrt.Message{}, errors.E(op, errors.Str("unexpected worker response"))
+ }
+
+ return result[0], nil
+}
diff --git a/plugins/temporal/workflow/workflow_pool.go b/plugins/temporal/workflow/workflow_pool.go
new file mode 100644
index 00000000..b9ed46c8
--- /dev/null
+++ b/plugins/temporal/workflow/workflow_pool.go
@@ -0,0 +1,190 @@
+package workflow
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/pkg/events"
+ "github.com/spiral/roadrunner/v2/pkg/payload"
+ rrWorker "github.com/spiral/roadrunner/v2/pkg/worker"
+ "github.com/spiral/roadrunner/v2/plugins/server"
+ "github.com/spiral/roadrunner/v2/plugins/temporal/client"
+ rrt "github.com/spiral/roadrunner/v2/plugins/temporal/protocol"
+ bindings "go.temporal.io/sdk/internalbindings"
+ "go.temporal.io/sdk/worker"
+ "go.temporal.io/sdk/workflow"
+)
+
+const eventWorkerExit = 8390
+
+// RR_MODE env variable key
+const RR_MODE = "RR_MODE" //nolint
+
+// RR_CODEC env variable key
+const RR_CODEC = "RR_CODEC" //nolint
+
+type workflowPool interface {
+ SeqID() uint64
+ Exec(p payload.Payload) (payload.Payload, error)
+ Start(ctx context.Context, temporal client.Temporal) error
+ Destroy(ctx context.Context) error
+ Workers() []rrWorker.BaseProcess
+ WorkflowNames() []string
+}
+
+// PoolEvent triggered on workflow pool worker events.
+type PoolEvent struct {
+ Event int
+ Context interface{}
+ Caused error
+}
+
+// workflowPoolImpl manages workflowProcess executions between worker restarts.
+type workflowPoolImpl struct {
+ codec rrt.Codec
+ seqID uint64
+ workflows map[string]rrt.WorkflowInfo
+ tWorkers []worker.Worker
+ mu sync.Mutex
+ worker rrWorker.SyncWorker
+ active bool
+}
+
+// newWorkflowPool creates new workflow pool.
+func newWorkflowPool(codec rrt.Codec, listener events.Listener, factory server.Server) (workflowPool, error) {
+ const op = errors.Op("new_workflow_pool")
+ w, err := factory.NewWorker(
+ context.Background(),
+ map[string]string{RR_MODE: RRMode, RR_CODEC: codec.GetName()},
+ listener,
+ )
+ if err != nil {
+ return nil, errors.E(op, err)
+ }
+
+ go func() {
+ err := w.Wait()
+ listener(PoolEvent{Event: eventWorkerExit, Caused: err})
+ }()
+
+ return &workflowPoolImpl{codec: codec, worker: rrWorker.From(w)}, nil
+}
+
+// Start the pool in non blocking mode.
+func (pool *workflowPoolImpl) Start(ctx context.Context, temporal client.Temporal) error {
+ const op = errors.Op("workflow_pool_start")
+ pool.mu.Lock()
+ pool.active = true
+ pool.mu.Unlock()
+
+ err := pool.initWorkers(ctx, temporal)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ for i := 0; i < len(pool.tWorkers); i++ {
+ err := pool.tWorkers[i].Start()
+ if err != nil {
+ return errors.E(op, err)
+ }
+ }
+
+ return nil
+}
+
+// Active.
+func (pool *workflowPoolImpl) Active() bool {
+ return pool.active
+}
+
+// Destroy stops all temporal workers and application worker.
+func (pool *workflowPoolImpl) Destroy(ctx context.Context) error {
+ pool.mu.Lock()
+ defer pool.mu.Unlock()
+ const op = errors.Op("workflow_pool_destroy")
+
+ pool.active = false
+ for i := 0; i < len(pool.tWorkers); i++ {
+ pool.tWorkers[i].Stop()
+ }
+
+ worker.PurgeStickyWorkflowCache()
+
+ err := pool.worker.Stop()
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ return nil
+}
+
+// NewWorkflowDefinition initiates new workflow process.
+func (pool *workflowPoolImpl) NewWorkflowDefinition() bindings.WorkflowDefinition {
+ return &workflowProcess{
+ codec: pool.codec,
+ pool: pool,
+ }
+}
+
+// NewWorkflowDefinition initiates new workflow process.
+func (pool *workflowPoolImpl) SeqID() uint64 {
+ return atomic.AddUint64(&pool.seqID, 1)
+}
+
+// Exec set of commands in thread safe move.
+func (pool *workflowPoolImpl) Exec(p payload.Payload) (payload.Payload, error) {
+ pool.mu.Lock()
+ defer pool.mu.Unlock()
+
+ if !pool.active {
+ return payload.Payload{}, nil
+ }
+
+ return pool.worker.Exec(p)
+}
+
+func (pool *workflowPoolImpl) Workers() []rrWorker.BaseProcess {
+ return []rrWorker.BaseProcess{pool.worker}
+}
+
+func (pool *workflowPoolImpl) WorkflowNames() []string {
+ names := make([]string, 0, len(pool.workflows))
+ for name := range pool.workflows {
+ names = append(names, name)
+ }
+
+ return names
+}
+
+// initWorkers request workers workflows from underlying PHP and configures temporal workers linked to the pool.
+func (pool *workflowPoolImpl) initWorkers(ctx context.Context, temporal client.Temporal) error {
+ const op = errors.Op("workflow_pool_init_workers")
+ workerInfo, err := rrt.FetchWorkerInfo(pool.codec, pool, temporal.GetDataConverter())
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ pool.workflows = make(map[string]rrt.WorkflowInfo)
+ pool.tWorkers = make([]worker.Worker, 0)
+
+ for _, info := range workerInfo {
+ w, err := temporal.CreateWorker(info.TaskQueue, info.Options)
+ if err != nil {
+ return errors.E(op, err, pool.Destroy(ctx))
+ }
+
+ pool.tWorkers = append(pool.tWorkers, w)
+ for _, workflowInfo := range info.Workflows {
+ w.RegisterWorkflowWithOptions(pool, workflow.RegisterOptions{
+ Name: workflowInfo.Name,
+ DisableAlreadyRegisteredCheck: false,
+ })
+
+ pool.workflows[workflowInfo.Name] = workflowInfo
+ }
+ }
+
+ return nil
+}