summaryrefslogtreecommitdiff
path: root/tcpproxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'tcpproxy.go')
-rw-r--r--tcpproxy.go48
1 files changed, 47 insertions, 1 deletions
diff --git a/tcpproxy.go b/tcpproxy.go
index 8a316c7..bada4d3 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -56,6 +56,7 @@ import (
"bufio"
"context"
"errors"
+ "fmt"
"io"
"log"
"net"
@@ -284,6 +285,15 @@ type DialProxy struct {
// If nil, the error is logged and src is closed.
// If non-nil, src is not closed automatically.
OnDialError func(src net.Conn, dstDialErr error)
+
+ // ProxyProtocolVersion optionally specifies the version of
+ // HAProxy's PROXY protocol to use. The PROXY protocol provides
+ // connection metadata to the DialProxy target, via a header
+ // inserted ahead of the client's traffic. The DialProxy target
+ // must explicitly support and expect the PROXY header; there is
+ // no graceful downgrade.
+ // If zero, no PROXY header is sent. Currently, version 1 is supported.
+ ProxyProtocolVersion int
}
// UnderlyingConn returns c.Conn if c of type *Conn,
@@ -310,8 +320,14 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
dp.onDialError()(src, err)
return
}
- defer src.Close()
defer dst.Close()
+
+ if err = dp.sendProxyHeader(dst, src); err != nil {
+ dp.onDialError()(src, err)
+ return
+ }
+ defer src.Close()
+
if ka := dp.keepAlivePeriod(); ka > 0 {
if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
c.SetKeepAlive(true)
@@ -322,12 +338,42 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
c.SetKeepAlivePeriod(ka)
}
}
+
errc := make(chan error, 1)
go proxyCopy(errc, src, dst)
go proxyCopy(errc, dst, src)
<-errc
}
+func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
+ switch dp.ProxyProtocolVersion {
+ case 0:
+ return nil
+ case 1:
+ var srcAddr, dstAddr *net.TCPAddr
+ if a, ok := src.RemoteAddr().(*net.TCPAddr); ok {
+ srcAddr = a
+ }
+ if a, ok := src.LocalAddr().(*net.TCPAddr); ok {
+ dstAddr = a
+ }
+
+ if srcAddr == nil || dstAddr == nil {
+ _, err := io.WriteString(w, "PROXY UNKNOWN\r\n")
+ return err
+ }
+
+ family := "TCP4"
+ if srcAddr.IP.To4() == nil {
+ family = "TCP6"
+ }
+ _, err := fmt.Fprintf(w, "PROXY %s %s %d %s %d\r\n", family, srcAddr.IP, srcAddr.Port, dstAddr.IP, dstAddr.Port)
+ return err
+ default:
+ return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion)
+ }
+}
+
// proxyCopy is the function that copies bytes around.
// It's a named function instead of a func literal so users get
// named goroutines in debug goroutine stack dumps.