summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2017-02-08 21:55:23 -0800
committerDavid Anderson <[email protected]>2017-02-08 21:55:23 -0800
commit728b8bce14d8241b090ecf89c7f48224d5ba2c74 (patch)
treefe410c0f29616cae6d2b37cbcf002f8e88f1c797
parenta5c2ccd532db7f26e6f6caff9570f126b9f58713 (diff)
Add ACME routing support.
TLS connections that look like ACME verification get fanned out to all known backends, and the one that responds with the right cert to continue ACME verification is the winner.
-rw-r--r--acme.go105
-rw-r--r--config.go18
-rw-r--r--e2e_test.go193
-rw-r--r--main.go40
4 files changed, 346 insertions, 10 deletions
diff --git a/acme.go b/acme.go
new file mode 100644
index 0000000..bbde2e9
--- /dev/null
+++ b/acme.go
@@ -0,0 +1,105 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net"
+ "time"
+)
+
+type acmeCacheEntry struct {
+ backend string
+ expires time.Time
+}
+
+type ACME struct {
+ backends []string
+ // *.acme.invalid domain to cache entry
+ cache map[string]acmeCacheEntry
+}
+
+func (s *ACME) Match(hostname string) string {
+ c := s.cache[hostname]
+ if time.Now().Before(c.expires) {
+ return c.backend
+ }
+
+ // Cache entry is either expired or invalid, need to figure out
+ // which backend is the right one.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ ch := make(chan string, len(s.backends))
+ for _, backend := range s.backends {
+ go tryAcme(ctx, ch, backend, hostname)
+ }
+ for range s.backends {
+ backend := <-ch
+ if backend != "" {
+ s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)}
+ return backend
+ }
+ }
+
+ // No usable backends found :(
+ s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)}
+ return ""
+}
+
+func tryAcme(ctx context.Context, ch chan string, backend, hostname string) {
+ var res string
+ var err error
+ defer func() { ch <- res }()
+ defer func() {
+ if err != nil {
+ fmt.Println(err)
+ }
+ }()
+
+ dialer := net.Dialer{Timeout: 10 * time.Second}
+ conn, err := dialer.DialContext(ctx, "tcp", backend)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ deadline, ok := ctx.Deadline()
+ if ok {
+ conn.SetDeadline(deadline)
+ }
+ client := tls.Client(conn, &tls.Config{
+ ServerName: hostname,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return
+ }
+ if err = client.Handshake(); err != nil {
+ return
+ }
+
+ certs := client.ConnectionState().PeerCertificates
+ if len(certs) == 0 {
+ return
+ }
+ if err = certs[0].VerifyHostname(hostname); err != nil {
+ return
+ }
+
+ res = backend
+}
diff --git a/config.go b/config.go
index c3e6d86..c205c22 100644
--- a/config.go
+++ b/config.go
@@ -16,6 +16,7 @@ package main
import (
"bufio"
+ "bytes"
"fmt"
"io"
"os"
@@ -34,6 +35,7 @@ type Route struct {
type Config struct {
mu sync.Mutex
routes []Route
+ acme *ACME
}
func dnsRegex(s string) (*regexp.Regexp, error) {
@@ -59,6 +61,11 @@ func dnsRegex(s string) (*regexp.Regexp, error) {
func (c *Config) Match(hostname string) string {
c.mu.Lock()
defer c.mu.Unlock()
+
+ if strings.HasSuffix(hostname, ".acme.invalid") {
+ return c.acme.Match(hostname)
+ }
+
for _, r := range c.routes {
if r.match.MatchString(hostname) {
return r.backend
@@ -70,6 +77,7 @@ func (c *Config) Match(hostname string) string {
// Read replaces the current Config with one read from r.
func (c *Config) Read(r io.Reader) error {
var routes []Route
+ var backends []string
s := bufio.NewScanner(r)
for s.Scan() {
@@ -90,6 +98,7 @@ func (c *Config) Read(r io.Reader) error {
return err
}
routes = append(routes, Route{re, fs[1]})
+ backends = append(backends, fs[1])
default:
// TODO: multiple backends?
return fmt.Errorf("too many fields on line: %q", s.Text())
@@ -102,6 +111,10 @@ func (c *Config) Read(r io.Reader) error {
c.mu.Lock()
defer c.mu.Unlock()
c.routes = routes
+ c.acme = &ACME{
+ backends: backends,
+ cache: make(map[string]acmeCacheEntry),
+ }
return nil
}
@@ -113,3 +126,8 @@ func (c *Config) ReadFile(path string) error {
}
return c.Read(f)
}
+
+func (c *Config) ReadString(cfg string) error {
+ b := bytes.NewBufferString(cfg)
+ return c.Read(b)
+}
diff --git a/e2e_test.go b/e2e_test.go
new file mode 100644
index 0000000..769e60c
--- /dev/null
+++ b/e2e_test.go
@@ -0,0 +1,193 @@
+package main
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestRouting(t *testing.T) {
+ // Two backend servers
+ s1, err := serveTLS(t, "server1", "test.com")
+ if err != nil {
+ t.Fatalf("serve TLS server1: %s", err)
+ }
+ defer s1.Close()
+
+ s2, err := serveTLS(t, "server2", "foo.net")
+ if err != nil {
+ t.Fatalf("serve TLS server2: %s", err)
+ }
+ defer s2.Close()
+
+ s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid")
+ if err != nil {
+ t.Fatalf("server TLS server3: %s", err)
+ }
+ defer s3.Close()
+
+ // One proxy
+ var p Proxy
+ l, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("create listener: %s", err)
+ }
+ defer l.Close()
+ go p.Serve(l)
+
+ if err := p.Config.ReadString(fmt.Sprintf(`
+test.com %s
+foo.net %s
+borkbork.tf %s
+`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil {
+ t.Fatalf("configure proxy: %s", err)
+ }
+
+ for _, test := range []struct {
+ N, V string
+ P *x509.CertPool
+ OK bool
+ }{
+ {"test.com", "server1", s1.Pool, true},
+ {"foo.net", "server2", s2.Pool, true},
+ {"bar.org", "", s1.Pool, false},
+ {"blarghblargh.acme.invalid", "server3", s3.Pool, true},
+ } {
+ res, err := getTLS(l.Addr().String(), test.N, test.P)
+ switch {
+ case test.OK && err != nil:
+ t.Fatalf("get %q failed: %s", test.N, err)
+ case !test.OK && err == nil:
+ t.Fatalf("get %q should have failed, but returned %q", test.N, res)
+ case test.OK && res != test.V:
+ t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ }
+ }
+}
+
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) {
+ cfg := tls.Config{
+ RootCAs: pool,
+ ServerName: domain,
+ }
+ conn, err := tls.Dial("tcp", addr, &cfg)
+ if err != nil {
+ return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ }
+ defer conn.Close()
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ return string(bs), nil
+}
+
+type tlsServer struct {
+ Domains []string
+ Value string
+ Pool *x509.CertPool
+ Test *testing.T
+ NumHits uint32
+ l net.Listener
+}
+
+func (s *tlsServer) Serve() {
+ for {
+ c, err := s.l.Accept()
+ if err != nil {
+ s.Test.Logf("accept failed on %q: %s", s.Domains, err)
+ return
+ }
+ atomic.AddUint32(&s.NumHits, 1)
+ c.Write([]byte(s.Value))
+ c.Close()
+ }
+}
+
+func (s *tlsServer) Addr() string {
+ return s.l.Addr().String()
+}
+
+func (s *tlsServer) Close() error {
+ return s.l.Close()
+}
+
+func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) {
+ cert, pool, err := selfSignedCert(domains)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+
+ l, err := tls.Listen("tcp", "localhost:0", cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ ret := &tlsServer{
+ Domains: domains,
+ Value: value,
+ Pool: pool,
+ Test: t,
+ l: l,
+ }
+ go ret.Serve()
+ return ret, nil
+}
+
+func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) {
+ pkey, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domains[0],
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: domains[1:],
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM(cert.Bytes()) {
+ return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains)
+ }
+
+ return tlscert, pool, nil
+}
diff --git a/main.go b/main.go
index c1ac324..cfdf5cc 100644
--- a/main.go
+++ b/main.go
@@ -31,34 +31,52 @@ var (
helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello")
)
-var config Config
-
func main() {
flag.Parse()
- if err := config.ReadFile(*cfgFile); err != nil {
+ p := &Proxy{}
+ if err := p.Config.ReadFile(*cfgFile); err != nil {
log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
}
- l, err := net.Listen("tcp", *listen)
- if err != nil {
- log.Fatalf("Failed to listen: %s", err)
- }
+ log.Fatalf("%s", p.ListenAndServe(*listen))
+}
+
+// Proxy routes connections to backends based on a Config.
+type Proxy struct {
+ Config Config
+ l net.Listener
+}
+// Serve accepts connections from l and routes them according to TLS SNI.
+func (p *Proxy) Serve(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
- log.Fatalf("Error while accepting: %s", err)
+ return fmt.Errorf("accept new conn: %s", err)
}
- conn := &Conn{TCPConn: c.(*net.TCPConn)}
+ conn := &Conn{
+ TCPConn: c.(*net.TCPConn),
+ config: &p.Config,
+ }
go conn.proxy()
}
}
+// ListenAndServe creates a listener on addr calls Serve on it.
+func (p *Proxy) ListenAndServe(addr string) error {
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %s", err)
+ }
+ return p.Serve(l)
+}
+
// A Conn handles the TLS proxying of one user connection.
type Conn struct {
*net.TCPConn
+ config *Config
tlsMinor int
hostname string
@@ -109,12 +127,14 @@ func (c *Conn) proxy() {
return
}
+ c.logf("extracted SNI %s", c.hostname)
+
if err = c.SetReadDeadline(time.Time{}); err != nil {
c.internalError("Clearing read deadline for ClientHello: %s", err)
return
}
- c.backend = config.Match(c.hostname)
+ c.backend = c.config.Match(c.hostname)
if c.backend == "" {
c.sniFailed("no backend found for %q", c.hostname)
return