summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tcpproxy.go40
-rw-r--r--tcpproxy_test.go40
2 files changed, 72 insertions, 8 deletions
diff --git a/tcpproxy.go b/tcpproxy.go
index 5d178c6..1f03e32 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -345,8 +345,30 @@ func UnderlyingConn(c net.Conn) net.Conn {
return c
}
+func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) {
+ if c, ok := UnderlyingConn(c).(*net.TCPConn); ok {
+ return c, ok
+ }
+ if c, ok := c.(*net.TCPConn); ok {
+ return c, ok
+ }
+ return nil, false
+}
+
func goCloseConn(c net.Conn) { go c.Close() }
+func closeRead(c net.Conn) {
+ if c, ok := tcpConn(c); ok {
+ c.CloseRead()
+ }
+}
+
+func closeWrite(c net.Conn) {
+ if c, ok := tcpConn(c); ok {
+ c.CloseWrite()
+ }
+}
+
// HandleConn implements the Target interface.
func (dp *DialProxy) HandleConn(src net.Conn) {
ctx := context.Background()
@@ -371,20 +393,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
defer goCloseConn(src)
if ka := dp.keepAlivePeriod(); ka > 0 {
- if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
- c.SetKeepAlive(true)
- c.SetKeepAlivePeriod(ka)
- }
- if c, ok := dst.(*net.TCPConn); ok {
- c.SetKeepAlive(true)
- c.SetKeepAlivePeriod(ka)
+ for _, c := range []net.Conn{src, dst} {
+ if c, ok := tcpConn(c); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
}
}
- errc := make(chan error, 1)
+ errc := make(chan error, 2)
go proxyCopy(errc, src, dst)
go proxyCopy(errc, dst, src)
<-errc
+ <-errc
}
func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
@@ -420,6 +441,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
// It's a named function instead of a func literal so users get
// named goroutines in debug goroutine stack dumps.
func proxyCopy(errc chan<- error, dst, src net.Conn) {
+ defer closeRead(src)
+ defer closeWrite(dst)
+
// Before we unwrap src and/or dst, copy any buffered data.
if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 {
if _, err := dst.Write(wc.Peeked); err != nil {
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 38feb06..0346a7a 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -174,6 +174,45 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
}
}
+func TestBufferedClose(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+ back := newLocalListener(t)
+ defer back.Close()
+
+ p := testProxy(t, front)
+ p.AddRoute(testFrontAddr, To(back.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fromProxy.Close()
+ const msg = "message"
+ if _, err := io.WriteString(toFront, msg); err != nil {
+ t.Fatal(err)
+ }
+ // actively close toFront, the write should still make to the back.
+ toFront.Close()
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
+
func TestProxyAlwaysMatch(t *testing.T) {
front := newLocalListener(t)
defer front.Close()
@@ -196,6 +235,7 @@ func TestProxyAlwaysMatch(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ defer fromProxy.Close()
const msg = "message"
io.WriteString(toFront, msg)