summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <[email protected]>2017-06-22 16:33:27 -0700
committerBrad Fitzpatrick <[email protected]>2017-06-22 16:33:27 -0700
commite3878897bde4f5d532f67738009cf1b9fcd2f408 (patch)
tree02f1bc882d95a041a1bd9a85d95ae367ba7930bd
parent2eb0155fac2d41b022bc0a430d13aa3b45825f1d (diff)
Add TargetListener
-rw-r--r--listener.go93
-rw-r--r--listener_test.go40
-rw-r--r--tcpproxy.go101
3 files changed, 210 insertions, 24 deletions
diff --git a/listener.go b/listener.go
new file mode 100644
index 0000000..6b0442d
--- /dev/null
+++ b/listener.go
@@ -0,0 +1,93 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpproxy
+
+import (
+ "io"
+ "net"
+ "sync"
+)
+
+// TargetListener implements both net.Listener and Target.
+// Matched Targets become accepted connections.
+type TargetListener struct {
+ Address string // Address is the string reported by TargetListener.Addr().String().
+
+ mu sync.Mutex
+ cond *sync.Cond
+ closed bool
+ nextConn net.Conn
+}
+
+var (
+ _ net.Listener = (*TargetListener)(nil)
+ _ Target = (*TargetListener)(nil)
+)
+
+func (tl *TargetListener) lock() {
+ tl.mu.Lock()
+ if tl.cond == nil {
+ tl.cond = sync.NewCond(&tl.mu)
+ }
+}
+
+type tcpAddr string
+
+func (a tcpAddr) Network() string { return "tcp" }
+func (a tcpAddr) String() string { return string(a) }
+
+func (tl *TargetListener) Addr() net.Addr { return tcpAddr(tl.Address) }
+
+func (tl *TargetListener) Close() error {
+ tl.lock()
+ if tl.closed {
+ tl.mu.Unlock()
+ return nil
+ }
+ tl.closed = true
+ tl.mu.Unlock()
+ tl.cond.Broadcast()
+ return nil
+}
+
+// HandleConn implements the Target interface. It blocks until tl is
+// closed or another goroutine has called Accept and received c.
+func (tl *TargetListener) HandleConn(c net.Conn) {
+ tl.lock()
+ defer tl.mu.Unlock()
+ for tl.nextConn != nil && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ c.Close()
+ return
+ }
+ tl.nextConn = c
+ tl.cond.Broadcast() // Signal might be sufficient; verify.
+ for tl.nextConn == c && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ c.Close()
+ return
+ }
+}
+
+func (tl *TargetListener) Accept() (net.Conn, error) {
+ tl.lock()
+ for tl.nextConn == nil && !tl.closed {
+ tl.cond.Wait()
+ }
+ if tl.closed {
+ tl.mu.Unlock()
+ return nil, io.EOF
+ }
+ c := tl.nextConn
+ tl.nextConn = nil
+ tl.mu.Unlock()
+ tl.cond.Broadcast() // Signal might be sufficient; verify.
+
+ return c, nil
+}
diff --git a/listener_test.go b/listener_test.go
new file mode 100644
index 0000000..6b713f1
--- /dev/null
+++ b/listener_test.go
@@ -0,0 +1,40 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpproxy
+
+import (
+ "io"
+ "testing"
+)
+
+func TestListenerAccept(t *testing.T) {
+ tl := new(TargetListener)
+ ch := make(chan interface{}, 1)
+ go func() {
+ for {
+ conn, err := tl.Accept()
+ if err != nil {
+ ch <- err
+ return
+ } else {
+ ch <- conn
+ }
+ }
+ }()
+
+ for i := 0; i < 3; i++ {
+ conn := new(Conn)
+ tl.HandleConn(conn)
+ got := <-ch
+ if got != conn {
+ t.Errorf("Accept conn = %v; want %v", got, conn)
+ }
+ }
+ tl.Close()
+ got := <-ch
+ if got != io.EOF {
+ t.Errorf("Accept error post-Close = %v; want io.EOF", got)
+ }
+}
diff --git a/tcpproxy.go b/tcpproxy.go
index fde438f..012684c 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -2,7 +2,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package tcpproxy lets users build TCP, HTTP/1, and TLS+SNI proxies.
+// Package tcpproxy lets users build TCP proxies, optionally making
+// routing decisions based on HTTP/1 Host headers and the SNI hostname
+// in TLS connections.
//
// Typical usage:
//
@@ -14,11 +16,31 @@
// p.AddSNIHostRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432"))
// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback
// log.Fatal(p.Run())
+//
+// Calling Run (or Start) on a proxy also starts all the necessary
+// listeners.
+//
+// For each accepted connection, the rules for that ipPort are
+// matched, in order. If one matches (currently HTTP Host, SNI, or
+// always), then the connection is handed to the target.
+//
+// The two predefined Target implementations are:
+//
+// 1) DialProxy, proxying to another address (use the To func to return a
+// DialProxy value),
+//
+// 2) TargetListener, making the matched connection available via a
+// net.Listener.Accept call.
+//
+// But Target is an interface, so you can also write your own.
+//
+// Note that tcpproxy does not do any TLS encryption or decryption. It
+// only (via DialProxy) copies bytes around. The SNI hostname in the TLS
+// header is unencrypted, for better or worse.
package tcpproxy
import (
"bufio"
- "bytes"
"context"
"errors"
"io"
@@ -28,7 +50,10 @@ import (
)
// Proxy is a proxy. Its zero value is a valid proxy that does
-// nothing. Call methods to add routes before calling Run.
+// nothing. Call methods to add routes before calling Start or Run.
+//
+// The order that routes are added in matters; each is matched in the order
+// registered.
type Proxy struct {
routes map[string][]route // ip:port => route
@@ -156,11 +181,14 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
br := bufio.NewReader(c)
for _, route := range routes {
if route.matcher.match(br) {
- buffered, _ := br.Peek(br.Buffered())
- route.target.HandleConn(changeReaderConn{
- r: io.MultiReader(bytes.NewReader(buffered), c),
- Conn: c,
- }, c)
+ if n := br.Buffered(); n > 0 {
+ peeked, _ := br.Peek(br.Buffered())
+ c = &Conn{
+ Peeked: peeked,
+ Conn: c,
+ }
+ }
+ route.target.HandleConn(c)
return true
}
}
@@ -170,32 +198,48 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
return false
}
-// changeReaderConn is a net.Conn wrapper with a separate reader function.
-type changeReaderConn struct {
- r io.Reader
+// Conn is an incoming connection that has had some bytes read from it
+// to determine how to route the connection. The Read method stitches
+// the peeked bytes and unread bytes back together.
+type Conn struct {
+ // Peeked are the bytes that have been read from Conn for the
+ // purposes of route matching, but have not yet been consumed
+ // by Read calls. It set to nil by Read when fully consumed.
+ Peeked []byte
+
+ // Conn is the underlying connection.
+ // It can be type asserted against *net.TCPConn or other types
+ // as needed. It should not be read from directly unless
+ // Peeked is nil.
net.Conn
}
-func (c changeReaderConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+func (c *Conn) Read(p []byte) (n int, err error) {
+ if len(c.Peeked) > 0 {
+ n = copy(p, c.Peeked)
+ c.Peeked = c.Peeked[n:]
+ if len(c.Peeked) == 0 {
+ c.Peeked = nil
+ }
+ return n, nil
+ }
+ return c.Conn.Read(p)
+}
// Target is what an incoming matched connection is sent to.
type Target interface {
// HandleConn is called when an incoming connection is
// matched. After the call to HandleConn, the tcpproxy
// package never touches the conn again. Implementations are
- // responsible for closing the conn when needed.
- //
- // The c Conn acts like a new conn, without any bytes consumed,
- // but it has an unexported concrete type and cannot be type
- // asserted to *net.TCPConn, etc.
+ // responsible for closing the connection when needed.
//
- // The rawConn represents the underlying connections (with
- // some bytes removed) and should only be used for type
- // assertions and setting deadlines, not reading.
- HandleConn(c net.Conn, rawConn net.Conn)
+ // The concrete type of conn will be of type *Conn if any
+ // bytes have been consumed for the purposes of route
+ // matching.
+ HandleConn(net.Conn)
}
-// To is shorthand way of writing &DialProxy{Addr: addr}.
+// To is shorthand way of writing &tlsproxy.DialProxy{Addr: addr}.
func To(addr string) *DialProxy {
return &DialProxy{Addr: addr}
}
@@ -229,7 +273,16 @@ type DialProxy struct {
OnDialError func(src net.Conn, dstDialErr error)
}
-func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) {
+// UnderlyingConn returns c.Conn if c of type *Conn,
+// otherwise it returns c.
+func UnderlyingConn(c net.Conn) net.Conn {
+ if wrap, ok := c.(*Conn); ok {
+ return wrap.Conn
+ }
+ return c
+}
+
+func (dp *DialProxy) HandleConn(src net.Conn) {
ctx := context.Background()
var cancel context.CancelFunc
if dp.DialTimeout >= 0 {
@@ -246,7 +299,7 @@ func (dp *DialProxy) HandleConn(src net.Conn, rawSrc net.Conn) {
defer src.Close()
defer dst.Close()
if ka := dp.keepAlivePeriod(); ka > 0 {
- if c, ok := rawSrc.(*net.TCPConn); ok {
+ if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
c.SetKeepAlive(true)
c.SetKeepAlivePeriod(ka)
}