summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <[email protected]>2017-06-22 15:14:37 -0700
committerBrad Fitzpatrick <[email protected]>2017-06-22 15:14:37 -0700
commit2eb0155fac2d41b022bc0a430d13aa3b45825f1d (patch)
tree9d1b725ff351a0a4d8e1a1382a5510df7618c1b2
Start of tcpproxy. No Listener or reverse dialing yet.
-rw-r--r--LICENSE27
-rw-r--r--README.md3
-rw-r--r--http.go97
-rw-r--r--sni.go66
-rw-r--r--tcpproxy.go308
-rw-r--r--tcpproxy_test.go260
6 files changed, 761 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..32017f8
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2017 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..ffc7269
--- /dev/null
+++ b/README.md
@@ -0,0 +1,3 @@
+# tcpproxy
+
+See https://godoc.org/github.com/bradfitz/tcpproxy/
diff --git a/http.go b/http.go
new file mode 100644
index 0000000..69c771b
--- /dev/null
+++ b/http.go
@@ -0,0 +1,97 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "net/http"
+)
+
+// AddHTTPHostRoute appends a route to the ipPort listener that says
+// if the incoming HTTP/1.x Host header name is httpHost, the
+// connection is given to dest. If it doesn't match, rule processing
+// continues for any additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
+ p.addRoute(ipPort, httpHostMatch(httpHost), dest)
+}
+
+type httpHostMatch string
+
+func (host httpHostMatch) match(br *bufio.Reader) bool {
+ return httpHostHeader(br) == string(host)
+}
+
+// httpHostHeader returns the HTTP Host header from br without
+// consuming any of its bytes. It returns "" if it can't find one.
+func httpHostHeader(br *bufio.Reader) string {
+ const maxPeek = 4 << 10
+ peekSize := 0
+ for {
+ peekSize++
+ if peekSize > maxPeek {
+ b, _ := br.Peek(br.Buffered())
+ return httpHostHeaderFromBytes(b)
+ }
+ b, err := br.Peek(peekSize)
+ if n := br.Buffered(); n > peekSize {
+ b, _ = br.Peek(n)
+ peekSize = n
+ }
+ if len(b) > 0 {
+ if b[0] < 'A' || b[0] > 'Z' {
+ // Doesn't look like an HTTP verb
+ // (GET, POST, etc).
+ return ""
+ }
+ if bytes.Index(b, crlfcrlf) != -1 || bytes.Index(b, lflf) != -1 {
+ req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b)))
+ if err != nil {
+ return ""
+ }
+ if len(req.Header["Host"]) > 1 {
+ // TODO(bradfitz): what does
+ // ReadRequest do if there are
+ // multiple Host headers?
+ return ""
+ }
+ return req.Host
+ }
+ }
+ if err != nil {
+ return httpHostHeaderFromBytes(b)
+ }
+ }
+}
+
+var (
+ lfHostColon = []byte("\nHost:")
+ lfhostColon = []byte("\nhost:")
+ crlf = []byte("\r\n")
+ lf = []byte("\n")
+ crlfcrlf = []byte("\r\n\r\n")
+ lflf = []byte("\n\n")
+)
+
+func httpHostHeaderFromBytes(b []byte) string {
+ if i := bytes.Index(b, lfHostColon); i != -1 {
+ return string(bytes.TrimSpace(untilEOL(b[i+len(lfHostColon):])))
+ }
+ if i := bytes.Index(b, lfhostColon); i != -1 {
+ return string(bytes.TrimSpace(untilEOL(b[i+len(lfhostColon):])))
+ }
+ return ""
+}
+
+// untilEOL returns v, truncated before the first '\n' byte, if any.
+// The returned slice may include a '\r' at the end.
+func untilEOL(v []byte) []byte {
+ if i := bytes.IndexByte(v, '\n'); i != -1 {
+ return v[:i]
+ }
+ return v
+}
diff --git a/sni.go b/sni.go
new file mode 100644
index 0000000..d57dd31
--- /dev/null
+++ b/sni.go
@@ -0,0 +1,66 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "io"
+ "net"
+)
+
+// AddSNIRoute appends a route to the ipPort listener that says if the
+// incoming TLS SNI server name is sni, the connection is given to
+// dest. If it doesn't match, rule processing continues for any
+// additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ p.addRoute(ipPort, sniMatch(sni), dest)
+}
+
+type sniMatch string
+
+func (sni sniMatch) match(br *bufio.Reader) bool {
+ return clientHelloServerName(br) == string(sni)
+}
+
+// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
+// without consuming any bytes from br.
+// On any error, the empty string is returned.
+func clientHelloServerName(br *bufio.Reader) (sni string) {
+ const recordHeaderLen = 5
+ hdr, err := br.Peek(recordHeaderLen)
+ if err != nil {
+ return ""
+ }
+ const recordTypeHandshake = 0x16
+ if hdr[0] != recordTypeHandshake {
+ return "" // Not TLS.
+ }
+ recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
+ helloBytes, err := br.Peek(recordHeaderLen + recLen)
+ if err != nil {
+ return ""
+ }
+ tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
+ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
+ sni = hello.ServerName
+ return nil, nil
+ },
+ }).Handshake()
+ return
+}
+
+// sniSniffConn is a net.Conn that reads from r, fails on Writes,
+// and crashes otherwise.
+type sniSniffConn struct {
+ r io.Reader
+ net.Conn // nil; crash on any unexpected use
+}
+
+func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }
diff --git a/tcpproxy.go b/tcpproxy.go
new file mode 100644
index 0000000..fde438f
--- /dev/null
+++ b/tcpproxy.go
@@ -0,0 +1,308 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tcpproxy lets users build TCP, HTTP/1, and TLS+SNI proxies.
+//
+// Typical usage:
+//
+// var p tcpproxy.Proxy
+// p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081"))
+// p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082"))
+// p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback
+// p.AddSNIHostRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431"))
+// p.AddSNIHostRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432"))
+// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback
+// log.Fatal(p.Run())
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "log"
+ "net"
+ "time"
+)
+
+// Proxy is a proxy. Its zero value is a valid proxy that does
+// nothing. Call methods to add routes before calling Run.
+type Proxy struct {
+ routes map[string][]route // ip:port => route
+
+ lns []net.Listener
+ donec chan struct{} // closed before err
+ err error // any error from listening
+
+ // ListenFunc optionally specifies an alternate listen
+ // function. If nil, net.Dial is used.
+ // The provided net is always "tcp".
+ ListenFunc func(net, laddr string) (net.Listener, error)
+}
+
+type route struct {
+ matcher matcher
+ target Target
+}
+
+type matcher interface {
+ match(*bufio.Reader) bool
+}
+
+func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
+ if p.ListenFunc != nil {
+ return p.ListenFunc
+ }
+ return net.Listen
+}
+
+func (p *Proxy) addRoute(ipPort string, matcher matcher, dest Target) {
+ if p.routes == nil {
+ p.routes = make(map[string][]route)
+ }
+ p.routes[ipPort] = append(p.routes[ipPort], route{matcher, dest})
+}
+
+// AddRoute appends an always-matching route to the ipPort listener,
+// directing any connection to dest.
+//
+// This is generally used as either the only rule (for simple TCP
+// proxies), or as the final fallback rule for an ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddRoute(ipPort string, dest Target) {
+ p.addRoute(ipPort, alwaysMatch{}, dest)
+}
+
+type alwaysMatch struct{}
+
+func (alwaysMatch) match(*bufio.Reader) bool { return true }
+
+// Run is calls Start, and then Wait.
+//
+// It blocks until there's an error. The return value is always
+// non-nil.
+func (p *Proxy) Run() error {
+ if err := p.Start(); err != nil {
+ return err
+ }
+ return p.Wait()
+}
+
+// Wait waits for the Proxy to finish running. Currently this can only
+// happen if a Listener is closed, or Close is called on the proxy.
+//
+// It is only valid to call Wait after a successful call to Start.
+func (p *Proxy) Wait() error {
+ <-p.donec
+ return p.err
+}
+
+// Close closes all the proxy's self-opened listeners.
+func (p *Proxy) Close() error {
+ for _, c := range p.lns {
+ c.Close()
+ }
+ return nil
+}
+
+// Start creates a TCP listener for each unique ipPort from the
+// previously created routes and starts the proxy. It returns any
+// error from starting listeners.
+//
+// If it returns a non-nil error, any successfully opened listeners
+// are closed.
+func (p *Proxy) Start() error {
+ if p.donec != nil {
+ return errors.New("already started")
+ }
+ p.donec = make(chan struct{})
+ errc := make(chan error, len(p.routes))
+ p.lns = make([]net.Listener, 0, len(p.routes))
+ for ipPort, routes := range p.routes {
+ ln, err := p.netListen()("tcp", ipPort)
+ if err != nil {
+ p.Close()
+ return err
+ }
+ p.lns = append(p.lns, ln)
+ go p.serveListener(errc, ln, routes)
+ }
+ go p.awaitFirstError(errc)
+ return nil
+}
+
+func (p *Proxy) awaitFirstError(errc <-chan error) {
+ p.err = <-errc
+ close(p.donec)
+}
+
+func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ ret <- err
+ return
+ }
+ go p.serveConn(c, routes)
+ }
+}
+
+// serveConn runs in its own goroutine and matches c against routes.
+// It returns whether it matched purely for testing.
+func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
+ br := bufio.NewReader(c)
+ for _, route := range routes {
+ if route.matcher.match(br) {
+ buffered, _ := br.Peek(br.Buffered())
+ route.target.HandleConn(changeReaderConn{
+ r: io.MultiReader(bytes.NewReader(buffered), c),
+ Conn: c,
+ }, c)
+ return true
+ }
+ }
+ // TODO: hook for this?
+ log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
+ c.Close()
+ return false
+}
+
+// changeReaderConn is a net.Conn wrapper with a separate reader function.
+type changeReaderConn struct {
+ r io.Reader
+ net.Conn
+}
+
+func (c changeReaderConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+
+// Target is what an incoming matched connection is sent to.
+type Target interface {
+ // HandleConn is called when an incoming connection is
+ // matched. After the call to HandleConn, the tcpproxy
+ // package never touches the conn again. Implementations are
+ // responsible for closing the conn when needed.
+ //
+ // The c Conn acts like a new conn, without any bytes consumed,
+ // but it has an unexported concrete type and cannot be type
+ // asserted to *net.TCPConn, etc.
+ //
+ // The rawConn represents the underlying connections (with
+ // some bytes removed) and should only be used for type
+ // assertions and setting deadlines, not reading.
+ HandleConn(c net.Conn, rawConn net.Conn)
+}
+
+// To is shorthand way of writing &DialProxy{Addr: addr}.
+func To(addr string) *DialProxy {
+ return &DialProxy{Addr: addr}
+}
+
+// DialProxy implements Target by dialing a new connection to Addr
+// and then proxying data back and forth.
+//
+// The To func is a shorthand way of creating a DialProxy.
+type DialProxy struct {
+ // Addr is the TCP address to proxy to.
+ Addr string
+
+ // KeepAlivePeriod sets the period between TCP keep alives.
+ // If zero, a default is used. To disable, use a negative number.
+ // The keep-alive is used for both the client connection and
+ KeepAlivePeriod time.Duration
+
+ // DialTimeout optionally specifies a dial timeout.
+ // If zero, a default is used.
+ // If negative, the timeout is disabled.
+ DialTimeout time.Duration
+
+ // DialContext optionally specifies an alternate dial function
+ // for TCP targets. If nil, the standard
+ // net.Dialer.DialContext method is used.
+ DialContext func(ctx context.Context, network, address string) (net.Conn, error)
+
+ // OnDialError optionally specifies an alternate way to handle errors dialing Addr.
+ // If nil, the error is logged and src is closed.
+ // If non-nil, src is not closed automatically.
+ OnDialError func(src net.Conn, dstDialErr error)
+}
+
+func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) {
+ ctx := context.Background()
+ var cancel context.CancelFunc
+ if dp.DialTimeout >= 0 {
+ ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout())
+ }
+ dst, err := dp.dialContext()(ctx, "tcp", dp.Addr)
+ if cancel != nil {
+ cancel()
+ }
+ if err != nil {
+ dp.onDialError()(src, err)
+ return
+ }
+ defer src.Close()
+ defer dst.Close()
+ if ka := dp.keepAlivePeriod(); ka > 0 {
+ if c, ok := rawSrc.(*net.TCPConn); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
+ if c, ok := dst.(*net.TCPConn); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
+ }
+ errc := make(chan error, 1)
+ go proxyCopy(errc, src, dst)
+ go proxyCopy(errc, dst, src)
+ <-errc
+}
+
+// proxyCopy is the function that copies bytes around.
+// It's a named function instead of a func literal so users get
+// named goroutines in debug goroutine stack dumps.
+func proxyCopy(errc chan<- error, dst io.Writer, src io.Reader) {
+ // TODO: make caller switch from src to rawSrc after N bytes (e.g. 4KB)
+ // if the io.Copy optimization to switch to Linux splice happens.
+ // TODO: if the runtime provides a way to wait for
+ // readability, use that to avoid stranding big blocks of
+ // memory blocked in idle reads.
+ _, err := io.Copy(dst, src)
+ errc <- err
+}
+
+func (dp *DialProxy) keepAlivePeriod() time.Duration {
+ if dp.KeepAlivePeriod != 0 {
+ return dp.KeepAlivePeriod
+ }
+ return time.Minute
+}
+
+func (dp *DialProxy) dialTimeout() time.Duration {
+ if dp.DialTimeout > 0 {
+ return dp.DialTimeout
+ }
+ return 10 * time.Second
+}
+
+var defaultDialer = new(net.Dialer)
+
+func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) {
+ if dp.DialContext != nil {
+ return dp.DialContext
+ }
+ return defaultDialer.DialContext
+}
+
+func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) {
+ if dp.OnDialError != nil {
+ return dp.OnDialError
+ }
+ return func(src net.Conn, dstDialErr error) {
+ log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr)
+ src.Close()
+ }
+}
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
new file mode 100644
index 0000000..31019b2
--- /dev/null
+++ b/tcpproxy_test.go
@@ -0,0 +1,260 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "testing"
+)
+
+func TestMatchHTTPHost(t *testing.T) {
+ tests := []struct {
+ name string
+ r io.Reader
+ host string
+ want bool
+ }{
+ {
+ name: "match",
+ r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"),
+ host: "foo.com",
+ want: true,
+ },
+ {
+ name: "no-match",
+ r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"),
+ host: "bar.com",
+ want: false,
+ },
+ {
+ name: "match-huge-request",
+ r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')),
+ host: "foo.com",
+ want: true,
+ },
+ }
+ for i, tt := range tests {
+ name := tt.name
+ if name == "" {
+ name = fmt.Sprintf("test_index_%d", i)
+ }
+ t.Run(name, func(t *testing.T) {
+ br := bufio.NewReader(tt.r)
+ var matcher matcher = httpHostMatch(tt.host)
+ got := matcher.match(br)
+ if got != tt.want {
+ t.Fatalf("match = %v; want %v", got, tt.want)
+ }
+ get := make([]byte, 3)
+ if _, err := io.ReadFull(br, get); err != nil {
+ t.Fatal(err)
+ }
+ if string(get) != "GET" {
+ t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get)
+ }
+ })
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+type recordWritesConn struct {
+ buf bytes.Buffer
+ net.Conn
+}
+
+func (c *recordWritesConn) Write(p []byte) (int, error) {
+ c.buf.Write(p)
+ return len(p), nil
+}
+
+func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF }
+
+func clientHelloRecord(t *testing.T, hostName string) string {
+ rec := new(recordWritesConn)
+ cl := tls.Client(rec, &tls.Config{ServerName: hostName})
+ cl.Handshake()
+
+ s := rec.buf.String()
+ if !strings.Contains(s, hostName) {
+ t.Fatalf("clientHello sent in test didn't contain %q", hostName)
+ }
+ return s
+}
+
+func TestSNI(t *testing.T) {
+ const hostName = "foo.com"
+ greeting := clientHelloRecord(t, hostName)
+ got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting)))
+ if got != hostName {
+ t.Errorf("got SNI %q; want %q", got, hostName)
+ }
+}
+
+func TestProxyStartNone(t *testing.T) {
+ var p Proxy
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ return ln
+}
+
+const testFrontAddr = "1.2.3.4:567"
+
+func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) {
+ return func(network, laddr string) (net.Listener, error) {
+ if network != "tcp" {
+ t.Errorf("got Listen call with network %q, not tcp", network)
+ return nil, errors.New("invalid network")
+ }
+ if laddr != testFrontAddr {
+ t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr)
+ panic("bogus address")
+ }
+ return ln, nil
+ }
+}
+
+func testProxy(t *testing.T, front net.Listener) *Proxy {
+ return &Proxy{
+ ListenFunc: testListenFunc(t, front),
+ }
+}
+
+func TestProxyAlwaysMatch(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, To(back.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ const msg = "message"
+ io.WriteString(toFront, msg)
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxyHTTP(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxySNI(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ msg := clientHelloRecord(t, "bar.com")
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}