diff options
Diffstat (limited to 'tcpproxy.go')
-rw-r--r-- | tcpproxy.go | 48 |
1 files changed, 47 insertions, 1 deletions
diff --git a/tcpproxy.go b/tcpproxy.go index 8a316c7..bada4d3 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -56,6 +56,7 @@ import ( "bufio" "context" "errors" + "fmt" "io" "log" "net" @@ -284,6 +285,15 @@ type DialProxy struct { // If nil, the error is logged and src is closed. // If non-nil, src is not closed automatically. OnDialError func(src net.Conn, dstDialErr error) + + // ProxyProtocolVersion optionally specifies the version of + // HAProxy's PROXY protocol to use. The PROXY protocol provides + // connection metadata to the DialProxy target, via a header + // inserted ahead of the client's traffic. The DialProxy target + // must explicitly support and expect the PROXY header; there is + // no graceful downgrade. + // If zero, no PROXY header is sent. Currently, version 1 is supported. + ProxyProtocolVersion int } // UnderlyingConn returns c.Conn if c of type *Conn, @@ -310,8 +320,14 @@ func (dp *DialProxy) HandleConn(src net.Conn) { dp.onDialError()(src, err) return } - defer src.Close() defer dst.Close() + + if err = dp.sendProxyHeader(dst, src); err != nil { + dp.onDialError()(src, err) + return + } + defer src.Close() + if ka := dp.keepAlivePeriod(); ka > 0 { if c, ok := UnderlyingConn(src).(*net.TCPConn); ok { c.SetKeepAlive(true) @@ -322,12 +338,42 @@ func (dp *DialProxy) HandleConn(src net.Conn) { c.SetKeepAlivePeriod(ka) } } + errc := make(chan error, 1) go proxyCopy(errc, src, dst) go proxyCopy(errc, dst, src) <-errc } +func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { + switch dp.ProxyProtocolVersion { + case 0: + return nil + case 1: + var srcAddr, dstAddr *net.TCPAddr + if a, ok := src.RemoteAddr().(*net.TCPAddr); ok { + srcAddr = a + } + if a, ok := src.LocalAddr().(*net.TCPAddr); ok { + dstAddr = a + } + + if srcAddr == nil || dstAddr == nil { + _, err := io.WriteString(w, "PROXY UNKNOWN\r\n") + return err + } + + family := "TCP4" + if srcAddr.IP.To4() == nil { + family = "TCP6" + } + _, err := fmt.Fprintf(w, "PROXY %s %s %d %s %d\r\n", family, srcAddr.IP, srcAddr.Port, dstAddr.IP, dstAddr.Port) + return err + default: + return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion) + } +} + // proxyCopy is the function that copies bytes around. // It's a named function instead of a func literal so users get // named goroutines in debug goroutine stack dumps. |