summaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2017-07-02 14:46:05 -0700
committerDavid Anderson <[email protected]>2017-07-02 14:46:05 -0700
commit3eb49e9b3902de95b3c9f5729d51ca7f61f02e5a (patch)
tree0f7c2ea9e93dfff7c63c7c35a0531a582eff1e42 /cmd
parentc58b44c4fc69a3602d751d679c69c07e6bcbe24a (diff)
Move tlsrouter to cmd/tlsrouter, in preparation for rewrite as a pkg.
Diffstat (limited to 'cmd')
-rw-r--r--cmd/tlsrouter/acme.go101
-rw-r--r--cmd/tlsrouter/config.go146
-rw-r--r--cmd/tlsrouter/config_test.go61
-rw-r--r--cmd/tlsrouter/e2e_test.go224
-rw-r--r--cmd/tlsrouter/main.go191
-rw-r--r--cmd/tlsrouter/sni.go232
-rw-r--r--cmd/tlsrouter/sni_test.go456
7 files changed, 1411 insertions, 0 deletions
diff --git a/cmd/tlsrouter/acme.go b/cmd/tlsrouter/acme.go
new file mode 100644
index 0000000..ab8d59a
--- /dev/null
+++ b/cmd/tlsrouter/acme.go
@@ -0,0 +1,101 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "net"
+ "time"
+)
+
+type acmeCacheEntry struct {
+ backend string
+ expires time.Time
+}
+
+// ACME locates backends that are attempting ACME SNI-based validation.
+type ACME struct {
+ backends []string
+ // *.acme.invalid domain to cache entry
+ cache map[string]acmeCacheEntry
+}
+
+// Match returns the backend for hostname, if one is found.
+func (s *ACME) Match(hostname string) string {
+ c := s.cache[hostname]
+ if time.Now().Before(c.expires) {
+ return c.backend
+ }
+
+ // Cache entry is either expired or invalid, need to figure out
+ // which backend is the right one.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ ch := make(chan string, len(s.backends))
+ for _, backend := range s.backends {
+ go tryAcme(ctx, ch, backend, hostname)
+ }
+ for range s.backends {
+ backend := <-ch
+ if backend != "" {
+ s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)}
+ return backend
+ }
+ }
+
+ // No usable backends found :(
+ s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)}
+ return ""
+}
+
+func tryAcme(ctx context.Context, ch chan string, backend, hostname string) {
+ var res string
+ var err error
+ defer func() { ch <- res }()
+
+ dialer := net.Dialer{Timeout: 10 * time.Second}
+ conn, err := dialer.DialContext(ctx, "tcp", backend)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ deadline, ok := ctx.Deadline()
+ if ok {
+ conn.SetDeadline(deadline)
+ }
+ client := tls.Client(conn, &tls.Config{
+ ServerName: hostname,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return
+ }
+ if err = client.Handshake(); err != nil {
+ return
+ }
+
+ certs := client.ConnectionState().PeerCertificates
+ if len(certs) == 0 {
+ return
+ }
+ if err = certs[0].VerifyHostname(hostname); err != nil {
+ return
+ }
+
+ res = backend
+}
diff --git a/cmd/tlsrouter/config.go b/cmd/tlsrouter/config.go
new file mode 100644
index 0000000..1c8151f
--- /dev/null
+++ b/cmd/tlsrouter/config.go
@@ -0,0 +1,146 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "regexp"
+ "strings"
+ "sync"
+)
+
+// A Route maps a match on a domain name to a backend.
+type Route struct {
+ match *regexp.Regexp
+ backend string
+ proxyInfo bool
+}
+
+// Config stores the TLS routing configuration.
+type Config struct {
+ mu sync.Mutex
+ routes []Route
+ acme *ACME
+}
+
+func dnsRegex(s string) (*regexp.Regexp, error) {
+ if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' {
+ return regexp.Compile(s[1 : len(s)-1])
+ }
+
+ var b []string
+ for _, f := range strings.Split(s, ".") {
+ switch f {
+ case "*":
+ b = append(b, `[^.]+`)
+ case "":
+ return nil, fmt.Errorf("DNS name %q has empty label", s)
+ default:
+ b = append(b, regexp.QuoteMeta(f))
+ }
+ }
+ return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`)))
+}
+
+// Match returns the backend for hostname, and whether to use the PROXY protocol.
+func (c *Config) Match(hostname string) (string, bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if strings.HasSuffix(hostname, ".acme.invalid") {
+ return c.acme.Match(hostname), false
+ }
+
+ for _, r := range c.routes {
+ if r.match.MatchString(hostname) {
+ return r.backend, r.proxyInfo
+ }
+ }
+ return "", false
+}
+
+// Read replaces the current Config with one read from r.
+func (c *Config) Read(r io.Reader) error {
+ var routes []Route
+ var backends []string
+
+ s := bufio.NewScanner(r)
+ for s.Scan() {
+ if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") {
+ // Comment, ignore.
+ continue
+ }
+
+ fs := strings.Fields(s.Text())
+ switch len(fs) {
+ case 0:
+ continue
+ case 1:
+ return fmt.Errorf("invalid %q on a line by itself", s.Text())
+ case 2:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ routes = append(routes, Route{re, fs[1], false})
+ backends = append(backends, fs[1])
+ case 3:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ if fs[2] != "PROXY" {
+ return errors.New("third item on a line can only be PROXY")
+ }
+ routes = append(routes, Route{re, fs[1], true})
+ backends = append(backends, fs[1])
+ default:
+ // TODO: multiple backends?
+ return fmt.Errorf("too many fields on line: %q", s.Text())
+ }
+ }
+ if err := s.Err(); err != nil {
+ return err
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.routes = routes
+ c.acme = &ACME{
+ backends: backends,
+ cache: make(map[string]acmeCacheEntry),
+ }
+ return nil
+}
+
+// ReadFile replaces the current Config with one read from path.
+func (c *Config) ReadFile(path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ return c.Read(f)
+}
+
+// ReadString replaces the current Config with one read from cfg.
+func (c *Config) ReadString(cfg string) error {
+ b := bytes.NewBufferString(cfg)
+ return c.Read(b)
+}
diff --git a/cmd/tlsrouter/config_test.go b/cmd/tlsrouter/config_test.go
new file mode 100644
index 0000000..9819b91
--- /dev/null
+++ b/cmd/tlsrouter/config_test.go
@@ -0,0 +1,61 @@
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestConfig(t *testing.T) {
+ type result struct {
+ backend string
+ proxy bool
+ }
+
+ cases := []struct {
+ Config string
+ Tests map[string]result
+ }{
+ {
+ Config: `
+# Comment
+go.universe.tf 1.2.3.4
+*.universe.tf 2.3.4.5
+# Comment
+google.* 3.4.5.6
+/gooo+gle\.com/ 4.5.6.7
+foobar.net 6.7.8.9 PROXY
+`,
+ Tests: map[string]result{
+ "go.universe.tf": result{"1.2.3.4", false},
+ "foo.universe.tf": result{"2.3.4.5", false},
+ "bar.universe.tf": result{"2.3.4.5", false},
+ "google.com": result{"3.4.5.6", false},
+ "google.fr": result{"3.4.5.6", false},
+ "goooooooooogle.com": result{"4.5.6.7", false},
+ "foobar.net": result{"6.7.8.9", true},
+
+ "blah.com": result{"", false},
+ "google.com.br": result{"", false},
+ "foo.bar.universe.tf": result{"", false},
+ "goooooglexcom": result{"", false},
+ },
+ },
+ }
+
+ for _, test := range cases {
+ var cfg Config
+ if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil {
+ t.Fatalf("Failed to read config (%s):\n%q", err, test.Config)
+ }
+
+ for hostname, expected := range test.Tests {
+ backend, proxy := cfg.Match(hostname)
+ if expected.backend != backend {
+ t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend)
+ }
+ if expected.proxy != proxy {
+ t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy)
+ }
+ }
+ }
+}
diff --git a/cmd/tlsrouter/e2e_test.go b/cmd/tlsrouter/e2e_test.go
new file mode 100644
index 0000000..c53e8c5
--- /dev/null
+++ b/cmd/tlsrouter/e2e_test.go
@@ -0,0 +1,224 @@
+package main
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ proxyproto "github.com/armon/go-proxyproto"
+)
+
+func TestRouting(t *testing.T) {
+ // Backend servers
+ s1, err := serveTLS(t, "server1", false, "test.com")
+ if err != nil {
+ t.Fatalf("serve TLS server1: %s", err)
+ }
+ defer s1.Close()
+
+ s2, err := serveTLS(t, "server2", false, "foo.net")
+ if err != nil {
+ t.Fatalf("serve TLS server2: %s", err)
+ }
+ defer s2.Close()
+
+ s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid")
+ if err != nil {
+ t.Fatalf("server TLS server3: %s", err)
+ }
+ defer s3.Close()
+
+ s4, err := serveTLS(t, "server4", true, "proxy.design")
+ if err != nil {
+ t.Fatalf("server TLS server4: %s", err)
+ }
+ defer s4.Close()
+
+ // One proxy
+ var p Proxy
+ l, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("create listener: %s", err)
+ }
+ defer l.Close()
+ go p.Serve(l)
+
+ if err := p.Config.ReadString(fmt.Sprintf(`
+test.com %s
+foo.net %s
+borkbork.tf %s
+proxy.design %s PROXY
+`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil {
+ t.Fatalf("configure proxy: %s", err)
+ }
+
+ for _, test := range []struct {
+ N, V string
+ P *x509.CertPool
+ OK bool
+ Transparent bool
+ }{
+ {"test.com", "server1", s1.Pool, true, false},
+ {"foo.net", "server2", s2.Pool, true, false},
+ {"bar.org", "", s1.Pool, false, false},
+ {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false},
+ {"proxy.design", "server4", s4.Pool, true, true},
+ } {
+ res, transparent, err := getTLS(l.Addr().String(), test.N, test.P)
+ switch {
+ case test.OK && err != nil:
+ t.Fatalf("get %q failed: %s", test.N, err)
+ case !test.OK && err == nil:
+ t.Fatalf("get %q should have failed, but returned %q", test.N, res)
+ case test.OK && res != test.V:
+ t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ case test.OK && transparent != test.Transparent:
+ t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent)
+ }
+ }
+}
+
+// getTLS attempts to set up a TLS session using the given proxy
+// address, domain, and cert pool. It returns the value served by the
+// server, as well as a bool indicating whether the server knew the
+// true client address, indicating that the PROXY protocol was in use.
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) {
+ cfg := tls.Config{
+ RootCAs: pool,
+ ServerName: domain,
+ }
+ conn, err := tls.Dial("tcp", addr, &cfg)
+ if err != nil {
+ return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ }
+ defer conn.Close()
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ fs := strings.Split(string(bs), " ")
+ if len(fs) != 2 {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs))
+ }
+ transparent := fs[1] == conn.LocalAddr().String()
+ return fs[0], transparent, nil
+}
+
+type tlsServer struct {
+ Domains []string
+ Value string
+ Pool *x509.CertPool
+ Test *testing.T
+ NumHits uint32
+ l net.Listener
+}
+
+func (s *tlsServer) Serve() {
+ for {
+ c, err := s.l.Accept()
+ if err != nil {
+ s.Test.Logf("accept failed on %q: %s", s.Domains, err)
+ return
+ }
+ atomic.AddUint32(&s.NumHits, 1)
+ fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr())
+ c.Close()
+ }
+}
+
+func (s *tlsServer) Addr() string {
+ return s.l.Addr().String()
+}
+
+func (s *tlsServer) Close() error {
+ return s.l.Close()
+}
+
+func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) {
+ cert, pool, err := selfSignedCert(domains)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+
+ var l net.Listener
+
+ l, err = net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return nil, err
+ }
+
+ if understandProxy {
+ l = &proxyproto.Listener{Listener: l}
+ }
+
+ l = tls.NewListener(l, cfg)
+
+ ret := &tlsServer{
+ Domains: domains,
+ Value: value,
+ Pool: pool,
+ Test: t,
+ l: l,
+ }
+ go ret.Serve()
+ return ret, nil
+}
+
+func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) {
+ pkey, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domains[0],
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: domains[1:],
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM(cert.Bytes()) {
+ return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains)
+ }
+
+ return tlscert, pool, nil
+}
diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go
new file mode 100644
index 0000000..ff1a816
--- /dev/null
+++ b/cmd/tlsrouter/main.go
@@ -0,0 +1,191 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+ "time"
+)
+
+var (
+ cfgFile = flag.String("conf", "", "configuration file")
+ listen = flag.String("listen", ":443", "listening port")
+ helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello")
+)
+
+func main() {
+ flag.Parse()
+
+ p := &Proxy{}
+ if err := p.Config.ReadFile(*cfgFile); err != nil {
+ log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
+ }
+
+ log.Fatalf("%s", p.ListenAndServe(*listen))
+}
+
+// Proxy routes connections to backends based on a Config.
+type Proxy struct {
+ Config Config
+ l net.Listener
+}
+
+// Serve accepts connections from l and routes them according to TLS SNI.
+func (p *Proxy) Serve(l net.Listener) error {
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ return fmt.Errorf("accept new conn: %s", err)
+ }
+
+ conn := &Conn{
+ TCPConn: c.(*net.TCPConn),
+ config: &p.Config,
+ }
+ go conn.proxy()
+ }
+}
+
+// ListenAndServe creates a listener on addr calls Serve on it.
+func (p *Proxy) ListenAndServe(addr string) error {
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %s", err)
+ }
+ return p.Serve(l)
+}
+
+// A Conn handles the TLS proxying of one user connection.
+type Conn struct {
+ *net.TCPConn
+ config *Config
+
+ tlsMinor int
+ hostname string
+ backend string
+ backendConn *net.TCPConn
+}
+
+func (c *Conn) logf(msg string, args ...interface{}) {
+ msg = fmt.Sprintf(msg, args...)
+ log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg)
+}
+
+func (c *Conn) abort(alert byte, msg string, args ...interface{}) {
+ c.logf(msg, args...)
+ alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert}
+
+ if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.logf("error while setting write deadline during abort: %s", err)
+ // Do NOT send the alert if we can't set a write deadline,
+ // that could result in leaking a connection for an extended
+ // period.
+ return
+ }
+
+ if _, err := c.Write(alertMsg); err != nil {
+ c.logf("error while sending alert: %s", err)
+ }
+}
+
+func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) }
+func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) }
+
+func (c *Conn) proxy() {
+ defer c.Close()
+
+ if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.internalError("Setting read deadline for ClientHello: %s", err)
+ return
+ }
+
+ var (
+ err error
+ handshakeBuf bytes.Buffer
+ )
+ c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf))
+ if err != nil {
+ c.internalError("Extracting SNI: %s", err)
+ return
+ }
+
+ c.logf("extracted SNI %s", c.hostname)
+
+ if err = c.SetReadDeadline(time.Time{}); err != nil {
+ c.internalError("Clearing read deadline for ClientHello: %s", err)
+ return
+ }
+
+ addProxyHeader := false
+ c.backend, addProxyHeader = c.config.Match(c.hostname)
+ if c.backend == "" {
+ c.sniFailed("no backend found for %q", c.hostname)
+ return
+ }
+
+ c.logf("routing %q to %q", c.hostname, c.backend)
+ backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second)
+ if err != nil {
+ c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err)
+ return
+ }
+ defer backend.Close()
+
+ c.backendConn = backend.(*net.TCPConn)
+
+ // If the backend supports the HAProxy PROXY protocol, give it the
+ // real source information about the connection.
+ if addProxyHeader {
+ remote := c.TCPConn.RemoteAddr().(*net.TCPAddr)
+ local := c.TCPConn.LocalAddr().(*net.TCPAddr)
+ family := "TCP6"
+ if remote.IP.To4() != nil {
+ family = "TCP4"
+ }
+ if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil {
+ c.internalError("failed to send PROXY header to %q: %s", c.backend, err)
+ return
+ }
+ }
+
+ // Replay the piece of the handshake we had to read to do the
+ // routing, then blindly proxy any other bytes.
+ if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil {
+ c.internalError("failed to replay handshake to %q: %s", c.backend, err)
+ return
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go proxy(&wg, c.TCPConn, c.backendConn)
+ go proxy(&wg, c.backendConn, c.TCPConn)
+ wg.Wait()
+}
+
+func proxy(wg *sync.WaitGroup, a, b net.Conn) {
+ defer wg.Done()
+ atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn)
+ if _, err := io.Copy(atcp, btcp); err != nil {
+ log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err)
+ }
+ btcp.CloseWrite()
+ atcp.CloseRead()
+}
diff --git a/cmd/tlsrouter/sni.go b/cmd/tlsrouter/sni.go
new file mode 100644
index 0000000..ed79df2
--- /dev/null
+++ b/cmd/tlsrouter/sni.go
@@ -0,0 +1,232 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+)
+
+func extractSNI(r io.Reader) (string, int, error) {
+ handshake, tlsver, err := handshakeRecord(r)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading TLS record: %s", err)
+ }
+
+ sni, err := parseHello(handshake)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading ClientHello: %s", err)
+ }
+ if len(sni) == 0 {
+ // ClientHello did not present an SNI extension. Valid packet,
+ // no hostname.
+ return "", tlsver, nil
+ }
+
+ hostname, err := parseSNI(sni)
+ if err != nil {
+ return "", 0, fmt.Errorf("parsing SNI extension: %s", err)
+ }
+ return hostname, tlsver, nil
+}
+
+// Extract the indicated hostname, if any, from the given SNI
+// extension bytes.
+func parseSNI(b []byte) (string, error) {
+ b, _, err := vector(b, 2)
+ if err != nil {
+ return "", err
+ }
+
+ var ret []byte
+ for len(b) >= 3 {
+ typ := b[0]
+ ret, b, err = vector(b[1:], 2)
+ if err != nil {
+ return "", fmt.Errorf("truncated SNI extension")
+ }
+
+ if typ == sniHostnameID {
+ return string(ret), nil
+ }
+ }
+
+ if len(b) != 0 {
+ return "", fmt.Errorf("trailing garbage at end of SNI extension")
+ }
+
+ // No DNS-based SNI present.
+ return "", nil
+}
+
+const sniExtensionID = 0
+const sniHostnameID = 0
+
+// Parse a TLS handshake record as a ClientHello message and extract
+// the SNI extension bytes, if any.
+func parseHello(b []byte) ([]byte, error) {
+ if len(b) == 0 {
+ return nil, errors.New("zero length handshake record")
+ }
+ if b[0] != 1 {
+ return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0])
+ }
+
+ // We're expecting a stricter TLS parser to run after we've
+ // proxied, so we ignore any trailing bytes that might be present
+ // (e.g. another handshake message).
+ b, _, err := vector(b[1:], 3)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello: %s", err)
+ }
+
+ // ClientHello must be at least 34 bytes to reach the first vector
+ // length byte. The actual minimal size is larger than that, but
+ // vector() will correctly handle truncated packets.
+ if len(b) < 34 {
+ return nil, errors.New("ClientHello packet too short")
+ }
+
+ if b[0] != 3 {
+ return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1])
+ }
+ switch b[1] {
+ case 1, 2, 3:
+ // TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1])
+ }
+
+ // Skip over version and random struct
+ b = b[34:]
+
+ // We don't technically care about SessionID, but we care that the
+ // framing is well-formed all the way up to the SNI field, so that
+ // we are sure that we're pulling the same SNI bytes as the
+ // eventual TLS implementation.
+ vec, b, err := vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello SessionID: %s", err)
+ }
+ if len(vec) > 32 {
+ return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
+ }
+
+ // Likewise, we're just checking the bare minimum of framing.
+ vec, b, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err)
+ }
+ if len(vec) < 2 || len(vec)%2 != 0 {
+ return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
+ }
+
+ vec, b, err = vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err)
+ }
+ if len(vec) < 1 {
+ return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
+ }
+
+ // Finally, we reach the extensions.
+ if len(b) == 0 {
+ // No extensions. This is not an error, it just means we have
+ // no SNI payload.
+ return nil, nil
+ }
+ b, vec, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extensions: %s", err)
+ }
+ if len(vec) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec))
+ }
+
+ for len(b) >= 4 {
+ typ := binary.BigEndian.Uint16(b[:2])
+ vec, b, err = vector(b[2:], 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err)
+ }
+ if typ == sniExtensionID {
+ // Found the SNI extension, return its payload. We don't
+ // care about anything in the packet beyond this point.
+ return vec, nil
+ }
+ }
+
+ if len(b) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b))
+ }
+
+ // Successfully parsed all extensions, but there was no SNI.
+ return nil, nil
+}
+
+const maxTLSRecordLength = 16384
+
+// Read one TLS record, which must be for the handshake protocol, from r.
+func handshakeRecord(r io.Reader) ([]byte, int, error) {
+ var hdr struct {
+ Type uint8
+ Major, Minor uint8
+ Length uint16
+ }
+ if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
+ return nil, 0, fmt.Errorf("reading TLS record header: %s", err)
+ }
+
+ if hdr.Type != 22 {
+ return nil, 0, fmt.Errorf("TLS record is not a handshake")
+ }
+
+ if hdr.Major != 3 {
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+ switch hdr.Minor {
+ case 1, 2, 3:
+ // TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+
+ if hdr.Length > maxTLSRecordLength {
+ return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength)
+ }
+
+ ret := make([]byte, hdr.Length)
+ if _, err := io.ReadFull(r, ret); err != nil {
+ return nil, 0, err
+ }
+
+ return ret, int(hdr.Minor), nil
+}
+
+func vector(b []byte, lenBytes int) ([]byte, []byte, error) {
+ if len(b) < lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ var l int
+ for _, b := range b[:lenBytes] {
+ l = (l << 8) + int(b)
+ }
+ if len(b) < l+lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil
+}
diff --git a/cmd/tlsrouter/sni_test.go b/cmd/tlsrouter/sni_test.go
new file mode 100644
index 0000000..8c87d24
--- /dev/null
+++ b/cmd/tlsrouter/sni_test.go
@@ -0,0 +1,456 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+func slice(l int) []byte {
+ ret := make([]byte, l)
+ for i := 0; i < l; i++ {
+ ret[i] = byte(i)
+ }
+ return ret
+}
+
+func vec(l, lenBytes int) []byte {
+ b := slice(l)
+ vecLen := len(b)
+ ret := make([]byte, vecLen+l)
+ for i := l - 1; i >= 0; i-- {
+ ret[i] = byte(vecLen & 0xff)
+ vecLen >>= 8
+ }
+ copy(ret[l:], b)
+ return ret
+}
+
+func packet(bs ...[]byte) []byte {
+ var ret []byte
+ for _, b := range bs {
+ ret = append(ret, b...)
+ }
+ return ret
+}
+
+func offset(b []byte, off int) []byte {
+ return b[off:]
+}
+
+func TestVector(t *testing.T) {
+ tests := []struct {
+ in []byte
+ inLen int
+ out1, out2 []byte
+ err bool
+ }{
+ {
+ // 1b length
+ append([]byte{3}, slice(10)...), 1,
+ slice(3), offset(slice(10), 3), false,
+ },
+ {
+ // 1b length, no trailer
+ append([]byte{10}, slice(10)...), 1,
+ slice(10), []byte{}, false,
+ },
+ {
+ // 1b length, no vector
+ append([]byte{0}, slice(10)...), 1,
+ []byte{}, slice(10), false,
+ },
+ {
+ // 1b length, no vector or trailer
+ []byte{0}, 1,
+ []byte{}, []byte{}, false,
+ },
+ {
+ // 2b length, LSB only
+ append([]byte{0, 3}, slice(10)...), 2,
+ slice(3), offset(slice(10), 3), false,
+ },
+ {
+ // 2b length, MSB only
+ append([]byte{3, 0}, slice(1024)...), 2,
+ slice(768), offset(slice(1024), 768), false,
+ },
+ {
+ // 2b length, both bytes
+ append([]byte{3, 2}, slice(1024)...), 2,
+ slice(770), offset(slice(1024), 770), false,
+ },
+ {
+ // 3b length
+ append([]byte{1, 2, 3}, slice(100000)...), 3,
+ slice(66051), offset(slice(100000), 66051), false,
+ },
+ {
+ // no bytes
+ []byte{}, 1,
+ nil, nil, true,
+ },
+ {
+ // no slice
+ nil, 1,
+ nil, nil, true,
+ },
+ {
+ // not enough bytes for length
+ []byte{1}, 2,
+ nil, nil, true,
+ },
+ {
+ // no bytes after length
+ []byte{1}, 1,
+ nil, nil, true,
+ },
+ {
+ // not enough bytes for vector
+ []byte{4, 1, 2}, 1,
+ nil, nil, true,
+ },
+ }
+
+ for _, test := range tests {
+ actual1, actual2, err := vector(test.in, test.inLen)
+ if !test.err && (err != nil) {
+ t.Errorf("unexpected error %q", err)
+ }
+ if test.err && (err == nil) {
+ t.Errorf("unexpected success")
+ }
+ if err != nil {
+ continue
+ }
+ if !bytes.Equal(actual1, test.out1) {
+ t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1)
+ }
+ if !bytes.Equal(actual2, test.out2) {
+ t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2)
+ }
+ }
+}
+
+func TestHandshakeRecord(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out []byte
+ tlsver int
+ }{
+ {
+ // TLS 1.0, 1b packet
+ []byte{22, 3, 1, 0, 1, 3},
+ []byte{3},
+ 1,
+ },
+ {
+ // TLS 1.1, 1b packet
+ []byte{22, 3, 2, 0, 1, 3},
+ []byte{3},
+ 2,
+ },
+ {
+ // TLS 1.2, 1b packet
+ []byte{22, 3, 3, 0, 1, 3},
+ []byte{3},
+ 3,
+ },
+ {
+ // TLS 1.2, no payload bytes
+ []byte{22, 3, 3, 0, 0},
+ []byte{},
+ 3,
+ },
+ {
+ // TLS 1.2, >255b payload w/ trailing stuff
+ append([]byte{22, 3, 3, 3, 2}, slice(1024)...),
+ slice(770),
+ 3,
+ },
+ {
+ // TLS 1.2, 2^14 payload
+ append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...),
+ slice(maxTLSRecordLength),
+ 3,
+ },
+ {
+ // TLS 1.2, >2^14 payload
+ append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...),
+ nil,
+ 0,
+ },
+ {
+ // TLS 1.2, truncated payload
+ []byte{22, 3, 3, 0, 4, 1, 2},
+ nil,
+ 0,
+ },
+ {
+ // truncated header
+ []byte{22},
+ nil,
+ 0,
+ },
+ {
+ // wrong record type
+ []byte{42, 3, 3, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // wrong TLS major version
+ []byte{22, 2, 3, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // wrong TLS minor version
+ []byte{22, 3, 42, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // Obsolete SSL 3.0
+ []byte{22, 3, 0, 0, 1, 3},
+ nil,
+ 0,
+ },
+ }
+
+ for _, test := range tests {
+ r := bytes.NewBuffer(test.in)
+ actual, tlsver, err := handshakeRecord(r)
+ if test.out == nil && err == nil {
+ t.Errorf("unexpected success")
+ continue
+ }
+ if !bytes.Equal(test.out, actual) {
+ t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out)
+ }
+ if tlsver != test.tlsver {
+ t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver)
+ }
+ }
+}
+
+func TestParseHello(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out []byte
+ err bool
+ }{
+ {
+ // Wrong record type
+ packet([]byte{42, 0, 0, 1, 1}),
+ nil,
+ true,
+ },
+ {
+ // Truncated payload
+ packet([]byte{1, 0, 0, 1}),
+ nil,
+ true,
+ },
+ {
+ // Payload too small
+ packet([]byte{1, 0, 0, 1, 1}),
+ nil,
+ true,
+ },
+ {
+ // Unknown major version
+ packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // Unknown minor version
+ packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // Missing required variadic fields
+ packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // All zero variadic fields (no ciphersuites, no compression)
+ packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}),
+ nil,
+ true,
+ },
+ {
+ // All zero variadic fields (no ciphersuites, no compression, nonzero session ID)
+ packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}),
+ nil,
+ true,
+ },
+ {
+ // Session + ciphersuites, no compression
+ packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}),
+ nil,
+ true,
+ },
+ {
+ // First valid packet. TLS 1.0, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.1, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.2, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.2, garbage extensions
+ packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)),
+ nil,
+ true,
+ },
+ {
+ // empty extensions vector
+ packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}),
+ nil,
+ false,
+ },
+ {
+ // non-SNI extensions
+ packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}),
+ nil,
+ false,
+ },
+ {
+ // SNI present
+ packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}),
+ []byte{182},
+ false,
+ },
+ {
+ // Longer SNI
+ packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)),
+ slice(4),
+ false,
+ },
+ {
+ // Embedded SNI
+ packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}),
+ slice(4),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ actual, err := parseHello(test.in)
+ if test.err {
+ if err == nil {
+ t.Errorf("unexpected success")
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ continue
+ }
+ if !bytes.Equal(test.out, actual) {
+ t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out)
+ }
+ }
+}
+
+func TestParseSNI(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out string
+ err bool
+ }{
+ {
+ // Empty packet
+ []byte{},
+ "",
+ true,
+ },
+ {
+ // Truncated packet
+ []byte{0, 2, 1},
+ "",
+ true,
+ },
+ {
+ // Truncated packet within SNI vector
+ []byte{0, 2, 1, 2},
+ "",
+ true,
+ },
+ {
+ // Wrong SNI kind
+ []byte{0, 3, 1, 0, 0},
+ "",
+ false,
+ },
+ {
+ // Right SNI kind, no hostname
+ []byte{0, 3, 0, 0, 0},
+ "",
+ false,
+ },
+ {
+ // SNI hostname
+ packet([]byte{0, 6, 0, 0, 3}, []byte("lol")),
+ "lol",
+ false,
+ },
+ {
+ // Multiple SNI kinds
+ packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}),
+ "lol",
+ false,
+ },
+ {
+ // Multiple SNI hostnames (illegal, but we just return the first)
+ packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")),
+ "bar",
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ actual, err := parseSNI(test.in)
+ if test.err {
+ if err == nil {
+ t.Errorf("unexpected success")
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ continue
+ }
+ if test.out != actual {
+ t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out)
+ }
+ }
+}