diff options
author | David Anderson <[email protected]> | 2017-07-06 01:08:29 -0700 |
---|---|---|
committer | Dave Anderson <[email protected]> | 2017-07-06 21:36:14 -0700 |
commit | c6a0996ce0f3db7b5c3e16e04c9e664936077c97 (patch) | |
tree | 74cda35a42bfd547c07032e9c37af6ddf4e2591e | |
parent | 815c9425f1ad46ffd3a3fb1bbefc05440072e4a4 (diff) |
Support configurable routing of ACME tls-sni-01 challenges.
By design, the tls-sni-01 challenge does not reveal information
about the domain being verified, so the proxy cannot "naively" route
such requests. Instead, it probes the Targets of all SNI routes, looking
for one that responds plausibly to the challenge hostname, and routes the
client connection to that.
ACME support can be turned off by inserting AddStopAcmeSearch in the route
chain. Subsequently registered SNI routes will not be probed by ACME challenges.
-rw-r--r-- | sni.go | 95 | ||||
-rw-r--r-- | tcpproxy.go | 47 | ||||
-rw-r--r-- | tcpproxy_test.go | 175 |
3 files changed, 302 insertions, 15 deletions
@@ -17,9 +17,11 @@ package tcpproxy import ( "bufio" "bytes" + "context" "crypto/tls" "io" "net" + "strings" ) // AddSNIRoute appends a route to the ipPort listener that says if the @@ -27,11 +29,31 @@ import ( // dest. If it doesn't match, rule processing continues for any // additional routes on ipPort. // +// By default, the proxy will route all ACME tls-sni-01 challenges +// received on ipPort to all SNI dests. You can disable ACME routing +// with AddStopACMESearch. +// // The ipPort is any valid net.Listen TCP address. func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { + cfg := p.configFor(ipPort) + if !cfg.stopACME { + if len(cfg.acmeTargets) == 0 { + p.addRoute(ipPort, &acmeMatch{cfg}) + } + cfg.acmeTargets = append(cfg.acmeTargets, dest) + } + p.addRoute(ipPort, sniMatch{sni, dest}) } +// AddStopACMESearch prevents ACME probing of subsequent SNI routes. +// Any ACME challenges on ipPort for SNI routes previously added +// before this call will still be proxied to all possible SNI +// backends. +func (p *Proxy) AddStopACMESearch(ipPort string) { + p.configFor(ipPort).stopACME = true +} + type sniMatch struct { sni string target Target @@ -44,6 +66,79 @@ func (m sniMatch) match(br *bufio.Reader) Target { return nil } +// acmeMatch matches "*.acme.invalid" ACME tls-sni-01 challenges and +// searches for a Target in cfg.acmeTargets that has the challenge +// response. +type acmeMatch struct { + cfg *config +} + +func (m *acmeMatch) match(br *bufio.Reader) Target { + sni := clientHelloServerName(br) + if !strings.HasSuffix(sni, ".acme.invalid") { + return nil + } + + // TODO: cache. ACME issuers will hit multiple times in a short + // burst for each issuance event. A short TTL cache + singleflight + // should have an excellent hit rate. + // TODO: maybe an acme-specific timeout as well? + // TODO: plumb context upwards? + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan Target, len(m.cfg.acmeTargets)) + for _, target := range m.cfg.acmeTargets { + go tryACME(ctx, ch, target, sni) + } + for range m.cfg.acmeTargets { + if target := <-ch; target != nil { + return target + } + } + + // No target was happy with the provided challenge. + return nil +} + +func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) { + var ret Target + defer func() { ch <- ret }() + + conn, targetConn := net.Pipe() + defer conn.Close() + go dest.HandleConn(targetConn) + + deadline, ok := ctx.Deadline() + if ok { + conn.SetDeadline(deadline) + } + + client := tls.Client(conn, &tls.Config{ + ServerName: sni, + InsecureSkipVerify: true, + }) + if err := client.Handshake(); err != nil { + // TODO: log? + return + } + certs := client.ConnectionState().PeerCertificates + if len(certs) == 0 { + // TODO: log? + return + } + // acme says the first cert offered by the server must match the + // challenge hostname. + if err := certs[0].VerifyHostname(sni); err != nil { + // TODO: log? + return + } + + // Target presented what looks like a valid challenge + // response, send it back to the matcher. + ret = dest +} + // clientHelloServerName returns the SNI server name inside the TLS ClientHello, // without consuming any bytes from br. // On any error, the empty string is returned. diff --git a/tcpproxy.go b/tcpproxy.go index 1eee7ea..02b70f5 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -69,7 +69,7 @@ import ( // The order that routes are added in matters; each is matched in the order // registered. type Proxy struct { - routes map[string][]route // ip:port => routes + configs map[string]*config // ip:port => config lns []net.Listener donec chan struct{} // closed before err @@ -81,7 +81,22 @@ type Proxy struct { ListenFunc func(net, laddr string) (net.Listener, error) } +// config contains the proxying state for one listener. +type config struct { + routes []route + acmeTargets []Target // accumulates targets that should be probed for acme. + stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets. +} + +// A route matches a connection to a target. type route interface { + // match examines the initial bytes of a connection, looking for a + // match. If a match is found, match returns a non-nil Target to + // which the stream should be proxied. match returns nil if the + // connection doesn't match. + // + // match must not consume bytes from the given bufio.Reader, it + // can only Peek. match(*bufio.Reader) Target } @@ -92,11 +107,19 @@ func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { return net.Listen } -func (p *Proxy) addRoute(ipPort string, r route) { - if p.routes == nil { - p.routes = make(map[string][]route) +func (p *Proxy) configFor(ipPort string) *config { + if p.configs == nil { + p.configs = make(map[string]*config) } - p.routes[ipPort] = append(p.routes[ipPort], r) + if p.configs[ipPort] == nil { + p.configs[ipPort] = &config{} + } + return p.configs[ipPort] +} + +func (p *Proxy) addRoute(ipPort string, r route) { + cfg := p.configFor(ipPort) + cfg.routes = append(cfg.routes, r) } // AddRoute appends an always-matching route to the ipPort listener, @@ -107,14 +130,14 @@ func (p *Proxy) addRoute(ipPort string, r route) { // // The ipPort is any valid net.Listen TCP address. func (p *Proxy) AddRoute(ipPort string, dest Target) { - p.addRoute(ipPort, alwaysMatch{dest}) + p.addRoute(ipPort, fixedTarget{dest}) } -type alwaysMatch struct { +type fixedTarget struct { t Target } -func (m alwaysMatch) match(*bufio.Reader) Target { return m.t } +func (m fixedTarget) match(*bufio.Reader) Target { return m.t } // Run is calls Start, and then Wait. // @@ -155,16 +178,16 @@ func (p *Proxy) Start() error { return errors.New("already started") } p.donec = make(chan struct{}) - errc := make(chan error, len(p.routes)) - p.lns = make([]net.Listener, 0, len(p.routes)) - for ipPort, routes := range p.routes { + errc := make(chan error, len(p.configs)) + p.lns = make([]net.Listener, 0, len(p.configs)) + for ipPort, config := range p.configs { ln, err := p.netListen()("tcp", ipPort) if err != nil { p.Close() return err } p.lns = append(p.lns, ln) - go p.serveListener(errc, ln, routes) + go p.serveListener(errc, ln, config.routes) } go p.awaitFirstError(errc) return nil diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 7150372..ac7c917 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -17,19 +17,26 @@ package tcpproxy import ( "bufio" "bytes" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" "io/ioutil" + "math/big" "net" "strings" "testing" + "time" ) type noopTarget struct{} -func (t *noopTarget) HandleConn(net.Conn) {} +func (noopTarget) HandleConn(net.Conn) {} func TestMatchHTTPHost(t *testing.T) { tests := []struct { @@ -57,7 +64,6 @@ func TestMatchHTTPHost(t *testing.T) { want: true, }, } - target := &noopTarget{} for i, tt := range tests { name := tt.name if name == "" { @@ -65,7 +71,7 @@ func TestMatchHTTPHost(t *testing.T) { } t.Run(name, func(t *testing.T) { br := bufio.NewReader(tt.r) - r := httpHostMatch{tt.host, target} + r := httpHostMatch{tt.host, noopTarget{}} got := r.match(br) != nil if got != tt.want { t.Fatalf("match = %v; want %v", got, tt.want) @@ -313,3 +319,166 @@ func TestProxyPROXYOut(t *testing.T) { t.Fatalf("got %q; want %q", bs, want) } } + +type tlsServer struct { + Listener net.Listener + Domain string + Test *testing.T +} + +func (t *tlsServer) Start() { + cert, acmeCert := cert(t.Test, t.Domain), cert(t.Test, t.Domain+".acme.invalid") + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert, acmeCert}, + } + cfg.BuildNameToCertificate() + + go func() { + for { + rawConn, err := t.Listener.Accept() + if err != nil { + return // assume Close() + } + + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, t.Domain); err != nil { + t.Test.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() +} + +func (t *tlsServer) Close() { + t.Listener.Close() +} + +// cert creates a well-formed, but completely insecure self-signed +// cert for domain. +func cert(t *testing.T, domain string) tls.Certificate { + private, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domain, + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private) + if err != nil { + t.Fatal(err) + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + t.Fatal(err) + } + + return tlscert +} + +// newTLSServer starts a TLS server that serves a self-signed cert for +// domain, and a corresonding acme.invalid dummy domain. +func newTLSServer(t *testing.T, domain string) net.Listener { + cert, acmeCert := cert(t, domain), cert(t, domain+".acme.invalid") + + l := newLocalListener(t) + go func() { + for { + rawConn, err := l.Accept() + if err != nil { + return // assume closed + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert, acmeCert}, + } + cfg.BuildNameToCertificate() + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, domain); err != nil { + t.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() + + return l +} + +func readTLS(dest, domain string) (string, error) { + conn, err := tls.Dial("tcp", dest, &tls.Config{ + ServerName: domain, + InsecureSkipVerify: true, + }) + if err != nil { + return "", err + } + defer conn.Close() + + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", err + } + return string(bs), nil +} + +func TestProxyACME(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newTLSServer(t, "foo.com") + defer backFoo.Close() + backBar := newTLSServer(t, "bar.com") + defer backBar.Close() + backQuux := newTLSServer(t, "quux.com") + defer backQuux.Close() + + p := testProxy(t, front) + p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + p.AddStopACMESearch(testFrontAddr) + p.AddSNIRoute(testFrontAddr, "quux.com", To(backQuux.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + tests := []struct { + domain, want string + succeeds bool + }{ + {"foo.com", "foo.com", true}, + {"bar.com", "bar.com", true}, + {"quux.com", "quux.com", true}, + {"xyzzy.com", "", false}, + {"foo.com.acme.invalid", "foo.com", true}, + {"bar.com.acme.invalid", "bar.com", true}, + {"quux.com.acme.invalid", "", false}, + } + for _, test := range tests { + got, err := readTLS(front.Addr().String(), test.domain) + if test.succeeds { + if err != nil { + t.Fatalf("readTLS %q got error %q, want nil", test.domain, err) + } + if got != test.want { + t.Fatalf("readTLS %q got %q, want %q", test.domain, got, test.want) + } + } else if err == nil { + t.Fatalf("readTLS %q unexpectedly succeeded", test.domain) + } + } +} |