summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2017-07-05 14:04:48 -0700
committerDavid Anderson <[email protected]>2017-07-05 14:04:48 -0700
commitbef9f6aa62487d4adc7d8ddf9e29b9f28810316f (patch)
treecfe400a24e8820af3866a4c5f008335a7090ce3a
parentd86e96a9d54bb62b297cf30dd2242b365fe33604 (diff)
parent9e73877b6b356885077a1b9c0ba349ce33c61438 (diff)
Merge bradfitz's tcpproxy codebase with the software formerly known as tlsrouter.
Brad's code will be the place for future development, and the base for the binary formerly known as tlsrouter. This merge is the first step towards converging the codebases.
-rw-r--r--README.md3
-rw-r--r--http.go107
-rw-r--r--listener.go103
-rw-r--r--listener_test.go50
-rw-r--r--sni.go76
-rw-r--r--tcpproxy.go374
-rw-r--r--tcpproxy_test.go270
7 files changed, 983 insertions, 0 deletions
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..ffc7269
--- /dev/null
+++ b/README.md
@@ -0,0 +1,3 @@
+# tcpproxy
+
+See https://godoc.org/github.com/bradfitz/tcpproxy/
diff --git a/http.go b/http.go
new file mode 100644
index 0000000..423c9a9
--- /dev/null
+++ b/http.go
@@ -0,0 +1,107 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "net/http"
+)
+
+// AddHTTPHostRoute appends a route to the ipPort listener that says
+// if the incoming HTTP/1.x Host header name is httpHost, the
+// connection is given to dest. If it doesn't match, rule processing
+// continues for any additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
+ p.addRoute(ipPort, httpHostMatch(httpHost), dest)
+}
+
+type httpHostMatch string
+
+func (host httpHostMatch) match(br *bufio.Reader) bool {
+ return httpHostHeader(br) == string(host)
+}
+
+// httpHostHeader returns the HTTP Host header from br without
+// consuming any of its bytes. It returns "" if it can't find one.
+func httpHostHeader(br *bufio.Reader) string {
+ const maxPeek = 4 << 10
+ peekSize := 0
+ for {
+ peekSize++
+ if peekSize > maxPeek {
+ b, _ := br.Peek(br.Buffered())
+ return httpHostHeaderFromBytes(b)
+ }
+ b, err := br.Peek(peekSize)
+ if n := br.Buffered(); n > peekSize {
+ b, _ = br.Peek(n)
+ peekSize = n
+ }
+ if len(b) > 0 {
+ if b[0] < 'A' || b[0] > 'Z' {
+ // Doesn't look like an HTTP verb
+ // (GET, POST, etc).
+ return ""
+ }
+ if bytes.Index(b, crlfcrlf) != -1 || bytes.Index(b, lflf) != -1 {
+ req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b)))
+ if err != nil {
+ return ""
+ }
+ if len(req.Header["Host"]) > 1 {
+ // TODO(bradfitz): what does
+ // ReadRequest do if there are
+ // multiple Host headers?
+ return ""
+ }
+ return req.Host
+ }
+ }
+ if err != nil {
+ return httpHostHeaderFromBytes(b)
+ }
+ }
+}
+
+var (
+ lfHostColon = []byte("\nHost:")
+ lfhostColon = []byte("\nhost:")
+ crlf = []byte("\r\n")
+ lf = []byte("\n")
+ crlfcrlf = []byte("\r\n\r\n")
+ lflf = []byte("\n\n")
+)
+
+func httpHostHeaderFromBytes(b []byte) string {
+ if i := bytes.Index(b, lfHostColon); i != -1 {
+ return string(bytes.TrimSpace(untilEOL(b[i+len(lfHostColon):])))
+ }
+ if i := bytes.Index(b, lfhostColon); i != -1 {
+ return string(bytes.TrimSpace(untilEOL(b[i+len(lfhostColon):])))
+ }
+ return ""
+}
+
+// untilEOL returns v, truncated before the first '\n' byte, if any.
+// The returned slice may include a '\r' at the end.
+func untilEOL(v []byte) []byte {
+ if i := bytes.IndexByte(v, '\n'); i != -1 {
+ return v[:i]
+ }
+ return v
+}
diff --git a/listener.go b/listener.go
new file mode 100644
index 0000000..2fd29eb
--- /dev/null
+++ b/listener.go
@@ -0,0 +1,103 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "io"
+ "net"
+ "sync"
+)
+
+// TargetListener implements both net.Listener and Target.
+// Matched Targets become accepted connections.
+type TargetListener struct {
+ Address string // Address is the string reported by TargetListener.Addr().String().
+
+ mu sync.Mutex
+ cond *sync.Cond
+ closed bool
+ nextConn net.Conn
+}
+
+var (
+ _ net.Listener = (*TargetListener)(nil)
+ _ Target = (*TargetListener)(nil)
+)
+
+func (tl *TargetListener) lock() {
+ tl.mu.Lock()
+ if tl.cond == nil {
+ tl.cond = sync.NewCond(&tl.mu)
+ }
+}
+
+type tcpAddr string
+
+func (a tcpAddr) Network() string { return "tcp" }
+func (a tcpAddr) String() string { return string(a) }
+
+func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) }
+
+func (tl *TargetListener) Close() error {
+ tl.lock()
+ if tl.closed {
+ tl.mu.Unlock()
+ return nil
+ }
+ tl.closed = true
+ tl.mu.Unlock()
+ tl.cond.Broadcast()
+ return nil
+}
+
+// HandleConn implements the Target interface. It blocks until tl is
+// closed or another goroutine has called Accept and received c.
+func (tl *TargetListener) HandleConn(c net.Conn) {
+ tl.lock()
+ defer tl.mu.Unlock()
+ for tl.nextConn != nil && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ c.Close()
+ return
+ }
+ tl.nextConn = c
+ tl.cond.Broadcast() // Signal might be sufficient; verify.
+ for tl.nextConn == c && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ c.Close()
+ return
+ }
+}
+
+func (tl *TargetListener) Accept() (net.Conn, error) {
+ tl.lock()
+ for tl.nextConn == nil && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ tl.mu.Unlock()
+ return nil, io.EOF
+ }
+ c := tl.nextConn
+ tl.nextConn = nil
+ tl.mu.Unlock()
+ tl.cond.Broadcast() // Signal might be sufficient; verify.
+
+ return c, nil
+}
diff --git a/listener_test.go b/listener_test.go
new file mode 100644
index 0000000..35f888e
--- /dev/null
+++ b/listener_test.go
@@ -0,0 +1,50 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "io"
+ "testing"
+)
+
+func TestListenerAccept(t *testing.T) {
+ tl := new(TargetListener)
+ ch := make(chan interface{}, 1)
+ go func() {
+ for {
+ conn, err := tl.Accept()
+ if err != nil {
+ ch <- err
+ return
+ } else {
+ ch <- conn
+ }
+ }
+ }()
+
+ for i := 0; i < 3; i++ {
+ conn := new(Conn)
+ tl.HandleConn(conn)
+ got := <-ch
+ if got != conn {
+ t.Errorf("Accept conn = %v; want %v", got, conn)
+ }
+ }
+ tl.Close()
+ got := <-ch
+ if got != io.EOF {
+ t.Errorf("Accept error post-Close = %v; want io.EOF", got)
+ }
+}
diff --git a/sni.go b/sni.go
new file mode 100644
index 0000000..e12c744
--- /dev/null
+++ b/sni.go
@@ -0,0 +1,76 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "io"
+ "net"
+)
+
+// AddSNIRoute appends a route to the ipPort listener that says if the
+// incoming TLS SNI server name is sni, the connection is given to
+// dest. If it doesn't match, rule processing continues for any
+// additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ p.addRoute(ipPort, sniMatch(sni), dest)
+}
+
+type sniMatch string
+
+func (sni sniMatch) match(br *bufio.Reader) bool {
+ return clientHelloServerName(br) == string(sni)
+}
+
+// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
+// without consuming any bytes from br.
+// On any error, the empty string is returned.
+func clientHelloServerName(br *bufio.Reader) (sni string) {
+ const recordHeaderLen = 5
+ hdr, err := br.Peek(recordHeaderLen)
+ if err != nil {
+ return ""
+ }
+ const recordTypeHandshake = 0x16
+ if hdr[0] != recordTypeHandshake {
+ return "" // Not TLS.
+ }
+ recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
+ helloBytes, err := br.Peek(recordHeaderLen + recLen)
+ if err != nil {
+ return ""
+ }
+ tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
+ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
+ sni = hello.ServerName
+ return nil, nil
+ },
+ }).Handshake()
+ return
+}
+
+// sniSniffConn is a net.Conn that reads from r, fails on Writes,
+// and crashes otherwise.
+type sniSniffConn struct {
+ r io.Reader
+ net.Conn // nil; crash on any unexpected use
+}
+
+func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }
diff --git a/tcpproxy.go b/tcpproxy.go
new file mode 100644
index 0000000..8a520f6
--- /dev/null
+++ b/tcpproxy.go
@@ -0,0 +1,374 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcpproxy lets users build TCP proxies, optionally making
+// routing decisions based on HTTP/1 Host headers and the SNI hostname
+// in TLS connections.
+//
+// Typical usage:
+//
+// var p tcpproxy.Proxy
+// p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081"))
+// p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082"))
+// p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback
+// p.AddSNIRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431"))
+// p.AddSNIRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432"))
+// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback
+// log.Fatal(p.Run())
+//
+// Calling Run (or Start) on a proxy also starts all the necessary
+// listeners.
+//
+// For each accepted connection, the rules for that ipPort are
+// matched, in order. If one matches (currently HTTP Host, SNI, or
+// always), then the connection is handed to the target.
+//
+// The two predefined Target implementations are:
+//
+// 1) DialProxy, proxying to another address (use the To func to return a
+// DialProxy value),
+//
+// 2) TargetListener, making the matched connection available via a
+// net.Listener.Accept call.
+//
+// But Target is an interface, so you can also write your own.
+//
+// Note that tcpproxy does not do any TLS encryption or decryption. It
+// only (via DialProxy) copies bytes around. The SNI hostname in the TLS
+// header is unencrypted, for better or worse.
+//
+// This package makes no API stability promises. If you depend on it,
+// vendor it.
+package tcpproxy
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "io"
+ "log"
+ "net"
+ "time"
+)
+
+// Proxy is a proxy. Its zero value is a valid proxy that does
+// nothing. Call methods to add routes before calling Start or Run.
+//
+// The order that routes are added in matters; each is matched in the order
+// registered.
+type Proxy struct {
+ routes map[string][]route // ip:port => route
+
+ lns []net.Listener
+ donec chan struct{} // closed before err
+ err error // any error from listening
+
+ // ListenFunc optionally specifies an alternate listen
+ // function. If nil, net.Dial is used.
+ // The provided net is always "tcp".
+ ListenFunc func(net, laddr string) (net.Listener, error)
+}
+
+type route struct {
+ matcher matcher
+ target Target
+}
+
+type matcher interface {
+ match(*bufio.Reader) bool
+}
+
+func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
+ if p.ListenFunc != nil {
+ return p.ListenFunc
+ }
+ return net.Listen
+}
+
+func (p *Proxy) addRoute(ipPort string, matcher matcher, dest Target) {
+ if p.routes == nil {
+ p.routes = make(map[string][]route)
+ }
+ p.routes[ipPort] = append(p.routes[ipPort], route{matcher, dest})
+}
+
+// AddRoute appends an always-matching route to the ipPort listener,
+// directing any connection to dest.
+//
+// This is generally used as either the only rule (for simple TCP
+// proxies), or as the final fallback rule for an ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddRoute(ipPort string, dest Target) {
+ p.addRoute(ipPort, alwaysMatch{}, dest)
+}
+
+type alwaysMatch struct{}
+
+func (alwaysMatch) match(*bufio.Reader) bool { return true }
+
+// Run is calls Start, and then Wait.
+//
+// It blocks until there's an error. The return value is always
+// non-nil.
+func (p *Proxy) Run() error {
+ if err := p.Start(); err != nil {
+ return err
+ }
+ return p.Wait()
+}
+
+// Wait waits for the Proxy to finish running. Currently this can only
+// happen if a Listener is closed, or Close is called on the proxy.
+//
+// It is only valid to call Wait after a successful call to Start.
+func (p *Proxy) Wait() error {
+ <-p.donec
+ return p.err
+}
+
+// Close closes all the proxy's self-opened listeners.
+func (p *Proxy) Close() error {
+ for _, c := range p.lns {
+ c.Close()
+ }
+ return nil
+}
+
+// Start creates a TCP listener for each unique ipPort from the
+// previously created routes and starts the proxy. It returns any
+// error from starting listeners.
+//
+// If it returns a non-nil error, any successfully opened listeners
+// are closed.
+func (p *Proxy) Start() error {
+ if p.donec != nil {
+ return errors.New("already started")
+ }
+ p.donec = make(chan struct{})
+ errc := make(chan error, len(p.routes))
+ p.lns = make([]net.Listener, 0, len(p.routes))
+ for ipPort, routes := range p.routes {
+ ln, err := p.netListen()("tcp", ipPort)
+ if err != nil {
+ p.Close()
+ return err
+ }
+ p.lns = append(p.lns, ln)
+ go p.serveListener(errc, ln, routes)
+ }
+ go p.awaitFirstError(errc)
+ return nil
+}
+
+func (p *Proxy) awaitFirstError(errc <-chan error) {
+ p.err = <-errc
+ close(p.donec)
+}
+
+func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ ret <- err
+ return
+ }
+ go p.serveConn(c, routes)
+ }
+}
+
+// serveConn runs in its own goroutine and matches c against routes.
+// It returns whether it matched purely for testing.
+func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
+ br := bufio.NewReader(c)
+ for _, route := range routes {
+ if route.matcher.match(br) {
+ if n := br.Buffered(); n > 0 {
+ peeked, _ := br.Peek(br.Buffered())
+ c = &Conn{
+ Peeked: peeked,
+ Conn: c,
+ }
+ }
+ route.target.HandleConn(c)
+ return true
+ }
+ }
+ // TODO: hook for this?
+ log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
+ c.Close()
+ return false
+}
+
+// Conn is an incoming connection that has had some bytes read from it
+// to determine how to route the connection. The Read method stitches
+// the peeked bytes and unread bytes back together.
+type Conn struct {
+ // Peeked are the bytes that have been read from Conn for the
+ // purposes of route matching, but have not yet been consumed
+ // by Read calls. It set to nil by Read when fully consumed.
+ Peeked []byte
+
+ // Conn is the underlying connection.
+ // It can be type asserted against *net.TCPConn or other types
+ // as needed. It should not be read from directly unless
+ // Peeked is nil.
+ net.Conn
+}
+
+func (c *Conn) Read(p []byte) (n int, err error) {
+ if len(c.Peeked) > 0 {
+ n = copy(p, c.Peeked)
+ c.Peeked = c.Peeked[n:]
+ if len(c.Peeked) == 0 {
+ c.Peeked = nil
+ }
+ return n, nil
+ }
+ return c.Conn.Read(p)
+}
+
+// Target is what an incoming matched connection is sent to.
+type Target interface {
+ // HandleConn is called when an incoming connection is
+ // matched. After the call to HandleConn, the tcpproxy
+ // package never touches the conn again. Implementations are
+ // responsible for closing the connection when needed.
+ //
+ // The concrete type of conn will be of type *Conn if any
+ // bytes have been consumed for the purposes of route
+ // matching.
+ HandleConn(net.Conn)
+}
+
+// To is shorthand way of writing &tlsproxy.DialProxy{Addr: addr}.
+func To(addr string) *DialProxy {
+ return &DialProxy{Addr: addr}
+}
+
+// DialProxy implements Target by dialing a new connection to Addr
+// and then proxying data back and forth.
+//
+// The To func is a shorthand way of creating a DialProxy.
+type DialProxy struct {
+ // Addr is the TCP address to proxy to.
+ Addr string
+
+ // KeepAlivePeriod sets the period between TCP keep alives.
+ // If zero, a default is used. To disable, use a negative number.
+ // The keep-alive is used for both the client connection and
+ KeepAlivePeriod time.Duration
+
+ // DialTimeout optionally specifies a dial timeout.
+ // If zero, a default is used.
+ // If negative, the timeout is disabled.
+ DialTimeout time.Duration
+
+ // DialContext optionally specifies an alternate dial function
+ // for TCP targets. If nil, the standard
+ // net.Dialer.DialContext method is used.
+ DialContext func(ctx context.Context, network, address string) (net.Conn, error)
+
+ // OnDialError optionally specifies an alternate way to handle errors dialing Addr.
+ // If nil, the error is logged and src is closed.
+ // If non-nil, src is not closed automatically.
+ OnDialError func(src net.Conn, dstDialErr error)
+}
+
+// UnderlyingConn returns c.Conn if c of type *Conn,
+// otherwise it returns c.
+func UnderlyingConn(c net.Conn) net.Conn {
+ if wrap, ok := c.(*Conn); ok {
+ return wrap.Conn
+ }
+ return c
+}
+
+func (dp *DialProxy) HandleConn(src net.Conn) {
+ ctx := context.Background()
+ var cancel context.CancelFunc
+ if dp.DialTimeout >= 0 {
+ ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout())
+ }
+ dst, err := dp.dialContext()(ctx, "tcp", dp.Addr)
+ if cancel != nil {
+ cancel()
+ }
+ if err != nil {
+ dp.onDialError()(src, err)
+ return
+ }
+ defer src.Close()
+ defer dst.Close()
+ if ka := dp.keepAlivePeriod(); ka > 0 {
+ if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
+ if c, ok := dst.(*net.TCPConn); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
+ }
+ errc := make(chan error, 1)
+ go proxyCopy(errc, src, dst)
+ go proxyCopy(errc, dst, src)
+ <-errc
+}
+
+// proxyCopy is the function that copies bytes around.
+// It's a named function instead of a func literal so users get
+// named goroutines in debug goroutine stack dumps.
+func proxyCopy(errc chan<- error, dst io.Writer, src io.Reader) {
+ // TODO: make caller switch from src to rawSrc after N bytes (e.g. 4KB)
+ // if the io.Copy optimization to switch to Linux splice happens.
+ // TODO: if the runtime provides a way to wait for
+ // readability, use that to avoid stranding big blocks of
+ // memory blocked in idle reads.
+ _, err := io.Copy(dst, src)
+ errc <- err
+}
+
+func (dp *DialProxy) keepAlivePeriod() time.Duration {
+ if dp.KeepAlivePeriod != 0 {
+ return dp.KeepAlivePeriod
+ }
+ return time.Minute
+}
+
+func (dp *DialProxy) dialTimeout() time.Duration {
+ if dp.DialTimeout > 0 {
+ return dp.DialTimeout
+ }
+ return 10 * time.Second
+}
+
+var defaultDialer = new(net.Dialer)
+
+func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) {
+ if dp.DialContext != nil {
+ return dp.DialContext
+ }
+ return defaultDialer.DialContext
+}
+
+func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) {
+ if dp.OnDialError != nil {
+ return dp.OnDialError
+ }
+ return func(src net.Conn, dstDialErr error) {
+ log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr)
+ src.Close()
+ }
+}
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
new file mode 100644
index 0000000..4cfb0ab
--- /dev/null
+++ b/tcpproxy_test.go
@@ -0,0 +1,270 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "testing"
+)
+
+func TestMatchHTTPHost(t *testing.T) {
+ tests := []struct {
+ name string
+ r io.Reader
+ host string
+ want bool
+ }{
+ {
+ name: "match",
+ r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"),
+ host: "foo.com",
+ want: true,
+ },
+ {
+ name: "no-match",
+ r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"),
+ host: "bar.com",
+ want: false,
+ },
+ {
+ name: "match-huge-request",
+ r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')),
+ host: "foo.com",
+ want: true,
+ },
+ }
+ for i, tt := range tests {
+ name := tt.name
+ if name == "" {
+ name = fmt.Sprintf("test_index_%d", i)
+ }
+ t.Run(name, func(t *testing.T) {
+ br := bufio.NewReader(tt.r)
+ var matcher matcher = httpHostMatch(tt.host)
+ got := matcher.match(br)
+ if got != tt.want {
+ t.Fatalf("match = %v; want %v", got, tt.want)
+ }
+ get := make([]byte, 3)
+ if _, err := io.ReadFull(br, get); err != nil {
+ t.Fatal(err)
+ }
+ if string(get) != "GET" {
+ t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get)
+ }
+ })
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+type recordWritesConn struct {
+ buf bytes.Buffer
+ net.Conn
+}
+
+func (c *recordWritesConn) Write(p []byte) (int, error) {
+ c.buf.Write(p)
+ return len(p), nil
+}
+
+func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF }
+
+func clientHelloRecord(t *testing.T, hostName string) string {
+ rec := new(recordWritesConn)
+ cl := tls.Client(rec, &tls.Config{ServerName: hostName})
+ cl.Handshake()
+
+ s := rec.buf.String()
+ if !strings.Contains(s, hostName) {
+ t.Fatalf("clientHello sent in test didn't contain %q", hostName)
+ }
+ return s
+}
+
+func TestSNI(t *testing.T) {
+ const hostName = "foo.com"
+ greeting := clientHelloRecord(t, hostName)
+ got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting)))
+ if got != hostName {
+ t.Errorf("got SNI %q; want %q", got, hostName)
+ }
+}
+
+func TestProxyStartNone(t *testing.T) {
+ var p Proxy
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ return ln
+}
+
+const testFrontAddr = "1.2.3.4:567"
+
+func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) {
+ return func(network, laddr string) (net.Listener, error) {
+ if network != "tcp" {
+ t.Errorf("got Listen call with network %q, not tcp", network)
+ return nil, errors.New("invalid network")
+ }
+ if laddr != testFrontAddr {
+ t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr)
+ panic("bogus address")
+ }
+ return ln, nil
+ }
+}
+
+func testProxy(t *testing.T, front net.Listener) *Proxy {
+ return &Proxy{
+ ListenFunc: testListenFunc(t, front),
+ }
+}
+
+func TestProxyAlwaysMatch(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, To(back.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ const msg = "message"
+ io.WriteString(toFront, msg)
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxyHTTP(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
+func TestProxySNI(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ msg := clientHelloRecord(t, "bar.com")
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}