diff options
-rw-r--r-- | tcpproxy.go | 40 | ||||
-rw-r--r-- | tcpproxy_test.go | 40 |
2 files changed, 72 insertions, 8 deletions
diff --git a/tcpproxy.go b/tcpproxy.go index 5d178c6..1f03e32 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -345,8 +345,30 @@ func UnderlyingConn(c net.Conn) net.Conn { return c } +func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) { + if c, ok := UnderlyingConn(c).(*net.TCPConn); ok { + return c, ok + } + if c, ok := c.(*net.TCPConn); ok { + return c, ok + } + return nil, false +} + func goCloseConn(c net.Conn) { go c.Close() } +func closeRead(c net.Conn) { + if c, ok := tcpConn(c); ok { + c.CloseRead() + } +} + +func closeWrite(c net.Conn) { + if c, ok := tcpConn(c); ok { + c.CloseWrite() + } +} + // HandleConn implements the Target interface. func (dp *DialProxy) HandleConn(src net.Conn) { ctx := context.Background() @@ -371,20 +393,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) { defer goCloseConn(src) if ka := dp.keepAlivePeriod(); ka > 0 { - if c, ok := UnderlyingConn(src).(*net.TCPConn); ok { - c.SetKeepAlive(true) - c.SetKeepAlivePeriod(ka) - } - if c, ok := dst.(*net.TCPConn); ok { - c.SetKeepAlive(true) - c.SetKeepAlivePeriod(ka) + for _, c := range []net.Conn{src, dst} { + if c, ok := tcpConn(c); ok { + c.SetKeepAlive(true) + c.SetKeepAlivePeriod(ka) + } } } - errc := make(chan error, 1) + errc := make(chan error, 2) go proxyCopy(errc, src, dst) go proxyCopy(errc, dst, src) <-errc + <-errc } func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { @@ -420,6 +441,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error { // It's a named function instead of a func literal so users get // named goroutines in debug goroutine stack dumps. func proxyCopy(errc chan<- error, dst, src net.Conn) { + defer closeRead(src) + defer closeWrite(dst) + // Before we unwrap src and/or dst, copy any buffered data. if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 { if _, err := dst.Write(wc.Peeked); err != nil { diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 38feb06..0346a7a 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -174,6 +174,45 @@ func testProxy(t *testing.T, front net.Listener) *Proxy { } } +func TestBufferedClose(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + defer fromProxy.Close() + const msg = "message" + if _, err := io.WriteString(toFront, msg); err != nil { + t.Fatal(err) + } + // actively close toFront, the write should still make to the back. + toFront.Close() + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + func TestProxyAlwaysMatch(t *testing.T) { front := newLocalListener(t) defer front.Close() @@ -196,6 +235,7 @@ func TestProxyAlwaysMatch(t *testing.T) { if err != nil { t.Fatal(err) } + defer fromProxy.Close() const msg = "message" io.WriteString(toFront, msg) |