summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <[email protected]>2022-10-16 16:54:21 -0700
committerBrad Fitzpatrick <[email protected]>2022-10-16 16:54:21 -0700
commit74ca1dc5d55168d202044c415dcf2e08d80c3fdc (patch)
tree688ff5d5dc36d752fe5c406cb67e74e6e83250a5
parent4e04b92f29ea8f8a10d28528a47ecc0f93814473 (diff)
add Proxy.AddSNIRouteFunc to do lookups by SNI dynamically
-rw-r--r--sni.go23
-rw-r--r--tcpproxy_test.go46
2 files changed, 67 insertions, 2 deletions
diff --git a/sni.go b/sni.go
index 53b53c2..eb826d4 100644
--- a/sni.go
+++ b/sni.go
@@ -57,7 +57,16 @@ func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) {
cfg.acmeTargets = append(cfg.acmeTargets, dest)
}
- p.addRoute(ipPort, sniMatch{matcher, dest})
+ p.addRoute(ipPort, sniMatch{matcher: matcher, target: dest})
+}
+
+// SNITargetFunc is the func callback used by Proxy.AddSNIRouteFunc.
+type SNITargetFunc func(ctx context.Context, sniName string) (t Target, ok bool)
+
+// AddSNIRouteFunc adds a route to ipPort that matches an SNI request and calls
+// fn to map its nap to a target.
+func (p *Proxy) AddSNIRouteFunc(ipPort string, fn SNITargetFunc) {
+ p.addRoute(ipPort, sniMatch{targetFunc: fn})
}
// AddStopACMESearch prevents ACME probing of subsequent SNI routes.
@@ -71,10 +80,22 @@ func (p *Proxy) AddStopACMESearch(ipPort string) {
type sniMatch struct {
matcher Matcher
target Target
+
+ // Alternatively, if targetFunc is non-nil, it's used instead:
+ targetFunc SNITargetFunc
}
func (m sniMatch) match(br *bufio.Reader) (Target, string) {
sni := clientHelloServerName(br)
+ if sni == "" {
+ return nil, ""
+ }
+ if m.targetFunc != nil {
+ if t, ok := m.targetFunc(context.TODO(), sni); ok {
+ return t, sni
+ }
+ return nil, ""
+ }
if m.matcher(context.TODO(), sni) {
return m.target, sni
}
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 4849c68..b6135b2 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -17,6 +17,7 @@ package tcpproxy
import (
"bufio"
"bytes"
+ "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@@ -287,6 +288,49 @@ func TestProxySNI(t *testing.T) {
}
}
+func TestAddSNIRouteFunc(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newLocalListener(t)
+ defer backFoo.Close()
+ backBar := newLocalListener(t)
+ defer backBar.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRouteFunc(testFrontAddr, func(ctx context.Context, sniName string) (_ Target, ok bool) {
+ if sniName == "bar.com" {
+ return To(backBar.Addr().String()), true
+ }
+ t.Fatalf("failed to match %q", sniName)
+ return nil, false
+ })
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ toFront, err := net.Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer toFront.Close()
+
+ msg := clientHelloRecord(t, "bar.com")
+ io.WriteString(toFront, msg)
+
+ fromProxy, err := backBar.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf := make([]byte, len(msg))
+ if _, err := io.ReadFull(fromProxy, buf); err != nil {
+ t.Fatal(err)
+ }
+ if string(buf) != msg {
+ t.Fatalf("got %q; want %q", buf, msg)
+ }
+}
func TestProxyPROXYOut(t *testing.T) {
front := newLocalListener(t)
defer front.Close()
@@ -362,7 +406,7 @@ func (t *tlsServer) Close() {
// cert creates a well-formed, but completely insecure self-signed
// cert for domain.
func cert(t *testing.T, domain string) tls.Certificate {
- private, err := rsa.GenerateKey(rand.Reader, 512)
+ private, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}