diff options
-rw-r--r-- | sni.go | 95 | ||||
-rw-r--r-- | tcpproxy.go | 47 | ||||
-rw-r--r-- | tcpproxy_test.go | 175 |
3 files changed, 302 insertions, 15 deletions
@@ -17,9 +17,11 @@ package tcpproxy import ( "bufio" "bytes" + "context" "crypto/tls" "io" "net" + "strings" ) // AddSNIRoute appends a route to the ipPort listener that says if the @@ -27,11 +29,31 @@ import ( // dest. If it doesn't match, rule processing continues for any // additional routes on ipPort. // +// By default, the proxy will route all ACME tls-sni-01 challenges +// received on ipPort to all SNI dests. You can disable ACME routing +// with AddStopACMESearch. +// // The ipPort is any valid net.Listen TCP address. func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { + cfg := p.configFor(ipPort) + if !cfg.stopACME { + if len(cfg.acmeTargets) == 0 { + p.addRoute(ipPort, &acmeMatch{cfg}) + } + cfg.acmeTargets = append(cfg.acmeTargets, dest) + } + p.addRoute(ipPort, sniMatch{sni, dest}) } +// AddStopACMESearch prevents ACME probing of subsequent SNI routes. +// Any ACME challenges on ipPort for SNI routes previously added +// before this call will still be proxied to all possible SNI +// backends. +func (p *Proxy) AddStopACMESearch(ipPort string) { + p.configFor(ipPort).stopACME = true +} + type sniMatch struct { sni string target Target @@ -44,6 +66,79 @@ func (m sniMatch) match(br *bufio.Reader) Target { return nil } +// acmeMatch matches "*.acme.invalid" ACME tls-sni-01 challenges and +// searches for a Target in cfg.acmeTargets that has the challenge +// response. +type acmeMatch struct { + cfg *config +} + +func (m *acmeMatch) match(br *bufio.Reader) Target { + sni := clientHelloServerName(br) + if !strings.HasSuffix(sni, ".acme.invalid") { + return nil + } + + // TODO: cache. ACME issuers will hit multiple times in a short + // burst for each issuance event. A short TTL cache + singleflight + // should have an excellent hit rate. + // TODO: maybe an acme-specific timeout as well? + // TODO: plumb context upwards? + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := make(chan Target, len(m.cfg.acmeTargets)) + for _, target := range m.cfg.acmeTargets { + go tryACME(ctx, ch, target, sni) + } + for range m.cfg.acmeTargets { + if target := <-ch; target != nil { + return target + } + } + + // No target was happy with the provided challenge. + return nil +} + +func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) { + var ret Target + defer func() { ch <- ret }() + + conn, targetConn := net.Pipe() + defer conn.Close() + go dest.HandleConn(targetConn) + + deadline, ok := ctx.Deadline() + if ok { + conn.SetDeadline(deadline) + } + + client := tls.Client(conn, &tls.Config{ + ServerName: sni, + InsecureSkipVerify: true, + }) + if err := client.Handshake(); err != nil { + // TODO: log? + return + } + certs := client.ConnectionState().PeerCertificates + if len(certs) == 0 { + // TODO: log? + return + } + // acme says the first cert offered by the server must match the + // challenge hostname. + if err := certs[0].VerifyHostname(sni); err != nil { + // TODO: log? + return + } + + // Target presented what looks like a valid challenge + // response, send it back to the matcher. + ret = dest +} + // clientHelloServerName returns the SNI server name inside the TLS ClientHello, // without consuming any bytes from br. // On any error, the empty string is returned. diff --git a/tcpproxy.go b/tcpproxy.go index 1eee7ea..02b70f5 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -69,7 +69,7 @@ import ( // The order that routes are added in matters; each is matched in the order // registered. type Proxy struct { - routes map[string][]route // ip:port => routes + configs map[string]*config // ip:port => config lns []net.Listener donec chan struct{} // closed before err @@ -81,7 +81,22 @@ type Proxy struct { ListenFunc func(net, laddr string) (net.Listener, error) } +// config contains the proxying state for one listener. +type config struct { + routes []route + acmeTargets []Target // accumulates targets that should be probed for acme. + stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets. +} + +// A route matches a connection to a target. type route interface { + // match examines the initial bytes of a connection, looking for a + // match. If a match is found, match returns a non-nil Target to + // which the stream should be proxied. match returns nil if the + // connection doesn't match. + // + // match must not consume bytes from the given bufio.Reader, it + // can only Peek. match(*bufio.Reader) Target } @@ -92,11 +107,19 @@ func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) { return net.Listen } -func (p *Proxy) addRoute(ipPort string, r route) { - if p.routes == nil { - p.routes = make(map[string][]route) +func (p *Proxy) configFor(ipPort string) *config { + if p.configs == nil { + p.configs = make(map[string]*config) } - p.routes[ipPort] = append(p.routes[ipPort], r) + if p.configs[ipPort] == nil { + p.configs[ipPort] = &config{} + } + return p.configs[ipPort] +} + +func (p *Proxy) addRoute(ipPort string, r route) { + cfg := p.configFor(ipPort) + cfg.routes = append(cfg.routes, r) } // AddRoute appends an always-matching route to the ipPort listener, @@ -107,14 +130,14 @@ func (p *Proxy) addRoute(ipPort string, r route) { // // The ipPort is any valid net.Listen TCP address. func (p *Proxy) AddRoute(ipPort string, dest Target) { - p.addRoute(ipPort, alwaysMatch{dest}) + p.addRoute(ipPort, fixedTarget{dest}) } -type alwaysMatch struct { +type fixedTarget struct { t Target } -func (m alwaysMatch) match(*bufio.Reader) Target { return m.t } +func (m fixedTarget) match(*bufio.Reader) Target { return m.t } // Run is calls Start, and then Wait. // @@ -155,16 +178,16 @@ func (p *Proxy) Start() error { return errors.New("already started") } p.donec = make(chan struct{}) - errc := make(chan error, len(p.routes)) - p.lns = make([]net.Listener, 0, len(p.routes)) - for ipPort, routes := range p.routes { + errc := make(chan error, len(p.configs)) + p.lns = make([]net.Listener, 0, len(p.configs)) + for ipPort, config := range p.configs { ln, err := p.netListen()("tcp", ipPort) if err != nil { p.Close() return err } p.lns = append(p.lns, ln) - go p.serveListener(errc, ln, routes) + go p.serveListener(errc, ln, config.routes) } go p.awaitFirstError(errc) return nil diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 7150372..ac7c917 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -17,19 +17,26 @@ package tcpproxy import ( "bufio" "bytes" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" "io/ioutil" + "math/big" "net" "strings" "testing" + "time" ) type noopTarget struct{} -func (t *noopTarget) HandleConn(net.Conn) {} +func (noopTarget) HandleConn(net.Conn) {} func TestMatchHTTPHost(t *testing.T) { tests := []struct { @@ -57,7 +64,6 @@ func TestMatchHTTPHost(t *testing.T) { want: true, }, } - target := &noopTarget{} for i, tt := range tests { name := tt.name if name == "" { @@ -65,7 +71,7 @@ func TestMatchHTTPHost(t *testing.T) { } t.Run(name, func(t *testing.T) { br := bufio.NewReader(tt.r) - r := httpHostMatch{tt.host, target} + r := httpHostMatch{tt.host, noopTarget{}} got := r.match(br) != nil if got != tt.want { t.Fatalf("match = %v; want %v", got, tt.want) @@ -313,3 +319,166 @@ func TestProxyPROXYOut(t *testing.T) { t.Fatalf("got %q; want %q", bs, want) } } + +type tlsServer struct { + Listener net.Listener + Domain string + Test *testing.T +} + +func (t *tlsServer) Start() { + cert, acmeCert := cert(t.Test, t.Domain), cert(t.Test, t.Domain+".acme.invalid") + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert, acmeCert}, + } + cfg.BuildNameToCertificate() + + go func() { + for { + rawConn, err := t.Listener.Accept() + if err != nil { + return // assume Close() + } + + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, t.Domain); err != nil { + t.Test.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() +} + +func (t *tlsServer) Close() { + t.Listener.Close() +} + +// cert creates a well-formed, but completely insecure self-signed +// cert for domain. +func cert(t *testing.T, domain string) tls.Certificate { + private, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domain, + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private) + if err != nil { + t.Fatal(err) + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + t.Fatal(err) + } + + return tlscert +} + +// newTLSServer starts a TLS server that serves a self-signed cert for +// domain, and a corresonding acme.invalid dummy domain. +func newTLSServer(t *testing.T, domain string) net.Listener { + cert, acmeCert := cert(t, domain), cert(t, domain+".acme.invalid") + + l := newLocalListener(t) + go func() { + for { + rawConn, err := l.Accept() + if err != nil { + return // assume closed + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert, acmeCert}, + } + cfg.BuildNameToCertificate() + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, domain); err != nil { + t.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() + + return l +} + +func readTLS(dest, domain string) (string, error) { + conn, err := tls.Dial("tcp", dest, &tls.Config{ + ServerName: domain, + InsecureSkipVerify: true, + }) + if err != nil { + return "", err + } + defer conn.Close() + + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", err + } + return string(bs), nil +} + +func TestProxyACME(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newTLSServer(t, "foo.com") + defer backFoo.Close() + backBar := newTLSServer(t, "bar.com") + defer backBar.Close() + backQuux := newTLSServer(t, "quux.com") + defer backQuux.Close() + + p := testProxy(t, front) + p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + p.AddStopACMESearch(testFrontAddr) + p.AddSNIRoute(testFrontAddr, "quux.com", To(backQuux.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + tests := []struct { + domain, want string + succeeds bool + }{ + {"foo.com", "foo.com", true}, + {"bar.com", "bar.com", true}, + {"quux.com", "quux.com", true}, + {"xyzzy.com", "", false}, + {"foo.com.acme.invalid", "foo.com", true}, + {"bar.com.acme.invalid", "bar.com", true}, + {"quux.com.acme.invalid", "", false}, + } + for _, test := range tests { + got, err := readTLS(front.Addr().String(), test.domain) + if test.succeeds { + if err != nil { + t.Fatalf("readTLS %q got error %q, want nil", test.domain, err) + } + if got != test.want { + t.Fatalf("readTLS %q got %q, want %q", test.domain, got, test.want) + } + } else if err == nil { + t.Fatalf("readTLS %q unexpectedly succeeded", test.domain) + } + } +} |