diff options
author | Brad Fitzpatrick <[email protected]> | 2017-06-22 16:33:27 -0700 |
---|---|---|
committer | Brad Fitzpatrick <[email protected]> | 2017-06-22 16:33:27 -0700 |
commit | e3878897bde4f5d532f67738009cf1b9fcd2f408 (patch) | |
tree | 02f1bc882d95a041a1bd9a85d95ae367ba7930bd | |
parent | 2eb0155fac2d41b022bc0a430d13aa3b45825f1d (diff) |
Add TargetListener
-rw-r--r-- | listener.go | 93 | ||||
-rw-r--r-- | listener_test.go | 40 | ||||
-rw-r--r-- | tcpproxy.go | 101 |
3 files changed, 210 insertions, 24 deletions
diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..6b0442d --- /dev/null +++ b/listener.go @@ -0,0 +1,93 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpproxy + +import ( + "io" + "net" + "sync" +) + +// TargetListener implements both net.Listener and Target. +// Matched Targets become accepted connections. +type TargetListener struct { + Address string // Address is the string reported by TargetListener.Addr().String(). + + mu sync.Mutex + cond *sync.Cond + closed bool + nextConn net.Conn +} + +var ( + _ net.Listener = (*TargetListener)(nil) + _ Target = (*TargetListener)(nil) +) + +func (tl *TargetListener) lock() { + tl.mu.Lock() + if tl.cond == nil { + tl.cond = sync.NewCond(&tl.mu) + } +} + +type tcpAddr string + +func (a tcpAddr) Network() string { return "tcp" } +func (a tcpAddr) String() string { return string(a) } + +func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) } + +func (tl *TargetListener) Close() error { + tl.lock() + if tl.closed { + tl.mu.Unlock() + return nil + } + tl.closed = true + tl.mu.Unlock() + tl.cond.Broadcast() + return nil +} + +// HandleConn implements the Target interface. It blocks until tl is +// closed or another goroutine has called Accept and received c. +func (tl *TargetListener) HandleConn(c net.Conn) { + tl.lock() + defer tl.mu.Unlock() + for tl.nextConn != nil && !tl.closed { + tl.cond.Wait() + } + if tl.closed { + c.Close() + return + } + tl.nextConn = c + tl.cond.Broadcast() // Signal might be sufficient; verify. + for tl.nextConn == c && !tl.closed { + tl.cond.Wait() + } + if tl.closed { + c.Close() + return + } +} + +func (tl *TargetListener) Accept() (net.Conn, error) { + tl.lock() + for tl.nextConn == nil && !tl.closed { + tl.cond.Wait() + } + if tl.closed { + tl.mu.Unlock() + return nil, io.EOF + } + c := tl.nextConn + tl.nextConn = nil + tl.mu.Unlock() + tl.cond.Broadcast() // Signal might be sufficient; verify. + + return c, nil +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 0000000..6b713f1 --- /dev/null +++ b/listener_test.go @@ -0,0 +1,40 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpproxy + +import ( + "io" + "testing" +) + +func TestListenerAccept(t *testing.T) { + tl := new(TargetListener) + ch := make(chan interface{}, 1) + go func() { + for { + conn, err := tl.Accept() + if err != nil { + ch <- err + return + } else { + ch <- conn + } + } + }() + + for i := 0; i < 3; i++ { + conn := new(Conn) + tl.HandleConn(conn) + got := <-ch + if got != conn { + t.Errorf("Accept conn = %v; want %v", got, conn) + } + } + tl.Close() + got := <-ch + if got != io.EOF { + t.Errorf("Accept error post-Close = %v; want io.EOF", got) + } +} diff --git a/tcpproxy.go b/tcpproxy.go index fde438f..012684c 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -2,7 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package tcpproxy lets users build TCP, HTTP/1, and TLS+SNI proxies. +// Package tcpproxy lets users build TCP proxies, optionally making +// routing decisions based on HTTP/1 Host headers and the SNI hostname +// in TLS connections. // // Typical usage: // @@ -14,11 +16,31 @@ // p.AddSNIHostRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432")) // p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback // log.Fatal(p.Run()) +// +// Calling Run (or Start) on a proxy also starts all the necessary +// listeners. +// +// For each accepted connection, the rules for that ipPort are +// matched, in order. If one matches (currently HTTP Host, SNI, or +// always), then the connection is handed to the target. +// +// The two predefined Target implementations are: +// +// 1) DialProxy, proxying to another address (use the To func to return a +// DialProxy value), +// +// 2) TargetListener, making the matched connection available via a +// net.Listener.Accept call. +// +// But Target is an interface, so you can also write your own. +// +// Note that tcpproxy does not do any TLS encryption or decryption. It +// only (via DialProxy) copies bytes around. The SNI hostname in the TLS +// header is unencrypted, for better or worse. package tcpproxy import ( "bufio" - "bytes" "context" "errors" "io" @@ -28,7 +50,10 @@ import ( ) // Proxy is a proxy. Its zero value is a valid proxy that does -// nothing. Call methods to add routes before calling Run. +// nothing. Call methods to add routes before calling Start or Run. +// +// The order that routes are added in matters; each is matched in the order +// registered. type Proxy struct { routes map[string][]route // ip:port => route @@ -156,11 +181,14 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool { br := bufio.NewReader(c) for _, route := range routes { if route.matcher.match(br) { - buffered, _ := br.Peek(br.Buffered()) - route.target.HandleConn(changeReaderConn{ - r: io.MultiReader(bytes.NewReader(buffered), c), - Conn: c, - }, c) + if n := br.Buffered(); n > 0 { + peeked, _ := br.Peek(br.Buffered()) + c = &Conn{ + Peeked: peeked, + Conn: c, + } + } + route.target.HandleConn(c) return true } } @@ -170,32 +198,48 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool { return false } -// changeReaderConn is a net.Conn wrapper with a separate reader function. -type changeReaderConn struct { - r io.Reader +// Conn is an incoming connection that has had some bytes read from it +// to determine how to route the connection. The Read method stitches +// the peeked bytes and unread bytes back together. +type Conn struct { + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. net.Conn } -func (c changeReaderConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (c *Conn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.Conn.Read(p) +} // Target is what an incoming matched connection is sent to. type Target interface { // HandleConn is called when an incoming connection is // matched. After the call to HandleConn, the tcpproxy // package never touches the conn again. Implementations are - // responsible for closing the conn when needed. - // - // The c Conn acts like a new conn, without any bytes consumed, - // but it has an unexported concrete type and cannot be type - // asserted to *net.TCPConn, etc. + // responsible for closing the connection when needed. // - // The rawConn represents the underlying connections (with - // some bytes removed) and should only be used for type - // assertions and setting deadlines, not reading. - HandleConn(c net.Conn, rawConn net.Conn) + // The concrete type of conn will be of type *Conn if any + // bytes have been consumed for the purposes of route + // matching. + HandleConn(net.Conn) } -// To is shorthand way of writing &DialProxy{Addr: addr}. +// To is shorthand way of writing &tlsproxy.DialProxy{Addr: addr}. func To(addr string) *DialProxy { return &DialProxy{Addr: addr} } @@ -229,7 +273,16 @@ type DialProxy struct { OnDialError func(src net.Conn, dstDialErr error) } -func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) { +// UnderlyingConn returns c.Conn if c of type *Conn, +// otherwise it returns c. +func UnderlyingConn(c net.Conn) net.Conn { + if wrap, ok := c.(*Conn); ok { + return wrap.Conn + } + return c +} + +func (dp *DialProxy) HandleConn(src net.Conn) { ctx := context.Background() var cancel context.CancelFunc if dp.DialTimeout >= 0 { @@ -246,7 +299,7 @@ func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) { defer src.Close() defer dst.Close() if ka := dp.keepAlivePeriod(); ka > 0 { - if c, ok := rawSrc.(*net.TCPConn); ok { + if c, ok := UnderlyingConn(src).(*net.TCPConn); ok { c.SetKeepAlive(true) c.SetKeepAlivePeriod(ka) } |