diff options
Diffstat (limited to 'socket_factory.go')
-rw-r--r-- | socket_factory.go | 237 |
1 files changed, 159 insertions, 78 deletions
diff --git a/socket_factory.go b/socket_factory.go index 42196588..27558cce 100644 --- a/socket_factory.go +++ b/socket_factory.go @@ -1,16 +1,20 @@ package roadrunner import ( - "fmt" - "github.com/pkg/errors" - "github.com/spiral/goridge/v2" + "context" "net" "os/exec" + "strings" "sync" "time" + + "github.com/pkg/errors" + "github.com/spiral/goridge/v2" + "go.uber.org/multierr" + "golang.org/x/sync/errgroup" ) -// SocketFactory connects to external workers using socket server. +// SocketFactory connects to external stack using socket server. type SocketFactory struct { // listens for incoming connections from underlying processes ls net.Listener @@ -18,122 +22,199 @@ type SocketFactory struct { // relay connection timeout tout time.Duration - // protects socket mapping - mu sync.Mutex - // sockets which are waiting for process association - relays map[int]chan *goridge.SocketRelay + // relays map[int64]*goridge.SocketRelay + relays sync.Map + + ErrCh chan error } -// NewSocketFactory returns SocketFactory attached to a given socket lsn. +// todo: review + +// NewSocketServer returns SocketFactory attached to a given socket listener. // tout specifies for how long factory should serve for incoming relay connection -func NewSocketFactory(ls net.Listener, tout time.Duration) *SocketFactory { +func NewSocketServer(ls net.Listener, tout time.Duration) Factory { f := &SocketFactory{ ls: ls, tout: tout, - relays: make(map[int]chan *goridge.SocketRelay), + relays: sync.Map{}, + ErrCh: make(chan error, 10), } - go f.listen() + // Be careful + // https://github.com/go101/go101/wiki/About-memory-ordering-guarantees-made-by-atomic-operations-in-Go + // https://github.com/golang/go/issues/5045 + go func() { + f.ErrCh <- f.listen() + }() return f } -// SpawnWorker creates worker and connects it to appropriate relay or returns error -func (f *SocketFactory) SpawnWorker(cmd *exec.Cmd) (w *Worker, err error) { - if w, err = newWorker(cmd); err != nil { +// blocking operation, returns an error +func (f *SocketFactory) listen() error { + errGr := &errgroup.Group{} + errGr.Go(func() error { + for { + conn, err := f.ls.Accept() + if err != nil { + return err + } + + rl := goridge.NewSocketRelay(conn) + pid, err := fetchPID(rl) + if err != nil { + return err + } + + f.attachRelayToPid(pid, rl) + } + }) + + return errGr.Wait() +} + +type socketSpawn struct { + w WorkerBase + err error +} + +// SpawnWorker creates WorkerProcess and connects it to appropriate relay or returns error +func (f *SocketFactory) SpawnWorkerWithContext(ctx context.Context, cmd *exec.Cmd) (WorkerBase, error) { + c := make(chan socketSpawn) + go func() { + ctx, cancel := context.WithTimeout(ctx, f.tout) + defer cancel() + w, err := InitBaseWorker(cmd) + if err != nil { + c <- socketSpawn{ + w: nil, + err: err, + } + return + } + + err = w.Start() + if err != nil { + c <- socketSpawn{ + w: nil, + err: errors.Wrap(err, "process error"), + } + return + } + + rl, err := f.findRelayWithContext(ctx, w) + if err != nil { + err = multierr.Combine( + err, + w.Kill(context.Background()), + w.Wait(context.Background()), + ) + c <- socketSpawn{ + w: nil, + err: err, + } + return + } + + w.AttachRelay(rl) + w.State().Set(StateReady) + + c <- socketSpawn{ + w: w, + err: nil, + } + return + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case res := <-c: + if res.err != nil { + return nil, res.err + } + + return res.w, nil + } +} + +func (f *SocketFactory) SpawnWorker(cmd *exec.Cmd) (WorkerBase, error) { + ctx := context.Background() + w, err := InitBaseWorker(cmd) + if err != nil { return nil, err } - if err := w.start(); err != nil { + err = w.Start() + if err != nil { return nil, errors.Wrap(err, "process error") } - rl, err := f.findRelay(w, f.tout) + var errs []string + rl, err := f.findRelay(w) if err != nil { - go func(w *Worker) { - err := w.Kill() - if err != nil { - fmt.Println(fmt.Errorf("error killing the worker %v", err)) - } - }(w) - - if wErr := w.Wait(); wErr != nil { - if _, ok := wErr.(*exec.ExitError); ok { - err = errors.Wrap(wErr, err.Error()) - } else { - err = wErr - } + errs = append(errs, err.Error()) + err = w.Kill(ctx) + if err != nil { + errs = append(errs, err.Error()) } - - return nil, errors.Wrap(err, "unable to connect to worker") + if err = w.Wait(ctx); err != nil { + errs = append(errs, err.Error()) + } + return nil, errors.New(strings.Join(errs, "/")) } - w.rl = rl - w.state.set(StateReady) + w.AttachRelay(rl) + w.State().Set(StateReady) return w, nil } // Close socket factory and underlying socket connection. -func (f *SocketFactory) Close() error { +func (f *SocketFactory) Close(ctx context.Context) error { return f.ls.Close() } -// listens for incoming socket connections -func (f *SocketFactory) listen() { +// waits for WorkerProcess to connect over socket and returns associated relay of timeout +func (f *SocketFactory) findRelayWithContext(ctx context.Context, w WorkerBase) (*goridge.SocketRelay, error) { for { - conn, err := f.ls.Accept() - if err != nil { - return - } - - rl := goridge.NewSocketRelay(conn) - if pid, err := fetchPID(rl); err == nil { - f.relayChan(pid) <- rl + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + tmp, ok := f.relays.Load(w.Pid()) + if !ok { + continue + } + return tmp.(*goridge.SocketRelay), nil } } } -// waits for worker to connect over socket and returns associated relay of timeout -func (f *SocketFactory) findRelay(w *Worker, tout time.Duration) (*goridge.SocketRelay, error) { - timer := time.NewTimer(tout) +func (f *SocketFactory) findRelay(w WorkerBase) (*goridge.SocketRelay, error) { + // poll every 1ms for the relay + pollDone := time.NewTimer(f.tout) for { select { - case rl := <-f.relayChan(*w.Pid): - timer.Stop() - f.cleanChan(*w.Pid) - return rl, nil - - case <-timer.C: - return nil, fmt.Errorf("relay timeout") - - case <-w.waitDone: - timer.Stop() - f.cleanChan(*w.Pid) - return nil, fmt.Errorf("worker is gone") + case <-pollDone.C: + return nil, errors.New("relay timeout") + default: + tmp, ok := f.relays.Load(w.Pid()) + if !ok { + continue + } + return tmp.(*goridge.SocketRelay), nil } } } -// chan to store relay associated with specific Pid -func (f *SocketFactory) relayChan(pid int) chan *goridge.SocketRelay { - f.mu.Lock() - defer f.mu.Unlock() - - rl, ok := f.relays[pid] - if !ok { - f.relays[pid] = make(chan *goridge.SocketRelay) - return f.relays[pid] - } - - return rl +// chan to store relay associated with specific pid +func (f *SocketFactory) attachRelayToPid(pid int64, relay *goridge.SocketRelay) { + f.relays.Store(pid, relay) } -// deletes relay chan associated with specific Pid -func (f *SocketFactory) cleanChan(pid int) { - f.mu.Lock() - defer f.mu.Unlock() - - delete(f.relays, pid) +// deletes relay chan associated with specific pid +func (f *SocketFactory) removeRelayFromPid(pid int64) { + f.relays.Delete(pid) } |