summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tcpproxy.go48
-rw-r--r--tcpproxy_test.go40
2 files changed, 87 insertions, 1 deletions
diff --git a/tcpproxy.go b/tcpproxy.go
index 8a316c7..bada4d3 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -56,6 +56,7 @@ import (
"bufio"
"context"
"errors"
+ "fmt"
"io"
"log"
"net"
@@ -284,6 +285,15 @@ type DialProxy struct {
// If nil, the error is logged and src is closed.
// If non-nil, src is not closed automatically.
OnDialError func(src net.Conn, dstDialErr error)
+
+ // ProxyProtocolVersion optionally specifies the version of
+ // HAProxy's PROXY protocol to use. The PROXY protocol provides
+ // connection metadata to the DialProxy target, via a header
+ // inserted ahead of the client's traffic. The DialProxy target
+ // must explicitly support and expect the PROXY header; there is
+ // no graceful downgrade.
+ // If zero, no PROXY header is sent. Currently, version 1 is supported.
+ ProxyProtocolVersion int
}
// UnderlyingConn returns c.Conn if c of type *Conn,
@@ -310,8 +320,14 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
dp.onDialError()(src, err)
return
}
- defer src.Close()
defer dst.Close()
+
+ if err = dp.sendProxyHeader(dst, src); err != nil {
+ dp.onDialError()(src, err)
+ return
+ }
+ defer src.Close()
+
if ka := dp.keepAlivePeriod(); ka > 0 {
if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
c.SetKeepAlive(true)
@@ -322,12 +338,42 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
c.SetKeepAlivePeriod(ka)
}
}
+
errc := make(chan error, 1)
go proxyCopy(errc, src, dst)
go proxyCopy(errc, dst, src)
<-errc
}
+func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
+ switch dp.ProxyProtocolVersion {
+ case 0:
+ return nil
+ case 1:
+ var srcAddr, dstAddr *net.TCPAddr
+ if a, ok := src.RemoteAddr().(*net.TCPAddr); ok {
+ srcAddr = a
+ }
+ if a, ok := src.LocalAddr().(*net.TCPAddr); ok {
+ dstAddr = a
+ }
+
+ if srcAddr == nil || dstAddr == nil {
+ _, err := io.WriteString(w, "PROXY UNKNOWN\r\n")
+ return err
+ }
+
+ family := "TCP4"
+ if srcAddr.IP.To4() == nil {
+ family = "TCP6"
+ }
+ _, err := fmt.Fprintf(w, "PROXY %s %s %d %s %d\r\n", family, srcAddr.IP, srcAddr.Port, dstAddr.IP, dstAddr.Port)
+ return err
+ default:
+ return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion)
+ }
+}
+
// proxyCopy is the function that copies bytes around.
// It's a named function instead of a func literal so users get
// named goroutines in debug goroutine stack dumps.
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 4cfb0ab..45d8b0e 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"io"
+ "io/ioutil"
"net"
"strings"
"testing"
@@ -268,3 +269,42 @@ func TestProxySNI(t *testing.T) {
t.Fatalf("got %q; want %q", buf, msg)
}
}
+
+func TestProxyPROXYOut(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, &DialProxy{
+ Addr: back.Addr().String(),
+ ProxyProtocolVersion: 1,
+ })
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ io.WriteString(toFront, "foo")
+ toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ bs, err := ioutil.ReadAll(fromProxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port)
+ if string(bs) != want {
+ t.Fatalf("got %q; want %q", bs, want)
+ }
+}