summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--acme.go105
-rw-r--r--config.go18
-rw-r--r--e2e_test.go193
-rw-r--r--main.go40
4 files changed, 346 insertions, 10 deletions
diff --git a/acme.go b/acme.go
new file mode 100644
index 0000000..bbde2e9
--- /dev/null
+++ b/acme.go
@@ -0,0 +1,105 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net"
+ "time"
+)
+
+type acmeCacheEntry struct {
+ backend string
+ expires time.Time
+}
+
+type ACME struct {
+ backends []string
+ // *.acme.invalid domain to cache entry
+ cache map[string]acmeCacheEntry
+}
+
+func (s *ACME) Match(hostname string) string {
+ c := s.cache[hostname]
+ if time.Now().Before(c.expires) {
+ return c.backend
+ }
+
+ // Cache entry is either expired or invalid, need to figure out
+ // which backend is the right one.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ ch := make(chan string, len(s.backends))
+ for _, backend := range s.backends {
+ go tryAcme(ctx, ch, backend, hostname)
+ }
+ for range s.backends {
+ backend := <-ch
+ if backend != "" {
+ s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)}
+ return backend
+ }
+ }
+
+ // No usable backends found :(
+ s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)}
+ return ""
+}
+
+func tryAcme(ctx context.Context, ch chan string, backend, hostname string) {
+ var res string
+ var err error
+ defer func() { ch <- res }()
+ defer func() {
+ if err != nil {
+ fmt.Println(err)
+ }
+ }()
+
+ dialer := net.Dialer{Timeout: 10 * time.Second}
+ conn, err := dialer.DialContext(ctx, "tcp", backend)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ deadline, ok := ctx.Deadline()
+ if ok {
+ conn.SetDeadline(deadline)
+ }
+ client := tls.Client(conn, &tls.Config{
+ ServerName: hostname,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return
+ }
+ if err = client.Handshake(); err != nil {
+ return
+ }
+
+ certs := client.ConnectionState().PeerCertificates
+ if len(certs) == 0 {
+ return
+ }
+ if err = certs[0].VerifyHostname(hostname); err != nil {
+ return
+ }
+
+ res = backend
+}
diff --git a/config.go b/config.go
index c3e6d86..c205c22 100644
--- a/config.go
+++ b/config.go
@@ -16,6 +16,7 @@ package main
import (
"bufio"
+ "bytes"
"fmt"
"io"
"os"
@@ -34,6 +35,7 @@ type Route struct {
type Config struct {
mu sync.Mutex
routes []Route
+ acme *ACME
}
func dnsRegex(s string) (*regexp.Regexp, error) {
@@ -59,6 +61,11 @@ func dnsRegex(s string) (*regexp.Regexp, error) {
func (c *Config) Match(hostname string) string {
c.mu.Lock()
defer c.mu.Unlock()
+
+ if strings.HasSuffix(hostname, ".acme.invalid") {
+ return c.acme.Match(hostname)
+ }
+
for _, r := range c.routes {
if r.match.MatchString(hostname) {
return r.backend
@@ -70,6 +77,7 @@ func (c *Config) Match(hostname string) string {
// Read replaces the current Config with one read from r.
func (c *Config) Read(r io.Reader) error {
var routes []Route
+ var backends []string
s := bufio.NewScanner(r)
for s.Scan() {
@@ -90,6 +98,7 @@ func (c *Config) Read(r io.Reader) error {
return err
}
routes = append(routes, Route{re, fs[1]})
+ backends = append(backends, fs[1])
default:
// TODO: multiple backends?
return fmt.Errorf("too many fields on line: %q", s.Text())
@@ -102,6 +111,10 @@ func (c *Config) Read(r io.Reader) error {
c.mu.Lock()
defer c.mu.Unlock()
c.routes = routes
+ c.acme = &ACME{
+ backends: backends,
+ cache: make(map[string]acmeCacheEntry),
+ }
return nil
}
@@ -113,3 +126,8 @@ func (c *Config) ReadFile(path string) error {
}
return c.Read(f)
}
+
+func (c *Config) ReadString(cfg string) error {
+ b := bytes.NewBufferString(cfg)
+ return c.Read(b)
+}
diff --git a/e2e_test.go b/e2e_test.go
new file mode 100644
index 0000000..769e60c
--- /dev/null
+++ b/e2e_test.go
@@ -0,0 +1,193 @@
+package main
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestRouting(t *testing.T) {
+ // Two backend servers
+ s1, err := serveTLS(t, "server1", "test.com")
+ if err != nil {
+ t.Fatalf("serve TLS server1: %s", err)
+ }
+ defer s1.Close()
+
+ s2, err := serveTLS(t, "server2", "foo.net")
+ if err != nil {
+ t.Fatalf("serve TLS server2: %s", err)
+ }
+ defer s2.Close()
+
+ s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid")
+ if err != nil {
+ t.Fatalf("server TLS server3: %s", err)
+ }
+ defer s3.Close()
+
+ // One proxy
+ var p Proxy
+ l, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("create listener: %s", err)
+ }
+ defer l.Close()
+ go p.Serve(l)
+
+ if err := p.Config.ReadString(fmt.Sprintf(`
+test.com %s
+foo.net %s
+borkbork.tf %s
+`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil {
+ t.Fatalf("configure proxy: %s", err)
+ }
+
+ for _, test := range []struct {
+ N, V string
+ P *x509.CertPool
+ OK bool
+ }{
+ {"test.com", "server1", s1.Pool, true},
+ {"foo.net", "server2", s2.Pool, true},
+ {"bar.org", "", s1.Pool, false},
+ {"blarghblargh.acme.invalid", "server3", s3.Pool, true},
+ } {
+ res, err := getTLS(l.Addr().String(), test.N, test.P)
+ switch {
+ case test.OK && err != nil:
+ t.Fatalf("get %q failed: %s", test.N, err)
+ case !test.OK && err == nil:
+ t.Fatalf("get %q should have failed, but returned %q", test.N, res)
+ case test.OK && res != test.V:
+ t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ }
+ }
+}
+
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) {
+ cfg := tls.Config{
+ RootCAs: pool,
+ ServerName: domain,
+ }
+ conn, err := tls.Dial("tcp", addr, &cfg)
+ if err != nil {
+ return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ }
+ defer conn.Close()
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ return string(bs), nil
+}
+
+type tlsServer struct {
+ Domains []string
+ Value string
+ Pool *x509.CertPool
+ Test *testing.T
+ NumHits uint32
+ l net.Listener
+}
+
+func (s *tlsServer) Serve() {
+ for {
+ c, err := s.l.Accept()
+ if err != nil {
+ s.Test.Logf("accept failed on %q: %s", s.Domains, err)
+ return
+ }
+ atomic.AddUint32(&s.NumHits, 1)
+ c.Write([]byte(s.Value))
+ c.Close()
+ }
+}
+
+func (s *tlsServer) Addr() string {
+ return s.l.Addr().String()
+}
+
+func (s *tlsServer) Close() error {
+ return s.l.Close()
+}
+
+func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) {
+ cert, pool, err := selfSignedCert(domains)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+
+ l, err := tls.Listen("tcp", "localhost:0", cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ ret := &tlsServer{
+ Domains: domains,
+ Value: value,
+ Pool: pool,
+ Test: t,
+ l: l,
+ }
+ go ret.Serve()
+ return ret, nil
+}
+
+func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) {
+ pkey, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domains[0],
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: domains[1:],
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM(cert.Bytes()) {
+ return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains)
+ }
+
+ return tlscert, pool, nil
+}
diff --git a/main.go b/main.go
index c1ac324..cfdf5cc 100644
--- a/main.go
+++ b/main.go
@@ -31,34 +31,52 @@ var (
helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello")
)
-var config Config
-
func main() {
flag.Parse()
- if err := config.ReadFile(*cfgFile); err != nil {
+ p := &Proxy{}
+ if err := p.Config.ReadFile(*cfgFile); err != nil {
log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
}
- l, err := net.Listen("tcp", *listen)
- if err != nil {
- log.Fatalf("Failed to listen: %s", err)
- }
+ log.Fatalf("%s", p.ListenAndServe(*listen))
+}
+
+// Proxy routes connections to backends based on a Config.
+type Proxy struct {
+ Config Config
+ l net.Listener
+}
+// Serve accepts connections from l and routes them according to TLS SNI.
+func (p *Proxy) Serve(l net.Listener) error {
for {
c, err := l.Accept()
if err != nil {
- log.Fatalf("Error while accepting: %s", err)
+ return fmt.Errorf("accept new conn: %s", err)
}
- conn := &Conn{TCPConn: c.(*net.TCPConn)}
+ conn := &Conn{
+ TCPConn: c.(*net.TCPConn),
+ config: &p.Config,
+ }
go conn.proxy()
}
}
+// ListenAndServe creates a listener on addr calls Serve on it.
+func (p *Proxy) ListenAndServe(addr string) error {
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %s", err)
+ }
+ return p.Serve(l)
+}
+
// A Conn handles the TLS proxying of one user connection.
type Conn struct {
*net.TCPConn
+ config *Config
tlsMinor int
hostname string
@@ -109,12 +127,14 @@ func (c *Conn) proxy() {
return
}
+ c.logf("extracted SNI %s", c.hostname)
+
if err = c.SetReadDeadline(time.Time{}); err != nil {
c.internalError("Clearing read deadline for ClientHello: %s", err)
return
}
- c.backend = config.Match(c.hostname)
+ c.backend = c.config.Match(c.hostname)
if c.backend == "" {
c.sniFailed("no backend found for %q", c.hostname)
return