summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2017-07-06 00:29:51 -0700
committerDave Anderson <[email protected]>2017-07-06 21:36:14 -0700
commit815c9425f1ad46ffd3a3fb1bbefc05440072e4a4 (patch)
tree199b42d5595dad7c0e2aef8ea7500abe3fa4d55b
parent2065af4b1e2d181a987a23f64c66f43e474469ff (diff)
Merge matcher and route into an interface that yields a Target.
This allows routes to compute a target at match time, instead of being statically mapped to a Target at register time.
-rw-r--r--http.go14
-rw-r--r--sni.go14
-rw-r--r--tcpproxy.go27
-rw-r--r--tcpproxy_test.go9
4 files changed, 39 insertions, 25 deletions
diff --git a/http.go b/http.go
index 423c9a9..601d535 100644
--- a/http.go
+++ b/http.go
@@ -27,13 +27,19 @@ import (
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
- p.addRoute(ipPort, httpHostMatch(httpHost), dest)
+ p.addRoute(ipPort, httpHostMatch{httpHost, dest})
}
-type httpHostMatch string
+type httpHostMatch struct {
+ host string
+ target Target
+}
-func (host httpHostMatch) match(br *bufio.Reader) bool {
- return httpHostHeader(br) == string(host)
+func (m httpHostMatch) match(br *bufio.Reader) Target {
+ if httpHostHeader(br) == m.host {
+ return m.target
+ }
+ return nil
}
// httpHostHeader returns the HTTP Host header from br without
diff --git a/sni.go b/sni.go
index e12c744..f0128bf 100644
--- a/sni.go
+++ b/sni.go
@@ -29,13 +29,19 @@ import (
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
- p.addRoute(ipPort, sniMatch(sni), dest)
+ p.addRoute(ipPort, sniMatch{sni, dest})
}
-type sniMatch string
+type sniMatch struct {
+ sni string
+ target Target
+}
-func (sni sniMatch) match(br *bufio.Reader) bool {
- return clientHelloServerName(br) == string(sni)
+func (m sniMatch) match(br *bufio.Reader) Target {
+ if clientHelloServerName(br) == string(m.sni) {
+ return m.target
+ }
+ return nil
}
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
diff --git a/tcpproxy.go b/tcpproxy.go
index bada4d3..1eee7ea 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -69,7 +69,7 @@ import (
// The order that routes are added in matters; each is matched in the order
// registered.
type Proxy struct {
- routes map[string][]route // ip:port => route
+ routes map[string][]route // ip:port => routes
lns []net.Listener
donec chan struct{} // closed before err
@@ -81,13 +81,8 @@ type Proxy struct {
ListenFunc func(net, laddr string) (net.Listener, error)
}
-type route struct {
- matcher matcher
- target Target
-}
-
-type matcher interface {
- match(*bufio.Reader) bool
+type route interface {
+ match(*bufio.Reader) Target
}
func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
@@ -97,11 +92,11 @@ func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
return net.Listen
}
-func (p *Proxy) addRoute(ipPort string, matcher matcher, dest Target) {
+func (p *Proxy) addRoute(ipPort string, r route) {
if p.routes == nil {
p.routes = make(map[string][]route)
}
- p.routes[ipPort] = append(p.routes[ipPort], route{matcher, dest})
+ p.routes[ipPort] = append(p.routes[ipPort], r)
}
// AddRoute appends an always-matching route to the ipPort listener,
@@ -112,12 +107,14 @@ func (p *Proxy) addRoute(ipPort string, matcher matcher, dest Target) {
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
- p.addRoute(ipPort, alwaysMatch{}, dest)
+ p.addRoute(ipPort, alwaysMatch{dest})
}
-type alwaysMatch struct{}
+type alwaysMatch struct {
+ t Target
+}
-func (alwaysMatch) match(*bufio.Reader) bool { return true }
+func (m alwaysMatch) match(*bufio.Reader) Target { return m.t }
// Run is calls Start, and then Wait.
//
@@ -194,7 +191,7 @@ func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route)
func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
br := bufio.NewReader(c)
for _, route := range routes {
- if route.matcher.match(br) {
+ if target := route.match(br); target != nil {
if n := br.Buffered(); n > 0 {
peeked, _ := br.Peek(br.Buffered())
c = &Conn{
@@ -202,7 +199,7 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
Conn: c,
}
}
- route.target.HandleConn(c)
+ target.HandleConn(c)
return true
}
}
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 45d8b0e..7150372 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -27,6 +27,10 @@ import (
"testing"
)
+type noopTarget struct{}
+
+func (t *noopTarget) HandleConn(net.Conn) {}
+
func TestMatchHTTPHost(t *testing.T) {
tests := []struct {
name string
@@ -53,6 +57,7 @@ func TestMatchHTTPHost(t *testing.T) {
want: true,
},
}
+ target := &noopTarget{}
for i, tt := range tests {
name := tt.name
if name == "" {
@@ -60,8 +65,8 @@ func TestMatchHTTPHost(t *testing.T) {
}
t.Run(name, func(t *testing.T) {
br := bufio.NewReader(tt.r)
- var matcher matcher = httpHostMatch(tt.host)
- got := matcher.match(br)
+ r := httpHostMatch{tt.host, target}
+ got := r.match(br) != nil
if got != tt.want {
t.Fatalf("match = %v; want %v", got, tt.want)
}