From 74ca1dc5d55168d202044c415dcf2e08d80c3fdc Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 16 Oct 2022 16:54:21 -0700 Subject: add Proxy.AddSNIRouteFunc to do lookups by SNI dynamically --- sni.go | 23 ++++++++++++++++++++++- tcpproxy_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/sni.go b/sni.go index 53b53c2..eb826d4 100644 --- a/sni.go +++ b/sni.go @@ -57,7 +57,16 @@ func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) { cfg.acmeTargets = append(cfg.acmeTargets, dest) } - p.addRoute(ipPort, sniMatch{matcher, dest}) + p.addRoute(ipPort, sniMatch{matcher: matcher, target: dest}) +} + +// SNITargetFunc is the func callback used by Proxy.AddSNIRouteFunc. +type SNITargetFunc func(ctx context.Context, sniName string) (t Target, ok bool) + +// AddSNIRouteFunc adds a route to ipPort that matches an SNI request and calls +// fn to map its nap to a target. +func (p *Proxy) AddSNIRouteFunc(ipPort string, fn SNITargetFunc) { + p.addRoute(ipPort, sniMatch{targetFunc: fn}) } // AddStopACMESearch prevents ACME probing of subsequent SNI routes. @@ -71,10 +80,22 @@ func (p *Proxy) AddStopACMESearch(ipPort string) { type sniMatch struct { matcher Matcher target Target + + // Alternatively, if targetFunc is non-nil, it's used instead: + targetFunc SNITargetFunc } func (m sniMatch) match(br *bufio.Reader) (Target, string) { sni := clientHelloServerName(br) + if sni == "" { + return nil, "" + } + if m.targetFunc != nil { + if t, ok := m.targetFunc(context.TODO(), sni); ok { + return t, sni + } + return nil, "" + } if m.matcher(context.TODO(), sni) { return m.target, sni } diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 4849c68..b6135b2 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -17,6 +17,7 @@ package tcpproxy import ( "bufio" "bytes" + "context" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -287,6 +288,49 @@ func TestProxySNI(t *testing.T) { } } +func TestAddSNIRouteFunc(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) { + if sniName == "bar.com" { + return To(backBar.Addr().String()), true + } + t.Fatalf("failed to match %q", sniName) + return nil, false + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + msg := clientHelloRecord(t, "bar.com") + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} func TestProxyPROXYOut(t *testing.T) { front := newLocalListener(t) defer front.Close() @@ -362,7 +406,7 @@ func (t *tlsServer) Close() { // cert creates a well-formed, but completely insecure self-signed // cert for domain. func cert(t *testing.T, domain string) tls.Certificate { - private, err := rsa.GenerateKey(rand.Reader, 512) + private, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { t.Fatal(err) } -- cgit v1.2.3