diff options
author | David Anderson <[email protected]> | 2017-02-08 21:55:23 -0800 |
---|---|---|
committer | David Anderson <[email protected]> | 2017-02-08 21:55:23 -0800 |
commit | 728b8bce14d8241b090ecf89c7f48224d5ba2c74 (patch) | |
tree | fe410c0f29616cae6d2b37cbcf002f8e88f1c797 | |
parent | a5c2ccd532db7f26e6f6caff9570f126b9f58713 (diff) |
Add ACME routing support.
TLS connections that look like ACME verification get fanned out to
all known backends, and the one that responds with the right cert
to continue ACME verification is the winner.
-rw-r--r-- | acme.go | 105 | ||||
-rw-r--r-- | config.go | 18 | ||||
-rw-r--r-- | e2e_test.go | 193 | ||||
-rw-r--r-- | main.go | 40 |
4 files changed, 346 insertions, 10 deletions
@@ -0,0 +1,105 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "time" +) + +type acmeCacheEntry struct { + backend string + expires time.Time +} + +type ACME struct { + backends []string + // *.acme.invalid domain to cache entry + cache map[string]acmeCacheEntry +} + +func (s *ACME) Match(hostname string) string { + c := s.cache[hostname] + if time.Now().Before(c.expires) { + return c.backend + } + + // Cache entry is either expired or invalid, need to figure out + // which backend is the right one. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ch := make(chan string, len(s.backends)) + for _, backend := range s.backends { + go tryAcme(ctx, ch, backend, hostname) + } + for range s.backends { + backend := <-ch + if backend != "" { + s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)} + return backend + } + } + + // No usable backends found :( + s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)} + return "" +} + +func tryAcme(ctx context.Context, ch chan string, backend, hostname string) { + var res string + var err error + defer func() { ch <- res }() + defer func() { + if err != nil { + fmt.Println(err) + } + }() + + dialer := net.Dialer{Timeout: 10 * time.Second} + conn, err := dialer.DialContext(ctx, "tcp", backend) + if err != nil { + return + } + defer conn.Close() + + deadline, ok := ctx.Deadline() + if ok { + conn.SetDeadline(deadline) + } + client := tls.Client(conn, &tls.Config{ + ServerName: hostname, + InsecureSkipVerify: true, + }) + if err != nil { + return + } + if err = client.Handshake(); err != nil { + return + } + + certs := client.ConnectionState().PeerCertificates + if len(certs) == 0 { + return + } + if err = certs[0].VerifyHostname(hostname); err != nil { + return + } + + res = backend +} @@ -16,6 +16,7 @@ package main import ( "bufio" + "bytes" "fmt" "io" "os" @@ -34,6 +35,7 @@ type Route struct { type Config struct { mu sync.Mutex routes []Route + acme *ACME } func dnsRegex(s string) (*regexp.Regexp, error) { @@ -59,6 +61,11 @@ func dnsRegex(s string) (*regexp.Regexp, error) { func (c *Config) Match(hostname string) string { c.mu.Lock() defer c.mu.Unlock() + + if strings.HasSuffix(hostname, ".acme.invalid") { + return c.acme.Match(hostname) + } + for _, r := range c.routes { if r.match.MatchString(hostname) { return r.backend @@ -70,6 +77,7 @@ func (c *Config) Match(hostname string) string { // Read replaces the current Config with one read from r. func (c *Config) Read(r io.Reader) error { var routes []Route + var backends []string s := bufio.NewScanner(r) for s.Scan() { @@ -90,6 +98,7 @@ func (c *Config) Read(r io.Reader) error { return err } routes = append(routes, Route{re, fs[1]}) + backends = append(backends, fs[1]) default: // TODO: multiple backends? return fmt.Errorf("too many fields on line: %q", s.Text()) @@ -102,6 +111,10 @@ func (c *Config) Read(r io.Reader) error { c.mu.Lock() defer c.mu.Unlock() c.routes = routes + c.acme = &ACME{ + backends: backends, + cache: make(map[string]acmeCacheEntry), + } return nil } @@ -113,3 +126,8 @@ func (c *Config) ReadFile(path string) error { } return c.Read(f) } + +func (c *Config) ReadString(cfg string) error { + b := bytes.NewBufferString(cfg) + return c.Read(b) +} diff --git a/e2e_test.go b/e2e_test.go new file mode 100644 index 0000000..769e60c --- /dev/null +++ b/e2e_test.go @@ -0,0 +1,193 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "sync/atomic" + "testing" + "time" +) + +func TestRouting(t *testing.T) { + // Two backend servers + s1, err := serveTLS(t, "server1", "test.com") + if err != nil { + t.Fatalf("serve TLS server1: %s", err) + } + defer s1.Close() + + s2, err := serveTLS(t, "server2", "foo.net") + if err != nil { + t.Fatalf("serve TLS server2: %s", err) + } + defer s2.Close() + + s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid") + if err != nil { + t.Fatalf("server TLS server3: %s", err) + } + defer s3.Close() + + // One proxy + var p Proxy + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("create listener: %s", err) + } + defer l.Close() + go p.Serve(l) + + if err := p.Config.ReadString(fmt.Sprintf(` +test.com %s +foo.net %s +borkbork.tf %s +`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil { + t.Fatalf("configure proxy: %s", err) + } + + for _, test := range []struct { + N, V string + P *x509.CertPool + OK bool + }{ + {"test.com", "server1", s1.Pool, true}, + {"foo.net", "server2", s2.Pool, true}, + {"bar.org", "", s1.Pool, false}, + {"blarghblargh.acme.invalid", "server3", s3.Pool, true}, + } { + res, err := getTLS(l.Addr().String(), test.N, test.P) + switch { + case test.OK && err != nil: + t.Fatalf("get %q failed: %s", test.N, err) + case !test.OK && err == nil: + t.Fatalf("get %q should have failed, but returned %q", test.N, res) + case test.OK && res != test.V: + t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) + } + } +} + +func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) { + cfg := tls.Config{ + RootCAs: pool, + ServerName: domain, + } + conn, err := tls.Dial("tcp", addr, &cfg) + if err != nil { + return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) + } + defer conn.Close() + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + } + return string(bs), nil +} + +type tlsServer struct { + Domains []string + Value string + Pool *x509.CertPool + Test *testing.T + NumHits uint32 + l net.Listener +} + +func (s *tlsServer) Serve() { + for { + c, err := s.l.Accept() + if err != nil { + s.Test.Logf("accept failed on %q: %s", s.Domains, err) + return + } + atomic.AddUint32(&s.NumHits, 1) + c.Write([]byte(s.Value)) + c.Close() + } +} + +func (s *tlsServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tlsServer) Close() error { + return s.l.Close() +} + +func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) { + cert, pool, err := selfSignedCert(domains) + if err != nil { + return nil, err + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + cfg.BuildNameToCertificate() + + l, err := tls.Listen("tcp", "localhost:0", cfg) + if err != nil { + return nil, err + } + + ret := &tlsServer{ + Domains: domains, + Value: value, + Pool: pool, + Test: t, + l: l, + } + go ret.Serve() + return ret, nil +} + +func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) { + pkey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + return tls.Certificate{}, nil, err + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domains[0], + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: domains[1:], + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey) + if err != nil { + return tls.Certificate{}, nil, err + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + return tls.Certificate{}, nil, err + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert.Bytes()) { + return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains) + } + + return tlscert, pool, nil +} @@ -31,34 +31,52 @@ var ( helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") ) -var config Config - func main() { flag.Parse() - if err := config.ReadFile(*cfgFile); err != nil { + p := &Proxy{} + if err := p.Config.ReadFile(*cfgFile); err != nil { log.Fatalf("Failed to read config %q: %s", *cfgFile, err) } - l, err := net.Listen("tcp", *listen) - if err != nil { - log.Fatalf("Failed to listen: %s", err) - } + log.Fatalf("%s", p.ListenAndServe(*listen)) +} + +// Proxy routes connections to backends based on a Config. +type Proxy struct { + Config Config + l net.Listener +} +// Serve accepts connections from l and routes them according to TLS SNI. +func (p *Proxy) Serve(l net.Listener) error { for { c, err := l.Accept() if err != nil { - log.Fatalf("Error while accepting: %s", err) + return fmt.Errorf("accept new conn: %s", err) } - conn := &Conn{TCPConn: c.(*net.TCPConn)} + conn := &Conn{ + TCPConn: c.(*net.TCPConn), + config: &p.Config, + } go conn.proxy() } } +// ListenAndServe creates a listener on addr calls Serve on it. +func (p *Proxy) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("create listener: %s", err) + } + return p.Serve(l) +} + // A Conn handles the TLS proxying of one user connection. type Conn struct { *net.TCPConn + config *Config tlsMinor int hostname string @@ -109,12 +127,14 @@ func (c *Conn) proxy() { return } + c.logf("extracted SNI %s", c.hostname) + if err = c.SetReadDeadline(time.Time{}); err != nil { c.internalError("Clearing read deadline for ClientHello: %s", err) return } - c.backend = config.Match(c.hostname) + c.backend = c.config.Match(c.hostname) if c.backend == "" { c.sniFailed("no backend found for %q", c.hostname) return |