diff options
-rw-r--r-- | acme.go | 105 | ||||
-rw-r--r-- | config.go | 18 | ||||
-rw-r--r-- | e2e_test.go | 193 | ||||
-rw-r--r-- | main.go | 40 |
4 files changed, 346 insertions, 10 deletions
@@ -0,0 +1,105 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "time" +) + +type acmeCacheEntry struct { + backend string + expires time.Time +} + +type ACME struct { + backends []string + // *.acme.invalid domain to cache entry + cache map[string]acmeCacheEntry +} + +func (s *ACME) Match(hostname string) string { + c := s.cache[hostname] + if time.Now().Before(c.expires) { + return c.backend + } + + // Cache entry is either expired or invalid, need to figure out + // which backend is the right one. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ch := make(chan string, len(s.backends)) + for _, backend := range s.backends { + go tryAcme(ctx, ch, backend, hostname) + } + for range s.backends { + backend := <-ch + if backend != "" { + s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)} + return backend + } + } + + // No usable backends found :( + s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)} + return "" +} + +func tryAcme(ctx context.Context, ch chan string, backend, hostname string) { + var res string + var err error + defer func() { ch <- res }() + defer func() { + if err != nil { + fmt.Println(err) + } + }() + + dialer := net.Dialer{Timeout: 10 * time.Second} + conn, err := dialer.DialContext(ctx, "tcp", backend) + if err != nil { + return + } + defer conn.Close() + + deadline, ok := ctx.Deadline() + if ok { + conn.SetDeadline(deadline) + } + client := tls.Client(conn, &tls.Config{ + ServerName: hostname, + InsecureSkipVerify: true, + }) + if err != nil { + return + } + if err = client.Handshake(); err != nil { + return + } + + certs := client.ConnectionState().PeerCertificates + if len(certs) == 0 { + return + } + if err = certs[0].VerifyHostname(hostname); err != nil { + return + } + + res = backend +} @@ -16,6 +16,7 @@ package main import ( "bufio" + "bytes" "fmt" "io" "os" @@ -34,6 +35,7 @@ type Route struct { type Config struct { mu sync.Mutex routes []Route + acme *ACME } func dnsRegex(s string) (*regexp.Regexp, error) { @@ -59,6 +61,11 @@ func dnsRegex(s string) (*regexp.Regexp, error) { func (c *Config) Match(hostname string) string { c.mu.Lock() defer c.mu.Unlock() + + if strings.HasSuffix(hostname, ".acme.invalid") { + return c.acme.Match(hostname) + } + for _, r := range c.routes { if r.match.MatchString(hostname) { return r.backend @@ -70,6 +77,7 @@ func (c *Config) Match(hostname string) string { // Read replaces the current Config with one read from r. func (c *Config) Read(r io.Reader) error { var routes []Route + var backends []string s := bufio.NewScanner(r) for s.Scan() { @@ -90,6 +98,7 @@ func (c *Config) Read(r io.Reader) error { return err } routes = append(routes, Route{re, fs[1]}) + backends = append(backends, fs[1]) default: // TODO: multiple backends? return fmt.Errorf("too many fields on line: %q", s.Text()) @@ -102,6 +111,10 @@ func (c *Config) Read(r io.Reader) error { c.mu.Lock() defer c.mu.Unlock() c.routes = routes + c.acme = &ACME{ + backends: backends, + cache: make(map[string]acmeCacheEntry), + } return nil } @@ -113,3 +126,8 @@ func (c *Config) ReadFile(path string) error { } return c.Read(f) } + +func (c *Config) ReadString(cfg string) error { + b := bytes.NewBufferString(cfg) + return c.Read(b) +} diff --git a/e2e_test.go b/e2e_test.go new file mode 100644 index 0000000..769e60c --- /dev/null +++ b/e2e_test.go @@ -0,0 +1,193 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "sync/atomic" + "testing" + "time" +) + +func TestRouting(t *testing.T) { + // Two backend servers + s1, err := serveTLS(t, "server1", "test.com") + if err != nil { + t.Fatalf("serve TLS server1: %s", err) + } + defer s1.Close() + + s2, err := serveTLS(t, "server2", "foo.net") + if err != nil { + t.Fatalf("serve TLS server2: %s", err) + } + defer s2.Close() + + s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid") + if err != nil { + t.Fatalf("server TLS server3: %s", err) + } + defer s3.Close() + + // One proxy + var p Proxy + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("create listener: %s", err) + } + defer l.Close() + go p.Serve(l) + + if err := p.Config.ReadString(fmt.Sprintf(` +test.com %s +foo.net %s +borkbork.tf %s +`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil { + t.Fatalf("configure proxy: %s", err) + } + + for _, test := range []struct { + N, V string + P *x509.CertPool + OK bool + }{ + {"test.com", "server1", s1.Pool, true}, + {"foo.net", "server2", s2.Pool, true}, + {"bar.org", "", s1.Pool, false}, + {"blarghblargh.acme.invalid", "server3", s3.Pool, true}, + } { + res, err := getTLS(l.Addr().String(), test.N, test.P) + switch { + case test.OK && err != nil: + t.Fatalf("get %q failed: %s", test.N, err) + case !test.OK && err == nil: + t.Fatalf("get %q should have failed, but returned %q", test.N, res) + case test.OK && res != test.V: + t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) + } + } +} + +func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) { + cfg := tls.Config{ + RootCAs: pool, + ServerName: domain, + } + conn, err := tls.Dial("tcp", addr, &cfg) + if err != nil { + return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) + } + defer conn.Close() + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + } + return string(bs), nil +} + +type tlsServer struct { + Domains []string + Value string + Pool *x509.CertPool + Test *testing.T + NumHits uint32 + l net.Listener +} + +func (s *tlsServer) Serve() { + for { + c, err := s.l.Accept() + if err != nil { + s.Test.Logf("accept failed on %q: %s", s.Domains, err) + return + } + atomic.AddUint32(&s.NumHits, 1) + c.Write([]byte(s.Value)) + c.Close() + } +} + +func (s *tlsServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tlsServer) Close() error { + return s.l.Close() +} + +func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) { + cert, pool, err := selfSignedCert(domains) + if err != nil { + return nil, err + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + cfg.BuildNameToCertificate() + + l, err := tls.Listen("tcp", "localhost:0", cfg) + if err != nil { + return nil, err + } + + ret := &tlsServer{ + Domains: domains, + Value: value, + Pool: pool, + Test: t, + l: l, + } + go ret.Serve() + return ret, nil +} + +func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) { + pkey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + return tls.Certificate{}, nil, err + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domains[0], + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: domains[1:], + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey) + if err != nil { + return tls.Certificate{}, nil, err + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + return tls.Certificate{}, nil, err + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert.Bytes()) { + return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains) + } + + return tlscert, pool, nil +} @@ -31,34 +31,52 @@ var ( helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") ) -var config Config - func main() { flag.Parse() - if err := config.ReadFile(*cfgFile); err != nil { + p := &Proxy{} + if err := p.Config.ReadFile(*cfgFile); err != nil { log.Fatalf("Failed to read config %q: %s", *cfgFile, err) } - l, err := net.Listen("tcp", *listen) - if err != nil { - log.Fatalf("Failed to listen: %s", err) - } + log.Fatalf("%s", p.ListenAndServe(*listen)) +} + +// Proxy routes connections to backends based on a Config. +type Proxy struct { + Config Config + l net.Listener +} +// Serve accepts connections from l and routes them according to TLS SNI. +func (p *Proxy) Serve(l net.Listener) error { for { c, err := l.Accept() if err != nil { - log.Fatalf("Error while accepting: %s", err) + return fmt.Errorf("accept new conn: %s", err) } - conn := &Conn{TCPConn: c.(*net.TCPConn)} + conn := &Conn{ + TCPConn: c.(*net.TCPConn), + config: &p.Config, + } go conn.proxy() } } +// ListenAndServe creates a listener on addr calls Serve on it. +func (p *Proxy) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("create listener: %s", err) + } + return p.Serve(l) +} + // A Conn handles the TLS proxying of one user connection. type Conn struct { *net.TCPConn + config *Config tlsMinor int hostname string @@ -109,12 +127,14 @@ func (c *Conn) proxy() { return } + c.logf("extracted SNI %s", c.hostname) + if err = c.SetReadDeadline(time.Time{}); err != nil { c.internalError("Clearing read deadline for ClientHello: %s", err) return } - c.backend = config.Match(c.hostname) + c.backend = c.config.Match(c.hostname) if c.backend == "" { c.sniFailed("no backend found for %q", c.hostname) return |