diff options
Diffstat (limited to 'tcpproxy.go')
-rw-r--r-- | tcpproxy.go | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/tcpproxy.go b/tcpproxy.go index 5d178c6..1f03e32 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -345,8 +345,30 @@ func UnderlyingConn(c net.Conn) net.Conn { return c } +func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) { + if c, ok := UnderlyingConn(c).(*net.TCPConn); ok { + return c, ok + } + if c, ok := c.(*net.TCPConn); ok { + return c, ok + } + return nil, false +} + func goCloseConn(c net.Conn) { go c.Close() } +func closeRead(c net.Conn) { + if c, ok := tcpConn(c); ok { + c.CloseRead() + } +} + +func closeWrite(c net.Conn) { + if c, ok := tcpConn(c); ok { + c.CloseWrite() + } +} + // HandleConn implements the Target interface. func (dp *DialProxy) HandleConn(src net.Conn) { ctx := context.Background() @@ -371,20 +393,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) { defer goCloseConn(src) if ka := dp.keepAlivePeriod(); ka > 0 { - if c, ok := UnderlyingConn(src).(*net.TCPConn); ok { - c.SetKeepAlive(true) - c.SetKeepAlivePeriod(ka) - } - if c, ok := dst.(*net.TCPConn); ok { - c.SetKeepAlive(true) - c.SetKeepAlivePeriod(ka) + for _, c := range []net.Conn{src, dst} { + if c, ok := tcpConn(c); ok { + c.SetKeepAlive(true) + c.SetKeepAlivePeriod(ka) + } } } - errc := make(chan error, 1) + errc := make(chan error, 2) go proxyCopy(errc, src, dst) go proxyCopy(errc, dst, src) <-errc + <-errc } func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { @@ -420,6 +441,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { // It's a named function instead of a func literal so users get // named goroutines in debug goroutine stack dumps. func proxyCopy(errc chan<- error, dst, src net.Conn) { + defer closeRead(src) + defer closeWrite(dst) + // Before we unwrap src and/or dst, copy any buffered data. if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 { if _, err := dst.Write(wc.Peeked); err != nil { |