diff options
author | David Anderson <[email protected]> | 2017-07-05 18:18:48 -0700 |
---|---|---|
committer | David Anderson <[email protected]> | 2017-07-05 19:23:29 -0700 |
commit | 28de75ab2159e28b7e0318fd2ac9617625cc0788 (patch) | |
tree | 70ef7d0d85f2095ec3d63a03d9ad011ecc5d74b4 | |
parent | e03035937341374a9be6eb8459ffe4f23bacd185 (diff) |
Support HAProxy's PROXY protocol v1 in DialProxy.proxy-protocol
-rw-r--r-- | tcpproxy.go | 48 | ||||
-rw-r--r-- | tcpproxy_test.go | 40 |
2 files changed, 87 insertions, 1 deletions
diff --git a/tcpproxy.go b/tcpproxy.go index 8a316c7..bada4d3 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -56,6 +56,7 @@ import ( "bufio" "context" "errors" + "fmt" "io" "log" "net" @@ -284,6 +285,15 @@ type DialProxy struct { // If nil, the error is logged and src is closed. // If non-nil, src is not closed automatically. OnDialError func(src net.Conn, dstDialErr error) + + // ProxyProtocolVersion optionally specifies the version of + // HAProxy's PROXY protocol to use. The PROXY protocol provides + // connection metadata to the DialProxy target, via a header + // inserted ahead of the client's traffic. The DialProxy target + // must explicitly support and expect the PROXY header; there is + // no graceful downgrade. + // If zero, no PROXY header is sent. Currently, version 1 is supported. + ProxyProtocolVersion int } // UnderlyingConn returns c.Conn if c of type *Conn, @@ -310,8 +320,14 @@ func (dp *DialProxy) HandleConn(src net.Conn) { dp.onDialError()(src, err) return } - defer src.Close() defer dst.Close() + + if err = dp.sendProxyHeader(dst, src); err != nil { + dp.onDialError()(src, err) + return + } + defer src.Close() + if ka := dp.keepAlivePeriod(); ka > 0 { if c, ok := UnderlyingConn(src).(*net.TCPConn); ok { c.SetKeepAlive(true) @@ -322,12 +338,42 @@ func (dp *DialProxy) HandleConn(src net.Conn) { c.SetKeepAlivePeriod(ka) } } + errc := make(chan error, 1) go proxyCopy(errc, src, dst) go proxyCopy(errc, dst, src) <-errc } +func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { + switch dp.ProxyProtocolVersion { + case 0: + return nil + case 1: + var srcAddr, dstAddr *net.TCPAddr + if a, ok := src.RemoteAddr().(*net.TCPAddr); ok { + srcAddr = a + } + if a, ok := src.LocalAddr().(*net.TCPAddr); ok { + dstAddr = a + } + + if srcAddr == nil || dstAddr == nil { + _, err := io.WriteString(w, "PROXY UNKNOWN\r\n") + return err + } + + family := "TCP4" + if srcAddr.IP.To4() == nil { + family = "TCP6" + } + _, err := fmt.Fprintf(w, "PROXY %s %s %d %s %d\r\n", family, srcAddr.IP, srcAddr.Port, dstAddr.IP, dstAddr.Port) + return err + default: + return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) + } +} + // proxyCopy is the function that copies bytes around. // It's a named function instead of a func literal so users get // named goroutines in debug goroutine stack dumps. diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 4cfb0ab..45d8b0e 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "strings" "testing" @@ -268,3 +269,42 @@ func TestProxySNI(t *testing.T) { t.Fatalf("got %q; want %q", buf, msg) } } + +func TestProxyPROXYOut(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, &DialProxy{ + Addr: back.Addr().String(), + ProxyProtocolVersion: 1, + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + + io.WriteString(toFront, "foo") + toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + + bs, err := ioutil.ReadAll(fromProxy) + if err != nil { + t.Fatal(err) + } + + want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port) + if string(bs) != want { + t.Fatalf("got %q; want %q", bs, want) + } +} |