summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2017-07-06 01:08:29 -0700
committerDave Anderson <[email protected]>2017-07-06 21:36:14 -0700
commitc6a0996ce0f3db7b5c3e16e04c9e664936077c97 (patch)
tree74cda35a42bfd547c07032e9c37af6ddf4e2591e
parent815c9425f1ad46ffd3a3fb1bbefc05440072e4a4 (diff)
Support configurable routing of ACME tls-sni-01 challenges.
By design, the tls-sni-01 challenge does not reveal information about the domain being verified, so the proxy cannot "naively" route such requests. Instead, it probes the Targets of all SNI routes, looking for one that responds plausibly to the challenge hostname, and routes the client connection to that. ACME support can be turned off by inserting AddStopAcmeSearch in the route chain. Subsequently registered SNI routes will not be probed by ACME challenges.
-rw-r--r--sni.go95
-rw-r--r--tcpproxy.go47
-rw-r--r--tcpproxy_test.go175
3 files changed, 302 insertions, 15 deletions
diff --git a/sni.go b/sni.go
index f0128bf..50ab599 100644
--- a/sni.go
+++ b/sni.go
@@ -17,9 +17,11 @@ package tcpproxy
import (
"bufio"
"bytes"
+ "context"
"crypto/tls"
"io"
"net"
+ "strings"
)
// AddSNIRoute appends a route to the ipPort listener that says if the
@@ -27,11 +29,31 @@ import (
// dest. If it doesn't match, rule processing continues for any
// additional routes on ipPort.
//
+// By default, the proxy will route all ACME tls-sni-01 challenges
+// received on ipPort to all SNI dests. You can disable ACME routing
+// with AddStopACMESearch.
+//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ cfg := p.configFor(ipPort)
+ if !cfg.stopACME {
+ if len(cfg.acmeTargets) == 0 {
+ p.addRoute(ipPort, &acmeMatch{cfg})
+ }
+ cfg.acmeTargets = append(cfg.acmeTargets, dest)
+ }
+
p.addRoute(ipPort, sniMatch{sni, dest})
}
+// AddStopACMESearch prevents ACME probing of subsequent SNI routes.
+// Any ACME challenges on ipPort for SNI routes previously added
+// before this call will still be proxied to all possible SNI
+// backends.
+func (p *Proxy) AddStopACMESearch(ipPort string) {
+ p.configFor(ipPort).stopACME = true
+}
+
type sniMatch struct {
sni string
target Target
@@ -44,6 +66,79 @@ func (m sniMatch) match(br *bufio.Reader) Target {
return nil
}
+// acmeMatch matches "*.acme.invalid" ACME tls-sni-01 challenges and
+// searches for a Target in cfg.acmeTargets that has the challenge
+// response.
+type acmeMatch struct {
+ cfg *config
+}
+
+func (m *acmeMatch) match(br *bufio.Reader) Target {
+ sni := clientHelloServerName(br)
+ if !strings.HasSuffix(sni, ".acme.invalid") {
+ return nil
+ }
+
+ // TODO: cache. ACME issuers will hit multiple times in a short
+ // burst for each issuance event. A short TTL cache + singleflight
+ // should have an excellent hit rate.
+ // TODO: maybe an acme-specific timeout as well?
+ // TODO: plumb context upwards?
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ch := make(chan Target, len(m.cfg.acmeTargets))
+ for _, target := range m.cfg.acmeTargets {
+ go tryACME(ctx, ch, target, sni)
+ }
+ for range m.cfg.acmeTargets {
+ if target := <-ch; target != nil {
+ return target
+ }
+ }
+
+ // No target was happy with the provided challenge.
+ return nil
+}
+
+func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) {
+ var ret Target
+ defer func() { ch <- ret }()
+
+ conn, targetConn := net.Pipe()
+ defer conn.Close()
+ go dest.HandleConn(targetConn)
+
+ deadline, ok := ctx.Deadline()
+ if ok {
+ conn.SetDeadline(deadline)
+ }
+
+ client := tls.Client(conn, &tls.Config{
+ ServerName: sni,
+ InsecureSkipVerify: true,
+ })
+ if err := client.Handshake(); err != nil {
+ // TODO: log?
+ return
+ }
+ certs := client.ConnectionState().PeerCertificates
+ if len(certs) == 0 {
+ // TODO: log?
+ return
+ }
+ // acme says the first cert offered by the server must match the
+ // challenge hostname.
+ if err := certs[0].VerifyHostname(sni); err != nil {
+ // TODO: log?
+ return
+ }
+
+ // Target presented what looks like a valid challenge
+ // response, send it back to the matcher.
+ ret = dest
+}
+
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
// without consuming any bytes from br.
// On any error, the empty string is returned.
diff --git a/tcpproxy.go b/tcpproxy.go
index 1eee7ea..02b70f5 100644
--- a/tcpproxy.go
+++ b/tcpproxy.go
@@ -69,7 +69,7 @@ import (
// The order that routes are added in matters; each is matched in the order
// registered.
type Proxy struct {
- routes map[string][]route // ip:port => routes
+ configs map[string]*config // ip:port => config
lns []net.Listener
donec chan struct{} // closed before err
@@ -81,7 +81,22 @@ type Proxy struct {
ListenFunc func(net, laddr string) (net.Listener, error)
}
+// config contains the proxying state for one listener.
+type config struct {
+ routes []route
+ acmeTargets []Target // accumulates targets that should be probed for acme.
+ stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets.
+}
+
+// A route matches a connection to a target.
type route interface {
+ // match examines the initial bytes of a connection, looking for a
+ // match. If a match is found, match returns a non-nil Target to
+ // which the stream should be proxied. match returns nil if the
+ // connection doesn't match.
+ //
+ // match must not consume bytes from the given bufio.Reader, it
+ // can only Peek.
match(*bufio.Reader) Target
}
@@ -92,11 +107,19 @@ func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
return net.Listen
}
-func (p *Proxy) addRoute(ipPort string, r route) {
- if p.routes == nil {
- p.routes = make(map[string][]route)
+func (p *Proxy) configFor(ipPort string) *config {
+ if p.configs == nil {
+ p.configs = make(map[string]*config)
}
- p.routes[ipPort] = append(p.routes[ipPort], r)
+ if p.configs[ipPort] == nil {
+ p.configs[ipPort] = &config{}
+ }
+ return p.configs[ipPort]
+}
+
+func (p *Proxy) addRoute(ipPort string, r route) {
+ cfg := p.configFor(ipPort)
+ cfg.routes = append(cfg.routes, r)
}
// AddRoute appends an always-matching route to the ipPort listener,
@@ -107,14 +130,14 @@ func (p *Proxy) addRoute(ipPort string, r route) {
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
- p.addRoute(ipPort, alwaysMatch{dest})
+ p.addRoute(ipPort, fixedTarget{dest})
}
-type alwaysMatch struct {
+type fixedTarget struct {
t Target
}
-func (m alwaysMatch) match(*bufio.Reader) Target { return m.t }
+func (m fixedTarget) match(*bufio.Reader) Target { return m.t }
// Run is calls Start, and then Wait.
//
@@ -155,16 +178,16 @@ func (p *Proxy) Start() error {
return errors.New("already started")
}
p.donec = make(chan struct{})
- errc := make(chan error, len(p.routes))
- p.lns = make([]net.Listener, 0, len(p.routes))
- for ipPort, routes := range p.routes {
+ errc := make(chan error, len(p.configs))
+ p.lns = make([]net.Listener, 0, len(p.configs))
+ for ipPort, config := range p.configs {
ln, err := p.netListen()("tcp", ipPort)
if err != nil {
p.Close()
return err
}
p.lns = append(p.lns, ln)
- go p.serveListener(errc, ln, routes)
+ go p.serveListener(errc, ln, config.routes)
}
go p.awaitFirstError(errc)
return nil
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 7150372..ac7c917 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -17,19 +17,26 @@ package tcpproxy
import (
"bufio"
"bytes"
+ "crypto/rand"
+ "crypto/rsa"
"crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
+ "math/big"
"net"
"strings"
"testing"
+ "time"
)
type noopTarget struct{}
-func (t *noopTarget) HandleConn(net.Conn) {}
+func (noopTarget) HandleConn(net.Conn) {}
func TestMatchHTTPHost(t *testing.T) {
tests := []struct {
@@ -57,7 +64,6 @@ func TestMatchHTTPHost(t *testing.T) {
want: true,
},
}
- target := &noopTarget{}
for i, tt := range tests {
name := tt.name
if name == "" {
@@ -65,7 +71,7 @@ func TestMatchHTTPHost(t *testing.T) {
}
t.Run(name, func(t *testing.T) {
br := bufio.NewReader(tt.r)
- r := httpHostMatch{tt.host, target}
+ r := httpHostMatch{tt.host, noopTarget{}}
got := r.match(br) != nil
if got != tt.want {
t.Fatalf("match = %v; want %v", got, tt.want)
@@ -313,3 +319,166 @@ func TestProxyPROXYOut(t *testing.T) {
t.Fatalf("got %q; want %q", bs, want)
}
}
+
+type tlsServer struct {
+ Listener net.Listener
+ Domain string
+ Test *testing.T
+}
+
+func (t *tlsServer) Start() {
+ cert, acmeCert := cert(t.Test, t.Domain), cert(t.Test, t.Domain+".acme.invalid")
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert, acmeCert},
+ }
+ cfg.BuildNameToCertificate()
+
+ go func() {
+ for {
+ rawConn, err := t.Listener.Accept()
+ if err != nil {
+ return // assume Close()
+ }
+
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, t.Domain); err != nil {
+ t.Test.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+}
+
+func (t *tlsServer) Close() {
+ t.Listener.Close()
+}
+
+// cert creates a well-formed, but completely insecure self-signed
+// cert for domain.
+func cert(t *testing.T, domain string) tls.Certificate {
+ private, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ t.Fatal(err)
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domain,
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return tlscert
+}
+
+// newTLSServer starts a TLS server that serves a self-signed cert for
+// domain, and a corresonding acme.invalid dummy domain.
+func newTLSServer(t *testing.T, domain string) net.Listener {
+ cert, acmeCert := cert(t, domain), cert(t, domain+".acme.invalid")
+
+ l := newLocalListener(t)
+ go func() {
+ for {
+ rawConn, err := l.Accept()
+ if err != nil {
+ return // assume closed
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert, acmeCert},
+ }
+ cfg.BuildNameToCertificate()
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, domain); err != nil {
+ t.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+
+ return l
+}
+
+func readTLS(dest, domain string) (string, error) {
+ conn, err := tls.Dial("tcp", dest, &tls.Config{
+ ServerName: domain,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return "", err
+ }
+ defer conn.Close()
+
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", err
+ }
+ return string(bs), nil
+}
+
+func TestProxyACME(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newTLSServer(t, "foo.com")
+ defer backFoo.Close()
+ backBar := newTLSServer(t, "bar.com")
+ defer backBar.Close()
+ backQuux := newTLSServer(t, "quux.com")
+ defer backQuux.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ p.AddStopACMESearch(testFrontAddr)
+ p.AddSNIRoute(testFrontAddr, "quux.com", To(backQuux.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ domain, want string
+ succeeds bool
+ }{
+ {"foo.com", "foo.com", true},
+ {"bar.com", "bar.com", true},
+ {"quux.com", "quux.com", true},
+ {"xyzzy.com", "", false},
+ {"foo.com.acme.invalid", "foo.com", true},
+ {"bar.com.acme.invalid", "bar.com", true},
+ {"quux.com.acme.invalid", "", false},
+ }
+ for _, test := range tests {
+ got, err := readTLS(front.Addr().String(), test.domain)
+ if test.succeeds {
+ if err != nil {
+ t.Fatalf("readTLS %q got error %q, want nil", test.domain, err)
+ }
+ if got != test.want {
+ t.Fatalf("readTLS %q got %q, want %q", test.domain, got, test.want)
+ }
+ } else if err == nil {
+ t.Fatalf("readTLS %q unexpectedly succeeded", test.domain)
+ }
+ }
+}