diff options
Diffstat (limited to 'tcpproxy_test.go')
-rw-r--r-- | tcpproxy_test.go | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 4849c68..b6135b2 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -17,6 +17,7 @@ package tcpproxy import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -287,6 +288,49 @@ func TestProxySNI(t *testing.T) { } } +func TestAddSNIRouteFunc(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) { + if sniName == "bar.com" { + return To(backBar.Addr().String()), true + } + t.Fatalf("failed to match %q", sniName) + return nil, false + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + msg := clientHelloRecord(t, "bar.com") + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} func TestProxyPROXYOut(t *testing.T) { front := newLocalListener(t) defer front.Close() @@ -362,7 +406,7 @@ func (t *tlsServer) Close() { // cert creates a well-formed, but completely insecure self-signed // cert for domain. func cert(t *testing.T, domain string) tls.Certificate { - private, err := rsa.GenerateKey(rand.Reader, 512) + private, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { t.Fatal(err) } |