diff options
author | Brad Fitzpatrick <[email protected]> | 2017-06-22 15:14:37 -0700 |
---|---|---|
committer | Brad Fitzpatrick <[email protected]> | 2017-06-22 15:14:37 -0700 |
commit | 2eb0155fac2d41b022bc0a430d13aa3b45825f1d (patch) | |
tree | 9d1b725ff351a0a4d8e1a1382a5510df7618c1b2 |
Start of tcpproxy. No Listener or reverse dialing yet.
-rw-r--r-- | LICENSE | 27 | ||||
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | http.go | 97 | ||||
-rw-r--r-- | sni.go | 66 | ||||
-rw-r--r-- | tcpproxy.go | 308 | ||||
-rw-r--r-- | tcpproxy_test.go | 260 |
6 files changed, 761 insertions, 0 deletions
@@ -0,0 +1,27 @@ +Copyright (c) 2017 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..ffc7269 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# tcpproxy + +See https://godoc.org/github.com/bradfitz/tcpproxy/ @@ -0,0 +1,97 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpproxy + +import ( + "bufio" + "bytes" + "net/http" +) + +// AddHTTPHostRoute appends a route to the ipPort listener that says +// if the incoming HTTP/1.x Host header name is httpHost, the +// connection is given to dest. If it doesn't match, rule processing +// continues for any additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { + p.addRoute(ipPort, httpHostMatch(httpHost), dest) +} + +type httpHostMatch string + +func (host httpHostMatch) match(br *bufio.Reader) bool { + return httpHostHeader(br) == string(host) +} + +// httpHostHeader returns the HTTP Host header from br without +// consuming any of its bytes. It returns "" if it can't find one. +func httpHostHeader(br *bufio.Reader) string { + const maxPeek = 4 << 10 + peekSize := 0 + for { + peekSize++ + if peekSize > maxPeek { + b, _ := br.Peek(br.Buffered()) + return httpHostHeaderFromBytes(b) + } + b, err := br.Peek(peekSize) + if n := br.Buffered(); n > peekSize { + b, _ = br.Peek(n) + peekSize = n + } + if len(b) > 0 { + if b[0] < 'A' || b[0] > 'Z' { + // Doesn't look like an HTTP verb + // (GET, POST, etc). + return "" + } + if bytes.Index(b, crlfcrlf) != -1 || bytes.Index(b, lflf) != -1 { + req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b))) + if err != nil { + return "" + } + if len(req.Header["Host"]) > 1 { + // TODO(bradfitz): what does + // ReadRequest do if there are + // multiple Host headers? + return "" + } + return req.Host + } + } + if err != nil { + return httpHostHeaderFromBytes(b) + } + } +} + +var ( + lfHostColon = []byte("\nHost:") + lfhostColon = []byte("\nhost:") + crlf = []byte("\r\n") + lf = []byte("\n") + crlfcrlf = []byte("\r\n\r\n") + lflf = []byte("\n\n") +) + +func httpHostHeaderFromBytes(b []byte) string { + if i := bytes.Index(b, lfHostColon); i != -1 { + return string(bytes.TrimSpace(untilEOL(b[i+len(lfHostColon):]))) + } + if i := bytes.Index(b, lfhostColon); i != -1 { + return string(bytes.TrimSpace(untilEOL(b[i+len(lfhostColon):]))) + } + return "" +} + +// untilEOL returns v, truncated before the first '\n' byte, if any. +// The returned slice may include a '\r' at the end. +func untilEOL(v []byte) []byte { + if i := bytes.IndexByte(v, '\n'); i != -1 { + return v[:i] + } + return v +} @@ -0,0 +1,66 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpproxy + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "net" +) + +// AddSNIRoute appends a route to the ipPort listener that says if the +// incoming TLS SNI server name is sni, the connection is given to +// dest. If it doesn't match, rule processing continues for any +// additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { + p.addRoute(ipPort, sniMatch(sni), dest) +} + +type sniMatch string + +func (sni sniMatch) match(br *bufio.Reader) bool { + return clientHelloServerName(br) == string(sni) +} + +// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// without consuming any bytes from br. +// On any error, the empty string is returned. +func clientHelloServerName(br *bufio.Reader) (sni string) { + const recordHeaderLen = 5 + hdr, err := br.Peek(recordHeaderLen) + if err != nil { + return "" + } + const recordTypeHandshake = 0x16 + if hdr[0] != recordTypeHandshake { + return "" // Not TLS. + } + recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] + helloBytes, err := br.Peek(recordHeaderLen + recLen) + if err != nil { + return "" + } + tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + sni = hello.ServerName + return nil, nil + }, + }).Handshake() + return +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } diff --git a/tcpproxy.go b/tcpproxy.go new file mode 100644 index 0000000..fde438f --- /dev/null +++ b/tcpproxy.go @@ -0,0 +1,308 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tcpproxy lets users build TCP, HTTP/1, and TLS+SNI proxies. +// +// Typical usage: +// +// var p tcpproxy.Proxy +// p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081")) +// p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082")) +// p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback +// p.AddSNIHostRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431")) +// p.AddSNIHostRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432")) +// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback +// log.Fatal(p.Run()) +package tcpproxy + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "log" + "net" + "time" +) + +// Proxy is a proxy. Its zero value is a valid proxy that does +// nothing. Call methods to add routes before calling Run. +type Proxy struct { + routes map[string][]route // ip:port => route + + lns []net.Listener + donec chan struct{} // closed before err + err error // any error from listening + + // ListenFunc optionally specifies an alternate listen + // function. If nil, net.Dial is used. + // The provided net is always "tcp". + ListenFunc func(net, laddr string) (net.Listener, error) +} + +type route struct { + matcher matcher + target Target +} + +type matcher interface { + match(*bufio.Reader) bool +} + +func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { + if p.ListenFunc != nil { + return p.ListenFunc + } + return net.Listen +} + +func (p *Proxy) addRoute(ipPort string, matcher matcher, dest Target) { + if p.routes == nil { + p.routes = make(map[string][]route) + } + p.routes[ipPort] = append(p.routes[ipPort], route{matcher, dest}) +} + +// AddRoute appends an always-matching route to the ipPort listener, +// directing any connection to dest. +// +// This is generally used as either the only rule (for simple TCP +// proxies), or as the final fallback rule for an ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddRoute(ipPort string, dest Target) { + p.addRoute(ipPort, alwaysMatch{}, dest) +} + +type alwaysMatch struct{} + +func (alwaysMatch) match(*bufio.Reader) bool { return true } + +// Run is calls Start, and then Wait. +// +// It blocks until there's an error. The return value is always +// non-nil. +func (p *Proxy) Run() error { + if err := p.Start(); err != nil { + return err + } + return p.Wait() +} + +// Wait waits for the Proxy to finish running. Currently this can only +// happen if a Listener is closed, or Close is called on the proxy. +// +// It is only valid to call Wait after a successful call to Start. +func (p *Proxy) Wait() error { + <-p.donec + return p.err +} + +// Close closes all the proxy's self-opened listeners. +func (p *Proxy) Close() error { + for _, c := range p.lns { + c.Close() + } + return nil +} + +// Start creates a TCP listener for each unique ipPort from the +// previously created routes and starts the proxy. It returns any +// error from starting listeners. +// +// If it returns a non-nil error, any successfully opened listeners +// are closed. +func (p *Proxy) Start() error { + if p.donec != nil { + return errors.New("already started") + } + p.donec = make(chan struct{}) + errc := make(chan error, len(p.routes)) + p.lns = make([]net.Listener, 0, len(p.routes)) + for ipPort, routes := range p.routes { + ln, err := p.netListen()("tcp", ipPort) + if err != nil { + p.Close() + return err + } + p.lns = append(p.lns, ln) + go p.serveListener(errc, ln, routes) + } + go p.awaitFirstError(errc) + return nil +} + +func (p *Proxy) awaitFirstError(errc <-chan error) { + p.err = <-errc + close(p.donec) +} + +func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) { + for { + c, err := ln.Accept() + if err != nil { + ret <- err + return + } + go p.serveConn(c, routes) + } +} + +// serveConn runs in its own goroutine and matches c against routes. +// It returns whether it matched purely for testing. +func (p *Proxy) serveConn(c net.Conn, routes []route) bool { + br := bufio.NewReader(c) + for _, route := range routes { + if route.matcher.match(br) { + buffered, _ := br.Peek(br.Buffered()) + route.target.HandleConn(changeReaderConn{ + r: io.MultiReader(bytes.NewReader(buffered), c), + Conn: c, + }, c) + return true + } + } + // TODO: hook for this? + log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String()) + c.Close() + return false +} + +// changeReaderConn is a net.Conn wrapper with a separate reader function. +type changeReaderConn struct { + r io.Reader + net.Conn +} + +func (c changeReaderConn) Read(p []byte) (int, error) { return c.r.Read(p) } + +// Target is what an incoming matched connection is sent to. +type Target interface { + // HandleConn is called when an incoming connection is + // matched. After the call to HandleConn, the tcpproxy + // package never touches the conn again. Implementations are + // responsible for closing the conn when needed. + // + // The c Conn acts like a new conn, without any bytes consumed, + // but it has an unexported concrete type and cannot be type + // asserted to *net.TCPConn, etc. + // + // The rawConn represents the underlying connections (with + // some bytes removed) and should only be used for type + // assertions and setting deadlines, not reading. + HandleConn(c net.Conn, rawConn net.Conn) +} + +// To is shorthand way of writing &DialProxy{Addr: addr}. +func To(addr string) *DialProxy { + return &DialProxy{Addr: addr} +} + +// DialProxy implements Target by dialing a new connection to Addr +// and then proxying data back and forth. +// +// The To func is a shorthand way of creating a DialProxy. +type DialProxy struct { + // Addr is the TCP address to proxy to. + Addr string + + // KeepAlivePeriod sets the period between TCP keep alives. + // If zero, a default is used. To disable, use a negative number. + // The keep-alive is used for both the client connection and + KeepAlivePeriod time.Duration + + // DialTimeout optionally specifies a dial timeout. + // If zero, a default is used. + // If negative, the timeout is disabled. + DialTimeout time.Duration + + // DialContext optionally specifies an alternate dial function + // for TCP targets. If nil, the standard + // net.Dialer.DialContext method is used. + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + + // OnDialError optionally specifies an alternate way to handle errors dialing Addr. + // If nil, the error is logged and src is closed. + // If non-nil, src is not closed automatically. + OnDialError func(src net.Conn, dstDialErr error) +} + +func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) { + ctx := context.Background() + var cancel context.CancelFunc + if dp.DialTimeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout()) + } + dst, err := dp.dialContext()(ctx, "tcp", dp.Addr) + if cancel != nil { + cancel() + } + if err != nil { + dp.onDialError()(src, err) + return + } + defer src.Close() + defer dst.Close() + if ka := dp.keepAlivePeriod(); ka > 0 { + if c, ok := rawSrc.(*net.TCPConn); ok { + c.SetKeepAlive(true) + c.SetKeepAlivePeriod(ka) + } + if c, ok := dst.(*net.TCPConn); ok { + c.SetKeepAlive(true) + c.SetKeepAlivePeriod(ka) + } + } + errc := make(chan error, 1) + go proxyCopy(errc, src, dst) + go proxyCopy(errc, dst, src) + <-errc +} + +// proxyCopy is the function that copies bytes around. +// It's a named function instead of a func literal so users get +// named goroutines in debug goroutine stack dumps. +func proxyCopy(errc chan<- error, dst io.Writer, src io.Reader) { + // TODO: make caller switch from src to rawSrc after N bytes (e.g. 4KB) + // if the io.Copy optimization to switch to Linux splice happens. + // TODO: if the runtime provides a way to wait for + // readability, use that to avoid stranding big blocks of + // memory blocked in idle reads. + _, err := io.Copy(dst, src) + errc <- err +} + +func (dp *DialProxy) keepAlivePeriod() time.Duration { + if dp.KeepAlivePeriod != 0 { + return dp.KeepAlivePeriod + } + return time.Minute +} + +func (dp *DialProxy) dialTimeout() time.Duration { + if dp.DialTimeout > 0 { + return dp.DialTimeout + } + return 10 * time.Second +} + +var defaultDialer = new(net.Dialer) + +func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) { + if dp.DialContext != nil { + return dp.DialContext + } + return defaultDialer.DialContext +} + +func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) { + if dp.OnDialError != nil { + return dp.OnDialError + } + return func(src net.Conn, dstDialErr error) { + log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr) + src.Close() + } +} diff --git a/tcpproxy_test.go b/tcpproxy_test.go new file mode 100644 index 0000000..31019b2 --- /dev/null +++ b/tcpproxy_test.go @@ -0,0 +1,260 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpproxy + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "strings" + "testing" +) + +func TestMatchHTTPHost(t *testing.T) { + tests := []struct { + name string + r io.Reader + host string + want bool + }{ + { + name: "match", + r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), + host: "foo.com", + want: true, + }, + { + name: "no-match", + r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), + host: "bar.com", + want: false, + }, + { + name: "match-huge-request", + r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')), + host: "foo.com", + want: true, + }, + } + for i, tt := range tests { + name := tt.name + if name == "" { + name = fmt.Sprintf("test_index_%d", i) + } + t.Run(name, func(t *testing.T) { + br := bufio.NewReader(tt.r) + var matcher matcher = httpHostMatch(tt.host) + got := matcher.match(br) + if got != tt.want { + t.Fatalf("match = %v; want %v", got, tt.want) + } + get := make([]byte, 3) + if _, err := io.ReadFull(br, get); err != nil { + t.Fatal(err) + } + if string(get) != "GET" { + t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get) + } + }) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +type recordWritesConn struct { + buf bytes.Buffer + net.Conn +} + +func (c *recordWritesConn) Write(p []byte) (int, error) { + c.buf.Write(p) + return len(p), nil +} + +func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF } + +func clientHelloRecord(t *testing.T, hostName string) string { + rec := new(recordWritesConn) + cl := tls.Client(rec, &tls.Config{ServerName: hostName}) + cl.Handshake() + + s := rec.buf.String() + if !strings.Contains(s, hostName) { + t.Fatalf("clientHello sent in test didn't contain %q", hostName) + } + return s +} + +func TestSNI(t *testing.T) { + const hostName = "foo.com" + greeting := clientHelloRecord(t, hostName) + got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting))) + if got != hostName { + t.Errorf("got SNI %q; want %q", got, hostName) + } +} + +func TestProxyStartNone(t *testing.T) { + var p Proxy + if err := p.Start(); err != nil { + t.Fatal(err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp", "[::1]:0") + if err != nil { + t.Fatal(err) + } + } + return ln +} + +const testFrontAddr = "1.2.3.4:567" + +func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) { + return func(network, laddr string) (net.Listener, error) { + if network != "tcp" { + t.Errorf("got Listen call with network %q, not tcp", network) + return nil, errors.New("invalid network") + } + if laddr != testFrontAddr { + t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr) + panic("bogus address") + } + return ln, nil + } +} + +func testProxy(t *testing.T, front net.Listener) *Proxy { + return &Proxy{ + ListenFunc: testListenFunc(t, front), + } +} + +func TestProxyAlwaysMatch(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + const msg = "message" + io.WriteString(toFront, msg) + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxyHTTP(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxySNI(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + msg := clientHelloRecord(t, "bar.com") + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} |