summaryrefslogtreecommitdiff
path: root/socket_factory.go
diff options
context:
space:
mode:
Diffstat (limited to 'socket_factory.go')
-rw-r--r--socket_factory.go237
1 files changed, 159 insertions, 78 deletions
diff --git a/socket_factory.go b/socket_factory.go
index 42196588..27558cce 100644
--- a/socket_factory.go
+++ b/socket_factory.go
@@ -1,16 +1,20 @@
package roadrunner
import (
- "fmt"
- "github.com/pkg/errors"
- "github.com/spiral/goridge/v2"
+ "context"
"net"
"os/exec"
+ "strings"
"sync"
"time"
+
+ "github.com/pkg/errors"
+ "github.com/spiral/goridge/v2"
+ "go.uber.org/multierr"
+ "golang.org/x/sync/errgroup"
)
-// SocketFactory connects to external workers using socket server.
+// SocketFactory connects to external stack using socket server.
type SocketFactory struct {
// listens for incoming connections from underlying processes
ls net.Listener
@@ -18,122 +22,199 @@ type SocketFactory struct {
// relay connection timeout
tout time.Duration
- // protects socket mapping
- mu sync.Mutex
-
// sockets which are waiting for process association
- relays map[int]chan *goridge.SocketRelay
+ // relays map[int64]*goridge.SocketRelay
+ relays sync.Map
+
+ ErrCh chan error
}
-// NewSocketFactory returns SocketFactory attached to a given socket lsn.
+// todo: review
+
+// NewSocketServer returns SocketFactory attached to a given socket listener.
// tout specifies for how long factory should serve for incoming relay connection
-func NewSocketFactory(ls net.Listener, tout time.Duration) *SocketFactory {
+func NewSocketServer(ls net.Listener, tout time.Duration) Factory {
f := &SocketFactory{
ls: ls,
tout: tout,
- relays: make(map[int]chan *goridge.SocketRelay),
+ relays: sync.Map{},
+ ErrCh: make(chan error, 10),
}
- go f.listen()
+ // Be careful
+ // https://github.com/go101/go101/wiki/About-memory-ordering-guarantees-made-by-atomic-operations-in-Go
+ // https://github.com/golang/go/issues/5045
+ go func() {
+ f.ErrCh <- f.listen()
+ }()
return f
}
-// SpawnWorker creates worker and connects it to appropriate relay or returns error
-func (f *SocketFactory) SpawnWorker(cmd *exec.Cmd) (w *Worker, err error) {
- if w, err = newWorker(cmd); err != nil {
+// blocking operation, returns an error
+func (f *SocketFactory) listen() error {
+ errGr := &errgroup.Group{}
+ errGr.Go(func() error {
+ for {
+ conn, err := f.ls.Accept()
+ if err != nil {
+ return err
+ }
+
+ rl := goridge.NewSocketRelay(conn)
+ pid, err := fetchPID(rl)
+ if err != nil {
+ return err
+ }
+
+ f.attachRelayToPid(pid, rl)
+ }
+ })
+
+ return errGr.Wait()
+}
+
+type socketSpawn struct {
+ w WorkerBase
+ err error
+}
+
+// SpawnWorker creates WorkerProcess and connects it to appropriate relay or returns error
+func (f *SocketFactory) SpawnWorkerWithContext(ctx context.Context, cmd *exec.Cmd) (WorkerBase, error) {
+ c := make(chan socketSpawn)
+ go func() {
+ ctx, cancel := context.WithTimeout(ctx, f.tout)
+ defer cancel()
+ w, err := InitBaseWorker(cmd)
+ if err != nil {
+ c <- socketSpawn{
+ w: nil,
+ err: err,
+ }
+ return
+ }
+
+ err = w.Start()
+ if err != nil {
+ c <- socketSpawn{
+ w: nil,
+ err: errors.Wrap(err, "process error"),
+ }
+ return
+ }
+
+ rl, err := f.findRelayWithContext(ctx, w)
+ if err != nil {
+ err = multierr.Combine(
+ err,
+ w.Kill(context.Background()),
+ w.Wait(context.Background()),
+ )
+ c <- socketSpawn{
+ w: nil,
+ err: err,
+ }
+ return
+ }
+
+ w.AttachRelay(rl)
+ w.State().Set(StateReady)
+
+ c <- socketSpawn{
+ w: w,
+ err: nil,
+ }
+ return
+ }()
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case res := <-c:
+ if res.err != nil {
+ return nil, res.err
+ }
+
+ return res.w, nil
+ }
+}
+
+func (f *SocketFactory) SpawnWorker(cmd *exec.Cmd) (WorkerBase, error) {
+ ctx := context.Background()
+ w, err := InitBaseWorker(cmd)
+ if err != nil {
return nil, err
}
- if err := w.start(); err != nil {
+ err = w.Start()
+ if err != nil {
return nil, errors.Wrap(err, "process error")
}
- rl, err := f.findRelay(w, f.tout)
+ var errs []string
+ rl, err := f.findRelay(w)
if err != nil {
- go func(w *Worker) {
- err := w.Kill()
- if err != nil {
- fmt.Println(fmt.Errorf("error killing the worker %v", err))
- }
- }(w)
-
- if wErr := w.Wait(); wErr != nil {
- if _, ok := wErr.(*exec.ExitError); ok {
- err = errors.Wrap(wErr, err.Error())
- } else {
- err = wErr
- }
+ errs = append(errs, err.Error())
+ err = w.Kill(ctx)
+ if err != nil {
+ errs = append(errs, err.Error())
}
-
- return nil, errors.Wrap(err, "unable to connect to worker")
+ if err = w.Wait(ctx); err != nil {
+ errs = append(errs, err.Error())
+ }
+ return nil, errors.New(strings.Join(errs, "/"))
}
- w.rl = rl
- w.state.set(StateReady)
+ w.AttachRelay(rl)
+ w.State().Set(StateReady)
return w, nil
}
// Close socket factory and underlying socket connection.
-func (f *SocketFactory) Close() error {
+func (f *SocketFactory) Close(ctx context.Context) error {
return f.ls.Close()
}
-// listens for incoming socket connections
-func (f *SocketFactory) listen() {
+// waits for WorkerProcess to connect over socket and returns associated relay of timeout
+func (f *SocketFactory) findRelayWithContext(ctx context.Context, w WorkerBase) (*goridge.SocketRelay, error) {
for {
- conn, err := f.ls.Accept()
- if err != nil {
- return
- }
-
- rl := goridge.NewSocketRelay(conn)
- if pid, err := fetchPID(rl); err == nil {
- f.relayChan(pid) <- rl
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ tmp, ok := f.relays.Load(w.Pid())
+ if !ok {
+ continue
+ }
+ return tmp.(*goridge.SocketRelay), nil
}
}
}
-// waits for worker to connect over socket and returns associated relay of timeout
-func (f *SocketFactory) findRelay(w *Worker, tout time.Duration) (*goridge.SocketRelay, error) {
- timer := time.NewTimer(tout)
+func (f *SocketFactory) findRelay(w WorkerBase) (*goridge.SocketRelay, error) {
+ // poll every 1ms for the relay
+ pollDone := time.NewTimer(f.tout)
for {
select {
- case rl := <-f.relayChan(*w.Pid):
- timer.Stop()
- f.cleanChan(*w.Pid)
- return rl, nil
-
- case <-timer.C:
- return nil, fmt.Errorf("relay timeout")
-
- case <-w.waitDone:
- timer.Stop()
- f.cleanChan(*w.Pid)
- return nil, fmt.Errorf("worker is gone")
+ case <-pollDone.C:
+ return nil, errors.New("relay timeout")
+ default:
+ tmp, ok := f.relays.Load(w.Pid())
+ if !ok {
+ continue
+ }
+ return tmp.(*goridge.SocketRelay), nil
}
}
}
-// chan to store relay associated with specific Pid
-func (f *SocketFactory) relayChan(pid int) chan *goridge.SocketRelay {
- f.mu.Lock()
- defer f.mu.Unlock()
-
- rl, ok := f.relays[pid]
- if !ok {
- f.relays[pid] = make(chan *goridge.SocketRelay)
- return f.relays[pid]
- }
-
- return rl
+// chan to store relay associated with specific pid
+func (f *SocketFactory) attachRelayToPid(pid int64, relay *goridge.SocketRelay) {
+ f.relays.Store(pid, relay)
}
-// deletes relay chan associated with specific Pid
-func (f *SocketFactory) cleanChan(pid int) {
- f.mu.Lock()
- defer f.mu.Unlock()
-
- delete(f.relays, pid)
+// deletes relay chan associated with specific pid
+func (f *SocketFactory) removeRelayFromPid(pid int64) {
+ f.relays.Delete(pid)
}