summaryrefslogtreecommitdiff
path: root/tcpproxy.go
diff options
context:
space:
mode:
Diffstat (limited to 'tcpproxy.go')
-rw-r--r--tcpproxy.go40
1 files changed, 32 insertions, 8 deletions
diff --git a/tcpproxy.go b/tcpproxy.go
index 5d178c6..1f03e32 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -345,8 +345,30 @@ func UnderlyingConn(c net.Conn) net.Conn {
return c
}
+func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) {
+ if c, ok := UnderlyingConn(c).(*net.TCPConn); ok {
+ return c, ok
+ }
+ if c, ok := c.(*net.TCPConn); ok {
+ return c, ok
+ }
+ return nil, false
+}
+
func goCloseConn(c net.Conn) { go c.Close() }
+func closeRead(c net.Conn) {
+ if c, ok := tcpConn(c); ok {
+ c.CloseRead()
+ }
+}
+
+func closeWrite(c net.Conn) {
+ if c, ok := tcpConn(c); ok {
+ c.CloseWrite()
+ }
+}
+
// HandleConn implements the Target interface.
func (dp *DialProxy) HandleConn(src net.Conn) {
ctx := context.Background()
@@ -371,20 +393,19 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
defer goCloseConn(src)
if ka := dp.keepAlivePeriod(); ka > 0 {
- if c, ok := UnderlyingConn(src).(*net.TCPConn); ok {
- c.SetKeepAlive(true)
- c.SetKeepAlivePeriod(ka)
- }
- if c, ok := dst.(*net.TCPConn); ok {
- c.SetKeepAlive(true)
- c.SetKeepAlivePeriod(ka)
+ for _, c := range []net.Conn{src, dst} {
+ if c, ok := tcpConn(c); ok {
+ c.SetKeepAlive(true)
+ c.SetKeepAlivePeriod(ka)
+ }
}
}
- errc := make(chan error, 1)
+ errc := make(chan error, 2)
go proxyCopy(errc, src, dst)
go proxyCopy(errc, dst, src)
<-errc
+ <-errc
}
func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
@@ -420,6 +441,9 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
// It's a named function instead of a func literal so users get
// named goroutines in debug goroutine stack dumps.
func proxyCopy(errc chan<- error, dst, src net.Conn) {
+ defer closeRead(src)
+ defer closeWrite(dst)
+
// Before we unwrap src and/or dst, copy any buffered data.
if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 {
if _, err := dst.Write(wc.Peeked); err != nil {