summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sni.go95
-rw-r--r--tcpproxy.go47
-rw-r--r--tcpproxy_test.go175
3 files changed, 302 insertions, 15 deletions
diff --git a/sni.go b/sni.go
index f0128bf..50ab599 100644
--- a/sni.go
+++ b/sni.go
@@ -17,9 +17,11 @@ package tcpproxy
import (
"bufio"
"bytes"
+ "context"
"crypto/tls"
"io"
"net"
+ "strings"
)
// AddSNIRoute appends a route to the ipPort listener that says if the
@@ -27,11 +29,31 @@ import (
// dest. If it doesn't match, rule processing continues for any
// additional routes on ipPort.
//
+// By default, the proxy will route all ACME tls-sni-01 challenges
+// received on ipPort to all SNI dests. You can disable ACME routing
+// with AddStopACMESearch.
+//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ cfg := p.configFor(ipPort)
+ if !cfg.stopACME {
+ if len(cfg.acmeTargets) == 0 {
+ p.addRoute(ipPort, &acmeMatch{cfg})
+ }
+ cfg.acmeTargets = append(cfg.acmeTargets, dest)
+ }
+
p.addRoute(ipPort, sniMatch{sni, dest})
}
+// AddStopACMESearch prevents ACME probing of subsequent SNI routes.
+// Any ACME challenges on ipPort for SNI routes previously added
+// before this call will still be proxied to all possible SNI
+// backends.
+func (p *Proxy) AddStopACMESearch(ipPort string) {
+ p.configFor(ipPort).stopACME = true
+}
+
type sniMatch struct {
sni string
target Target
@@ -44,6 +66,79 @@ func (m sniMatch) match(br *bufio.Reader) Target {
return nil
}
+// acmeMatch matches "*.acme.invalid" ACME tls-sni-01 challenges and
+// searches for a Target in cfg.acmeTargets that has the challenge
+// response.
+type acmeMatch struct {
+ cfg *config
+}
+
+func (m *acmeMatch) match(br *bufio.Reader) Target {
+ sni := clientHelloServerName(br)
+ if !strings.HasSuffix(sni, ".acme.invalid") {
+ return nil
+ }
+
+ // TODO: cache. ACME issuers will hit multiple times in a short
+ // burst for each issuance event. A short TTL cache + singleflight
+ // should have an excellent hit rate.
+ // TODO: maybe an acme-specific timeout as well?
+ // TODO: plumb context upwards?
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ch := make(chan Target, len(m.cfg.acmeTargets))
+ for _, target := range m.cfg.acmeTargets {
+ go tryACME(ctx, ch, target, sni)
+ }
+ for range m.cfg.acmeTargets {
+ if target := <-ch; target != nil {
+ return target
+ }
+ }
+
+ // No target was happy with the provided challenge.
+ return nil
+}
+
+func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) {
+ var ret Target
+ defer func() { ch <- ret }()
+
+ conn, targetConn := net.Pipe()
+ defer conn.Close()
+ go dest.HandleConn(targetConn)
+
+ deadline, ok := ctx.Deadline()
+ if ok {
+ conn.SetDeadline(deadline)
+ }
+
+ client := tls.Client(conn, &tls.Config{
+ ServerName: sni,
+ InsecureSkipVerify: true,
+ })
+ if err := client.Handshake(); err != nil {
+ // TODO: log?
+ return
+ }
+ certs := client.ConnectionState().PeerCertificates
+ if len(certs) == 0 {
+ // TODO: log?
+ return
+ }
+ // acme says the first cert offered by the server must match the
+ // challenge hostname.
+ if err := certs[0].VerifyHostname(sni); err != nil {
+ // TODO: log?
+ return
+ }
+
+ // Target presented what looks like a valid challenge
+ // response, send it back to the matcher.
+ ret = dest
+}
+
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
// without consuming any bytes from br.
// On any error, the empty string is returned.
diff --git a/tcpproxy.go b/tcpproxy.go
index 1eee7ea..02b70f5 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -69,7 +69,7 @@ import (
// The order that routes are added in matters; each is matched in the order
// registered.
type Proxy struct {
- routes map[string][]route // ip:port => routes
+ configs map[string]*config // ip:port => config
lns []net.Listener
donec chan struct{} // closed before err
@@ -81,7 +81,22 @@ type Proxy struct {
ListenFunc func(net, laddr string) (net.Listener, error)
}
+// config contains the proxying state for one listener.
+type config struct {
+ routes []route
+ acmeTargets []Target // accumulates targets that should be probed for acme.
+ stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets.
+}
+
+// A route matches a connection to a target.
type route interface {
+ // match examines the initial bytes of a connection, looking for a
+ // match. If a match is found, match returns a non-nil Target to
+ // which the stream should be proxied. match returns nil if the
+ // connection doesn't match.
+ //
+ // match must not consume bytes from the given bufio.Reader, it
+ // can only Peek.
match(*bufio.Reader) Target
}
@@ -92,11 +107,19 @@ func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
return net.Listen
}
-func (p *Proxy) addRoute(ipPort string, r route) {
- if p.routes == nil {
- p.routes = make(map[string][]route)
+func (p *Proxy) configFor(ipPort string) *config {
+ if p.configs == nil {
+ p.configs = make(map[string]*config)
}
- p.routes[ipPort] = append(p.routes[ipPort], r)
+ if p.configs[ipPort] == nil {
+ p.configs[ipPort] = &config{}
+ }
+ return p.configs[ipPort]
+}
+
+func (p *Proxy) addRoute(ipPort string, r route) {
+ cfg := p.configFor(ipPort)
+ cfg.routes = append(cfg.routes, r)
}
// AddRoute appends an always-matching route to the ipPort listener,
@@ -107,14 +130,14 @@ func (p *Proxy) addRoute(ipPort string, r route) {
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
- p.addRoute(ipPort, alwaysMatch{dest})
+ p.addRoute(ipPort, fixedTarget{dest})
}
-type alwaysMatch struct {
+type fixedTarget struct {
t Target
}
-func (m alwaysMatch) match(*bufio.Reader) Target { return m.t }
+func (m fixedTarget) match(*bufio.Reader) Target { return m.t }
// Run is calls Start, and then Wait.
//
@@ -155,16 +178,16 @@ func (p *Proxy) Start() error {
return errors.New("already started")
}
p.donec = make(chan struct{})
- errc := make(chan error, len(p.routes))
- p.lns = make([]net.Listener, 0, len(p.routes))
- for ipPort, routes := range p.routes {
+ errc := make(chan error, len(p.configs))
+ p.lns = make([]net.Listener, 0, len(p.configs))
+ for ipPort, config := range p.configs {
ln, err := p.netListen()("tcp", ipPort)
if err != nil {
p.Close()
return err
}
p.lns = append(p.lns, ln)
- go p.serveListener(errc, ln, routes)
+ go p.serveListener(errc, ln, config.routes)
}
go p.awaitFirstError(errc)
return nil
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 7150372..ac7c917 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -17,19 +17,26 @@ package tcpproxy
import (
"bufio"
"bytes"
+ "crypto/rand"
+ "crypto/rsa"
"crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
+ "math/big"
"net"
"strings"
"testing"
+ "time"
)
type noopTarget struct{}
-func (t *noopTarget) HandleConn(net.Conn) {}
+func (noopTarget) HandleConn(net.Conn) {}
func TestMatchHTTPHost(t *testing.T) {
tests := []struct {
@@ -57,7 +64,6 @@ func TestMatchHTTPHost(t *testing.T) {
want: true,
},
}
- target := &noopTarget{}
for i, tt := range tests {
name := tt.name
if name == "" {
@@ -65,7 +71,7 @@ func TestMatchHTTPHost(t *testing.T) {
}
t.Run(name, func(t *testing.T) {
br := bufio.NewReader(tt.r)
- r := httpHostMatch{tt.host, target}
+ r := httpHostMatch{tt.host, noopTarget{}}
got := r.match(br) != nil
if got != tt.want {
t.Fatalf("match = %v; want %v", got, tt.want)
@@ -313,3 +319,166 @@ func TestProxyPROXYOut(t *testing.T) {
t.Fatalf("got %q; want %q", bs, want)
}
}
+
+type tlsServer struct {
+ Listener net.Listener
+ Domain string
+ Test *testing.T
+}
+
+func (t *tlsServer) Start() {
+ cert, acmeCert := cert(t.Test, t.Domain), cert(t.Test, t.Domain+".acme.invalid")
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert, acmeCert},
+ }
+ cfg.BuildNameToCertificate()
+
+ go func() {
+ for {
+ rawConn, err := t.Listener.Accept()
+ if err != nil {
+ return // assume Close()
+ }
+
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, t.Domain); err != nil {
+ t.Test.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+}
+
+func (t *tlsServer) Close() {
+ t.Listener.Close()
+}
+
+// cert creates a well-formed, but completely insecure self-signed
+// cert for domain.
+func cert(t *testing.T, domain string) tls.Certificate {
+ private, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ t.Fatal(err)
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domain,
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return tlscert
+}
+
+// newTLSServer starts a TLS server that serves a self-signed cert for
+// domain, and a corresonding acme.invalid dummy domain.
+func newTLSServer(t *testing.T, domain string) net.Listener {
+ cert, acmeCert := cert(t, domain), cert(t, domain+".acme.invalid")
+
+ l := newLocalListener(t)
+ go func() {
+ for {
+ rawConn, err := l.Accept()
+ if err != nil {
+ return // assume closed
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert, acmeCert},
+ }
+ cfg.BuildNameToCertificate()
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, domain); err != nil {
+ t.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+
+ return l
+}
+
+func readTLS(dest, domain string) (string, error) {
+ conn, err := tls.Dial("tcp", dest, &tls.Config{
+ ServerName: domain,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return "", err
+ }
+ defer conn.Close()
+
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", err
+ }
+ return string(bs), nil
+}
+
+func TestProxyACME(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newTLSServer(t, "foo.com")
+ defer backFoo.Close()
+ backBar := newTLSServer(t, "bar.com")
+ defer backBar.Close()
+ backQuux := newTLSServer(t, "quux.com")
+ defer backQuux.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ p.AddStopACMESearch(testFrontAddr)
+ p.AddSNIRoute(testFrontAddr, "quux.com", To(backQuux.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ domain, want string
+ succeeds bool
+ }{
+ {"foo.com", "foo.com", true},
+ {"bar.com", "bar.com", true},
+ {"quux.com", "quux.com", true},
+ {"xyzzy.com", "", false},
+ {"foo.com.acme.invalid", "foo.com", true},
+ {"bar.com.acme.invalid", "bar.com", true},
+ {"quux.com.acme.invalid", "", false},
+ }
+ for _, test := range tests {
+ got, err := readTLS(front.Addr().String(), test.domain)
+ if test.succeeds {
+ if err != nil {
+ t.Fatalf("readTLS %q got error %q, want nil", test.domain, err)
+ }
+ if got != test.want {
+ t.Fatalf("readTLS %q got %q, want %q", test.domain, got, test.want)
+ }
+ } else if err == nil {
+ t.Fatalf("readTLS %q unexpectedly succeeded", test.domain)
+ }
+ }
+}