diff options
-rw-r--r-- | http.go | 9 | ||||
-rw-r--r-- | sni.go | 17 | ||||
-rw-r--r-- | tcpproxy.go | 22 | ||||
-rw-r--r-- | tcpproxy_test.go | 6 |
4 files changed, 36 insertions, 18 deletions
@@ -46,11 +46,12 @@ type httpHostMatch struct { target Target } -func (m httpHostMatch) match(br *bufio.Reader) Target { - if m.matcher(context.TODO(), httpHostHeader(br)) { - return m.target +func (m httpHostMatch) match(br *bufio.Reader) (Target, string) { + hh := httpHostHeader(br) + if m.matcher(context.TODO(), hh) { + return m.target, hh } - return nil + return nil, "" } // httpHostHeader returns the HTTP Host header from br without @@ -73,11 +73,12 @@ type sniMatch struct { target Target } -func (m sniMatch) match(br *bufio.Reader) Target { - if m.matcher(context.TODO(), clientHelloServerName(br)) { - return m.target +func (m sniMatch) match(br *bufio.Reader) (Target, string) { + sni := clientHelloServerName(br) + if m.matcher(context.TODO(), sni) { + return m.target, sni } - return nil + return nil, "" } // acmeMatch matches "*.acme.invalid" ACME tls-sni-01 challenges and @@ -87,10 +88,10 @@ type acmeMatch struct { cfg *config } -func (m *acmeMatch) match(br *bufio.Reader) Target { +func (m *acmeMatch) match(br *bufio.Reader) (Target, string) { sni := clientHelloServerName(br) if !strings.HasSuffix(sni, ".acme.invalid") { - return nil + return nil, "" } // TODO: cache. ACME issuers will hit multiple times in a short @@ -107,12 +108,12 @@ func (m *acmeMatch) match(br *bufio.Reader) Target { } for range m.cfg.acmeTargets { if target := <-ch; target != nil { - return target + return target, sni } } // No target was happy with the provided challenge. - return nil + return nil, "" } func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) { diff --git a/tcpproxy.go b/tcpproxy.go index 8c33604..40a6c2c 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -107,7 +107,10 @@ type route interface { // // match must not consume bytes from the given bufio.Reader, it // can only Peek. - match(*bufio.Reader) Target + // + // If an sni or host header was parsed successfully, that will be + // returned as the second parameter. + match(*bufio.Reader) (Target, string) } func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { @@ -147,7 +150,7 @@ type fixedTarget struct { t Target } -func (m fixedTarget) match(*bufio.Reader) Target { return m.t } +func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" } // Run is calls Start, and then Wait. // @@ -224,12 +227,13 @@ func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) func (p *Proxy) serveConn(c net.Conn, routes []route) bool { br := bufio.NewReader(c) for _, route := range routes { - if target := route.match(br); target != nil { + if target, hostName := route.match(br); target != nil { if n := br.Buffered(); n > 0 { peeked, _ := br.Peek(br.Buffered()) c = &Conn{ - Peeked: peeked, - Conn: c, + HostName: hostName, + Peeked: peeked, + Conn: c, } } target.HandleConn(c) @@ -246,6 +250,14 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool { // to determine how to route the connection. The Read method stitches // the peeked bytes and unread bytes back together. type Conn struct { + // HostName is the hostname field that was sent to the request router. + // In the case of TLS, this is the SNI header, in the case of HTTPHost + // route, it will be the host header. In the case of a fixed + // route, i.e. those created with AddRoute(), this will always be + // empty. This can be useful in the case where further routing decisions + // need to be made in the Target impementation. + HostName string + // Peeked are the bytes that have been read from Conn for the // purposes of route matching, but have not yet been consumed // by Read calls. It set to nil by Read when fully consumed. diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 682214d..83729e3 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -72,10 +72,14 @@ func TestMatchHTTPHost(t *testing.T) { t.Run(name, func(t *testing.T) { br := bufio.NewReader(tt.r) r := httpHostMatch{equals(tt.host), noopTarget{}} - got := r.match(br) != nil + m, name := r.match(br) + got := m != nil if got != tt.want { t.Fatalf("match = %v; want %v", got, tt.want) } + if tt.want && name != tt.host { + t.Fatalf("host = %s; want %s", name, tt.host) + } get := make([]byte, 3) if _, err := io.ReadFull(br, get); err != nil { t.Fatal(err) |