summaryrefslogtreecommitdiff
path: root/socket_factory.go
blob: a77758e9cd165e1911100af342b52c97a1f311b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package roadrunner

import (
	"fmt"
	"github.com/pkg/errors"
	"github.com/spiral/goridge"
	"net"
	"os/exec"
	"sync"
	"time"
)

// SocketFactory connects to external workers using socket server.
type SocketFactory struct {
	// listens for incoming connections from underlying processes
	ls net.Listener

	// relay connection timeout
	tout time.Duration

	// protects socket mapping
	mu sync.Mutex

	// sockets which are waiting for process association
	relays map[int]chan *goridge.SocketRelay
}

// NewSocketFactory returns SocketFactory attached to a given socket listener.
// tout specifies for how long factory should serve for incoming relay connection
func NewSocketFactory(ls net.Listener, tout time.Duration) *SocketFactory {
	f := &SocketFactory{
		ls:     ls,
		tout:   tout,
		relays: make(map[int]chan *goridge.SocketRelay),
	}

	go f.listen()

	return f
}

// SpawnWorker creates worker and connects it to appropriate relay or returns error
func (f *SocketFactory) SpawnWorker(cmd *exec.Cmd) (w *Worker, err error) {
	if w, err = newWorker(cmd); err != nil {
		return nil, err
	}

	if err := w.Start(); err != nil {
		return nil, errors.Wrap(err, "process error")
	}

	rl, err := f.findRelay(w, f.tout)
	if err != nil {
		go func(w *Worker) { w.Kill() }(w)

		if wErr := w.Wait(); wErr != nil {
			err = errors.Wrap(wErr, err.Error())
		}

		return nil, errors.Wrap(err, "unable to connect to worker")
	}

	w.rl = rl
	w.state.set(StateReady)

	return w, nil
}

// listens for incoming socket connections
func (f *SocketFactory) listen() {
	for {
		conn, err := f.ls.Accept()
		if err != nil {
			return
		}

		rl := goridge.NewSocketRelay(conn)
		if pid, err := fetchPID(rl); err == nil {
			f.relayChan(pid) <- rl
		}
	}
}

// waits for worker to connect over socket and returns associated relay of timeout
func (f *SocketFactory) findRelay(w *Worker, tout time.Duration) (*goridge.SocketRelay, error) {
	timer := time.NewTimer(tout)
	for {
		select {
		case rl := <-f.relayChan(*w.Pid):
			timer.Stop()
			f.cleanChan(*w.Pid)
			return rl, nil

		case <-timer.C:
			return nil, fmt.Errorf("relay timeout")

		case <-w.waitDone:
			timer.Stop()
			f.cleanChan(*w.Pid)
			return nil, fmt.Errorf("worker is gone")
		}
	}
}

// chan to store relay associated with specific Pid
func (f *SocketFactory) relayChan(pid int) chan *goridge.SocketRelay {
	f.mu.Lock()
	defer f.mu.Unlock()

	rl, ok := f.relays[pid]
	if !ok {
		f.relays[pid] = make(chan *goridge.SocketRelay)
		return f.relays[pid]
	}

	return rl
}

// deletes relay chan associated with specific Pid
func (f *SocketFactory) cleanChan(pid int) {
	f.mu.Lock()
	defer f.mu.Unlock()

	delete(f.relays, pid)
}