summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--.travis.yml47
-rw-r--r--CONTRIBUTING.md8
-rw-r--r--cmd/tlsrouter/README.md51
-rw-r--r--cmd/tlsrouter/acme.go101
-rw-r--r--cmd/tlsrouter/config.go146
-rw-r--r--cmd/tlsrouter/config_test.go61
-rw-r--r--cmd/tlsrouter/e2e_test.go224
-rw-r--r--cmd/tlsrouter/main.go191
-rw-r--r--cmd/tlsrouter/sni.go232
-rw-r--r--cmd/tlsrouter/sni_test.go456
-rw-r--r--scripts/prune_old_versions.go150
-rw-r--r--systemd/tlsrouter.service25
13 files changed, 1694 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ab78466
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+tlsrouter
+tlsrouter.test
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..7c11683
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,47 @@
+language: go
+go_import_path: go.universe.tf/tcpproxy
+go:
+- 1.7
+- 1.8
+- tip
+os:
+- linux
+install:
+- go get github.com/golang/lint/golint
+before_script:
+script:
+- go get -t .
+- go build ./...
+- go test ./...
+- go vet ./...
+- golint -set_exit_status .
+
+jobs:
+ include:
+ - stage: deploy
+ go: 1.8
+ install:
+ - gem install fpm
+ script:
+ - go build .
+ - fpm -s dir -t deb -n tlsrouter -v $(date '+%Y%m%d%H%M%S')
+ --license Apache2
+ --vendor "David Anderson <[email protected]>"
+ --maintainer "David Anderson <[email protected]>"
+ --description "TLS SNI router"
+ --url "https://github.com/google/tlsrouter"
+ ./cmd/tlsrouter/tlsrouter=/usr/bin/tlsrouter
+ ./systemd/tlsrouter.service=/lib/systemd/system/tlsrouter.service
+ deploy:
+ - provider: packagecloud
+ repository: tlsrouter
+ username: danderson
+ dist: debian/stretch
+ skip_cleanup: true
+ token:
+ secure: gNU3o70EU4oYeIS6pr0K5oLMGqqxrcf41EOv6c/YoHPVdV6Cx4j9NW0/ISgu6a1/Xf2NgWKT5BWwLpAuhmGdALuOz1Ah//YBWd9N8mGHGaC6RpOPDU8/9NkQdBEmjEH9sgX4PNOh1KQ7d7O0OH0g8RqJlJa0MkUYbTtN6KJ29oiUXxKmZM4D/iWB8VonKOnrtx1NwQL8jL8imZyEV/1fknhDwumz2iKeU1le4Neq9zkxwICMLUonmgphlrp+SDb1EOoHxT6cn51bqBQtQUplfC4dN4OQU/CPqE9E1N1noibvN29YA93qfcrjD3I95KT9wzq+3B6he33+kb0Gz+Cj5ypGy4P85l7TuX4CtQg0U3NAlJCk32IfsdjK+o47pdmADij9IIb9yKt+g99FMERkJJY5EInqEsxHlW/vNF5OqQCmpiHstZL4R2XaHEsWh6j77npnjjC1Aea8xZTWr8PTsbSzVkbG7bTmFpZoPH8eEmr4GNuw5gnbi6D1AJDjcA+UdY9s5qZNpzuWOqfhOFxL+zUW+8sHBvcoFw3R+pwHECs2LCL1c0xAC1LtNUnmW/gnwHavtvKkzErjR1P8Xl7obCbeChJjp+b/BcFYlNACldZcuzBAPyPwIdlWVyUonL4bm63upfMEEShiAIDDJ21y7fjsQK7CfPA7g25bpyo+hV8=
+ - provider: script
+ script: go run scripts/prune_old_versions.go -user=danderson -repo=tlsrouter -distro=debian -version=stretch -package=tlsrouter -arch=amd64 -limit=2
+ env:
+ # Packagecloud API key, for prune_old_versions.go
+ - secure: "SRcNwt+45QyPS1w9aGxMg9905Y6d9w4mBM29G6iTTnUB5nD7cAk4m+tf834knGSobVXlWcRnTDW8zrHdQ9yX22dPqCpH5qE+qzTmIvxRHrVJRMmPeYvligJ/9jYfHgQbvuRT8cUpIcpCQAla6rw8nXfKTOE3h8XqMP2hdc3DTVOu2HCfKCNco1tJ7is+AIAnFV2Wpsbb3ZsdKFvHvi2RKUfFaX61J1GNt2/XJIlZs8jC6Y1IAC+ftjql9UsAE/WjZ9fL0Ww1b9/LBIIGHXWI3HpVv9WvlhhIxIlJgOVjmU2lbSuj2w/EBDJ9cd1Qe+wJkT3yKzE1NRsNScVjGg+Ku5igJu/XXuaHkIX01+15BqgPduBYRL0atiNQDhqgBiSyVhXZBX9vsgsp0bgpKaBSF++CV18Q9dara8aljqqS33M3imO3I8JmXU10944QA9Wvu7pCYuIzXxhINcDXRvqxBqz5LnFJGwnGqngTrOCSVS2xn7Y+sjmhe1n5cPCEISlozfa9mPYPvMPp8zg3TbATOOM8CVfcpaNscLqa/+SExN3zMwSanjNKrBgoaQcBzGW5mIgSPxhXkWikBgapiEN7+2Y032Lhqdb9dYjH+EuwcnofspDjjMabWxnuJaln+E3/9vZi2ooQrBEtvymUTy4VMSnqwIX5bU7nPdIuQycdWhk="
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..188ad87
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,8 @@
+Contributions are welcome by pull request.
+
+You need to sign the Google Contributor License Agreement before your
+contributions can be accepted. You can find the individual and organization
+level CLAs here:
+
+Individual: https://cla.developers.google.com/about/google-individual
+Organization: https://cla.developers.google.com/about/google-corporate
diff --git a/cmd/tlsrouter/README.md b/cmd/tlsrouter/README.md
new file mode 100644
index 0000000..5a75935
--- /dev/null
+++ b/cmd/tlsrouter/README.md
@@ -0,0 +1,51 @@
+# TLS SNI router
+
+[![license](https://img.shields.io/github/license/google/tlsrouter.svg?maxAge=2592000)](https://github.com/google/tlsrouter/blob/master/LICENSE) [![Travis](https://img.shields.io/travis/google/tlsrouter.svg?maxAge=2592000)](https://travis-ci.org/google/tlsrouter) [![api](https://img.shields.io/badge/api-unstable-red.svg)](https://godoc.org/go.universe.tf/tlsrouter)
+
+TLSRouter is a TLS proxy that routes connections to backends based on the TLS SNI (Server Name Indication) of the TLS handshake. It carries no encryption keys and cannot decode the traffic that it proxies.
+
+This is not an official Google project.
+
+## Installation
+
+Install TLSRouter via `go get`:
+
+```shell
+go get go.universe.tf/tlsrouter
+```
+
+## Usage
+
+TLSRouter requires a configuration file that tells it what backend to
+use for a given hostname. The config file looks like:
+
+```
+# Basic hostname -> backend mapping
+go.universe.tf localhost:1234
+
+# DNS wildcards are understood as well.
+*.go.universe.tf 1.2.3.4:8080
+
+# DNS wildcards can go anywhere in name.
+google.* 10.20.30.40:443
+
+# RE2 regexes are also available
+/(alpha|beta|gamma)\.mon(itoring)?\.dave\.tf/ 100.200.100.200:443
+
+# If your backend supports HAProxy's PROXY protocol, you can enable
+# it to receive the real client ip:port.
+
+fancy.backend 2.3.4.5:443 PROXY
+```
+
+TLSRouter takes one mandatory commandline argument, the configuration file to use:
+
+```shell
+tlsrouter -conf tlsrouter.conf
+```
+
+Optional flags are:
+
+ * `-listen <addr>`: set the listen address (default `:443`)
+ * `-hello-timeout <duration>`: how long to wait for the start of the
+ TLS handshake (default `3s`)
diff --git a/cmd/tlsrouter/acme.go b/cmd/tlsrouter/acme.go
new file mode 100644
index 0000000..ab8d59a
--- /dev/null
+++ b/cmd/tlsrouter/acme.go
@@ -0,0 +1,101 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "crypto/tls"
+ "net"
+ "time"
+)
+
+type acmeCacheEntry struct {
+ backend string
+ expires time.Time
+}
+
+// ACME locates backends that are attempting ACME SNI-based validation.
+type ACME struct {
+ backends []string
+ // *.acme.invalid domain to cache entry
+ cache map[string]acmeCacheEntry
+}
+
+// Match returns the backend for hostname, if one is found.
+func (s *ACME) Match(hostname string) string {
+ c := s.cache[hostname]
+ if time.Now().Before(c.expires) {
+ return c.backend
+ }
+
+ // Cache entry is either expired or invalid, need to figure out
+ // which backend is the right one.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ ch := make(chan string, len(s.backends))
+ for _, backend := range s.backends {
+ go tryAcme(ctx, ch, backend, hostname)
+ }
+ for range s.backends {
+ backend := <-ch
+ if backend != "" {
+ s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)}
+ return backend
+ }
+ }
+
+ // No usable backends found :(
+ s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)}
+ return ""
+}
+
+func tryAcme(ctx context.Context, ch chan string, backend, hostname string) {
+ var res string
+ var err error
+ defer func() { ch <- res }()
+
+ dialer := net.Dialer{Timeout: 10 * time.Second}
+ conn, err := dialer.DialContext(ctx, "tcp", backend)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ deadline, ok := ctx.Deadline()
+ if ok {
+ conn.SetDeadline(deadline)
+ }
+ client := tls.Client(conn, &tls.Config{
+ ServerName: hostname,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return
+ }
+ if err = client.Handshake(); err != nil {
+ return
+ }
+
+ certs := client.ConnectionState().PeerCertificates
+ if len(certs) == 0 {
+ return
+ }
+ if err = certs[0].VerifyHostname(hostname); err != nil {
+ return
+ }
+
+ res = backend
+}
diff --git a/cmd/tlsrouter/config.go b/cmd/tlsrouter/config.go
new file mode 100644
index 0000000..1c8151f
--- /dev/null
+++ b/cmd/tlsrouter/config.go
@@ -0,0 +1,146 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "regexp"
+ "strings"
+ "sync"
+)
+
+// A Route maps a match on a domain name to a backend.
+type Route struct {
+ match *regexp.Regexp
+ backend string
+ proxyInfo bool
+}
+
+// Config stores the TLS routing configuration.
+type Config struct {
+ mu sync.Mutex
+ routes []Route
+ acme *ACME
+}
+
+func dnsRegex(s string) (*regexp.Regexp, error) {
+ if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' {
+ return regexp.Compile(s[1 : len(s)-1])
+ }
+
+ var b []string
+ for _, f := range strings.Split(s, ".") {
+ switch f {
+ case "*":
+ b = append(b, `[^.]+`)
+ case "":
+ return nil, fmt.Errorf("DNS name %q has empty label", s)
+ default:
+ b = append(b, regexp.QuoteMeta(f))
+ }
+ }
+ return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`)))
+}
+
+// Match returns the backend for hostname, and whether to use the PROXY protocol.
+func (c *Config) Match(hostname string) (string, bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if strings.HasSuffix(hostname, ".acme.invalid") {
+ return c.acme.Match(hostname), false
+ }
+
+ for _, r := range c.routes {
+ if r.match.MatchString(hostname) {
+ return r.backend, r.proxyInfo
+ }
+ }
+ return "", false
+}
+
+// Read replaces the current Config with one read from r.
+func (c *Config) Read(r io.Reader) error {
+ var routes []Route
+ var backends []string
+
+ s := bufio.NewScanner(r)
+ for s.Scan() {
+ if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") {
+ // Comment, ignore.
+ continue
+ }
+
+ fs := strings.Fields(s.Text())
+ switch len(fs) {
+ case 0:
+ continue
+ case 1:
+ return fmt.Errorf("invalid %q on a line by itself", s.Text())
+ case 2:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ routes = append(routes, Route{re, fs[1], false})
+ backends = append(backends, fs[1])
+ case 3:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ if fs[2] != "PROXY" {
+ return errors.New("third item on a line can only be PROXY")
+ }
+ routes = append(routes, Route{re, fs[1], true})
+ backends = append(backends, fs[1])
+ default:
+ // TODO: multiple backends?
+ return fmt.Errorf("too many fields on line: %q", s.Text())
+ }
+ }
+ if err := s.Err(); err != nil {
+ return err
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.routes = routes
+ c.acme = &ACME{
+ backends: backends,
+ cache: make(map[string]acmeCacheEntry),
+ }
+ return nil
+}
+
+// ReadFile replaces the current Config with one read from path.
+func (c *Config) ReadFile(path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ return c.Read(f)
+}
+
+// ReadString replaces the current Config with one read from cfg.
+func (c *Config) ReadString(cfg string) error {
+ b := bytes.NewBufferString(cfg)
+ return c.Read(b)
+}
diff --git a/cmd/tlsrouter/config_test.go b/cmd/tlsrouter/config_test.go
new file mode 100644
index 0000000..9819b91
--- /dev/null
+++ b/cmd/tlsrouter/config_test.go
@@ -0,0 +1,61 @@
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestConfig(t *testing.T) {
+ type result struct {
+ backend string
+ proxy bool
+ }
+
+ cases := []struct {
+ Config string
+ Tests map[string]result
+ }{
+ {
+ Config: `
+# Comment
+go.universe.tf 1.2.3.4
+*.universe.tf 2.3.4.5
+# Comment
+google.* 3.4.5.6
+/gooo+gle\.com/ 4.5.6.7
+foobar.net 6.7.8.9 PROXY
+`,
+ Tests: map[string]result{
+ "go.universe.tf": result{"1.2.3.4", false},
+ "foo.universe.tf": result{"2.3.4.5", false},
+ "bar.universe.tf": result{"2.3.4.5", false},
+ "google.com": result{"3.4.5.6", false},
+ "google.fr": result{"3.4.5.6", false},
+ "goooooooooogle.com": result{"4.5.6.7", false},
+ "foobar.net": result{"6.7.8.9", true},
+
+ "blah.com": result{"", false},
+ "google.com.br": result{"", false},
+ "foo.bar.universe.tf": result{"", false},
+ "goooooglexcom": result{"", false},
+ },
+ },
+ }
+
+ for _, test := range cases {
+ var cfg Config
+ if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil {
+ t.Fatalf("Failed to read config (%s):\n%q", err, test.Config)
+ }
+
+ for hostname, expected := range test.Tests {
+ backend, proxy := cfg.Match(hostname)
+ if expected.backend != backend {
+ t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend)
+ }
+ if expected.proxy != proxy {
+ t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy)
+ }
+ }
+ }
+}
diff --git a/cmd/tlsrouter/e2e_test.go b/cmd/tlsrouter/e2e_test.go
new file mode 100644
index 0000000..c53e8c5
--- /dev/null
+++ b/cmd/tlsrouter/e2e_test.go
@@ -0,0 +1,224 @@
+package main
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "math/big"
+ "net"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ proxyproto "github.com/armon/go-proxyproto"
+)
+
+func TestRouting(t *testing.T) {
+ // Backend servers
+ s1, err := serveTLS(t, "server1", false, "test.com")
+ if err != nil {
+ t.Fatalf("serve TLS server1: %s", err)
+ }
+ defer s1.Close()
+
+ s2, err := serveTLS(t, "server2", false, "foo.net")
+ if err != nil {
+ t.Fatalf("serve TLS server2: %s", err)
+ }
+ defer s2.Close()
+
+ s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid")
+ if err != nil {
+ t.Fatalf("server TLS server3: %s", err)
+ }
+ defer s3.Close()
+
+ s4, err := serveTLS(t, "server4", true, "proxy.design")
+ if err != nil {
+ t.Fatalf("server TLS server4: %s", err)
+ }
+ defer s4.Close()
+
+ // One proxy
+ var p Proxy
+ l, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("create listener: %s", err)
+ }
+ defer l.Close()
+ go p.Serve(l)
+
+ if err := p.Config.ReadString(fmt.Sprintf(`
+test.com %s
+foo.net %s
+borkbork.tf %s
+proxy.design %s PROXY
+`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil {
+ t.Fatalf("configure proxy: %s", err)
+ }
+
+ for _, test := range []struct {
+ N, V string
+ P *x509.CertPool
+ OK bool
+ Transparent bool
+ }{
+ {"test.com", "server1", s1.Pool, true, false},
+ {"foo.net", "server2", s2.Pool, true, false},
+ {"bar.org", "", s1.Pool, false, false},
+ {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false},
+ {"proxy.design", "server4", s4.Pool, true, true},
+ } {
+ res, transparent, err := getTLS(l.Addr().String(), test.N, test.P)
+ switch {
+ case test.OK && err != nil:
+ t.Fatalf("get %q failed: %s", test.N, err)
+ case !test.OK && err == nil:
+ t.Fatalf("get %q should have failed, but returned %q", test.N, res)
+ case test.OK && res != test.V:
+ t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ case test.OK && transparent != test.Transparent:
+ t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent)
+ }
+ }
+}
+
+// getTLS attempts to set up a TLS session using the given proxy
+// address, domain, and cert pool. It returns the value served by the
+// server, as well as a bool indicating whether the server knew the
+// true client address, indicating that the PROXY protocol was in use.
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) {
+ cfg := tls.Config{
+ RootCAs: pool,
+ ServerName: domain,
+ }
+ conn, err := tls.Dial("tcp", addr, &cfg)
+ if err != nil {
+ return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ }
+ defer conn.Close()
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ fs := strings.Split(string(bs), " ")
+ if len(fs) != 2 {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs))
+ }
+ transparent := fs[1] == conn.LocalAddr().String()
+ return fs[0], transparent, nil
+}
+
+type tlsServer struct {
+ Domains []string
+ Value string
+ Pool *x509.CertPool
+ Test *testing.T
+ NumHits uint32
+ l net.Listener
+}
+
+func (s *tlsServer) Serve() {
+ for {
+ c, err := s.l.Accept()
+ if err != nil {
+ s.Test.Logf("accept failed on %q: %s", s.Domains, err)
+ return
+ }
+ atomic.AddUint32(&s.NumHits, 1)
+ fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr())
+ c.Close()
+ }
+}
+
+func (s *tlsServer) Addr() string {
+ return s.l.Addr().String()
+}
+
+func (s *tlsServer) Close() error {
+ return s.l.Close()
+}
+
+func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) {
+ cert, pool, err := selfSignedCert(domains)
+ if err != nil {
+ return nil, err
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+ cfg.BuildNameToCertificate()
+
+ var l net.Listener
+
+ l, err = net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return nil, err
+ }
+
+ if understandProxy {
+ l = &proxyproto.Listener{Listener: l}
+ }
+
+ l = tls.NewListener(l, cfg)
+
+ ret := &tlsServer{
+ Domains: domains,
+ Value: value,
+ Pool: pool,
+ Test: t,
+ l: l,
+ }
+ go ret.Serve()
+ return ret, nil
+}
+
+func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) {
+ pkey, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domains[0],
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: domains[1:],
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey)
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ return tls.Certificate{}, nil, err
+ }
+
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM(cert.Bytes()) {
+ return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains)
+ }
+
+ return tlscert, pool, nil
+}
diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go
new file mode 100644
index 0000000..ff1a816
--- /dev/null
+++ b/cmd/tlsrouter/main.go
@@ -0,0 +1,191 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+ "time"
+)
+
+var (
+ cfgFile = flag.String("conf", "", "configuration file")
+ listen = flag.String("listen", ":443", "listening port")
+ helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello")
+)
+
+func main() {
+ flag.Parse()
+
+ p := &Proxy{}
+ if err := p.Config.ReadFile(*cfgFile); err != nil {
+ log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
+ }
+
+ log.Fatalf("%s", p.ListenAndServe(*listen))
+}
+
+// Proxy routes connections to backends based on a Config.
+type Proxy struct {
+ Config Config
+ l net.Listener
+}
+
+// Serve accepts connections from l and routes them according to TLS SNI.
+func (p *Proxy) Serve(l net.Listener) error {
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ return fmt.Errorf("accept new conn: %s", err)
+ }
+
+ conn := &Conn{
+ TCPConn: c.(*net.TCPConn),
+ config: &p.Config,
+ }
+ go conn.proxy()
+ }
+}
+
+// ListenAndServe creates a listener on addr calls Serve on it.
+func (p *Proxy) ListenAndServe(addr string) error {
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %s", err)
+ }
+ return p.Serve(l)
+}
+
+// A Conn handles the TLS proxying of one user connection.
+type Conn struct {
+ *net.TCPConn
+ config *Config
+
+ tlsMinor int
+ hostname string
+ backend string
+ backendConn *net.TCPConn
+}
+
+func (c *Conn) logf(msg string, args ...interface{}) {
+ msg = fmt.Sprintf(msg, args...)
+ log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg)
+}
+
+func (c *Conn) abort(alert byte, msg string, args ...interface{}) {
+ c.logf(msg, args...)
+ alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert}
+
+ if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.logf("error while setting write deadline during abort: %s", err)
+ // Do NOT send the alert if we can't set a write deadline,
+ // that could result in leaking a connection for an extended
+ // period.
+ return
+ }
+
+ if _, err := c.Write(alertMsg); err != nil {
+ c.logf("error while sending alert: %s", err)
+ }
+}
+
+func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) }
+func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) }
+
+func (c *Conn) proxy() {
+ defer c.Close()
+
+ if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.internalError("Setting read deadline for ClientHello: %s", err)
+ return
+ }
+
+ var (
+ err error
+ handshakeBuf bytes.Buffer
+ )
+ c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf))
+ if err != nil {
+ c.internalError("Extracting SNI: %s", err)
+ return
+ }
+
+ c.logf("extracted SNI %s", c.hostname)
+
+ if err = c.SetReadDeadline(time.Time{}); err != nil {
+ c.internalError("Clearing read deadline for ClientHello: %s", err)
+ return
+ }
+
+ addProxyHeader := false
+ c.backend, addProxyHeader = c.config.Match(c.hostname)
+ if c.backend == "" {
+ c.sniFailed("no backend found for %q", c.hostname)
+ return
+ }
+
+ c.logf("routing %q to %q", c.hostname, c.backend)
+ backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second)
+ if err != nil {
+ c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err)
+ return
+ }
+ defer backend.Close()
+
+ c.backendConn = backend.(*net.TCPConn)
+
+ // If the backend supports the HAProxy PROXY protocol, give it the
+ // real source information about the connection.
+ if addProxyHeader {
+ remote := c.TCPConn.RemoteAddr().(*net.TCPAddr)
+ local := c.TCPConn.LocalAddr().(*net.TCPAddr)
+ family := "TCP6"
+ if remote.IP.To4() != nil {
+ family = "TCP4"
+ }
+ if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil {
+ c.internalError("failed to send PROXY header to %q: %s", c.backend, err)
+ return
+ }
+ }
+
+ // Replay the piece of the handshake we had to read to do the
+ // routing, then blindly proxy any other bytes.
+ if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil {
+ c.internalError("failed to replay handshake to %q: %s", c.backend, err)
+ return
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go proxy(&wg, c.TCPConn, c.backendConn)
+ go proxy(&wg, c.backendConn, c.TCPConn)
+ wg.Wait()
+}
+
+func proxy(wg *sync.WaitGroup, a, b net.Conn) {
+ defer wg.Done()
+ atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn)
+ if _, err := io.Copy(atcp, btcp); err != nil {
+ log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err)
+ }
+ btcp.CloseWrite()
+ atcp.CloseRead()
+}
diff --git a/cmd/tlsrouter/sni.go b/cmd/tlsrouter/sni.go
new file mode 100644
index 0000000..ed79df2
--- /dev/null
+++ b/cmd/tlsrouter/sni.go
@@ -0,0 +1,232 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+)
+
+func extractSNI(r io.Reader) (string, int, error) {
+ handshake, tlsver, err := handshakeRecord(r)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading TLS record: %s", err)
+ }
+
+ sni, err := parseHello(handshake)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading ClientHello: %s", err)
+ }
+ if len(sni) == 0 {
+ // ClientHello did not present an SNI extension. Valid packet,
+ // no hostname.
+ return "", tlsver, nil
+ }
+
+ hostname, err := parseSNI(sni)
+ if err != nil {
+ return "", 0, fmt.Errorf("parsing SNI extension: %s", err)
+ }
+ return hostname, tlsver, nil
+}
+
+// Extract the indicated hostname, if any, from the given SNI
+// extension bytes.
+func parseSNI(b []byte) (string, error) {
+ b, _, err := vector(b, 2)
+ if err != nil {
+ return "", err
+ }
+
+ var ret []byte
+ for len(b) >= 3 {
+ typ := b[0]
+ ret, b, err = vector(b[1:], 2)
+ if err != nil {
+ return "", fmt.Errorf("truncated SNI extension")
+ }
+
+ if typ == sniHostnameID {
+ return string(ret), nil
+ }
+ }
+
+ if len(b) != 0 {
+ return "", fmt.Errorf("trailing garbage at end of SNI extension")
+ }
+
+ // No DNS-based SNI present.
+ return "", nil
+}
+
+const sniExtensionID = 0
+const sniHostnameID = 0
+
+// Parse a TLS handshake record as a ClientHello message and extract
+// the SNI extension bytes, if any.
+func parseHello(b []byte) ([]byte, error) {
+ if len(b) == 0 {
+ return nil, errors.New("zero length handshake record")
+ }
+ if b[0] != 1 {
+ return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0])
+ }
+
+ // We're expecting a stricter TLS parser to run after we've
+ // proxied, so we ignore any trailing bytes that might be present
+ // (e.g. another handshake message).
+ b, _, err := vector(b[1:], 3)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello: %s", err)
+ }
+
+ // ClientHello must be at least 34 bytes to reach the first vector
+ // length byte. The actual minimal size is larger than that, but
+ // vector() will correctly handle truncated packets.
+ if len(b) < 34 {
+ return nil, errors.New("ClientHello packet too short")
+ }
+
+ if b[0] != 3 {
+ return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1])
+ }
+ switch b[1] {
+ case 1, 2, 3:
+ // TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1])
+ }
+
+ // Skip over version and random struct
+ b = b[34:]
+
+ // We don't technically care about SessionID, but we care that the
+ // framing is well-formed all the way up to the SNI field, so that
+ // we are sure that we're pulling the same SNI bytes as the
+ // eventual TLS implementation.
+ vec, b, err := vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello SessionID: %s", err)
+ }
+ if len(vec) > 32 {
+ return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
+ }
+
+ // Likewise, we're just checking the bare minimum of framing.
+ vec, b, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err)
+ }
+ if len(vec) < 2 || len(vec)%2 != 0 {
+ return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
+ }
+
+ vec, b, err = vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err)
+ }
+ if len(vec) < 1 {
+ return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
+ }
+
+ // Finally, we reach the extensions.
+ if len(b) == 0 {
+ // No extensions. This is not an error, it just means we have
+ // no SNI payload.
+ return nil, nil
+ }
+ b, vec, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extensions: %s", err)
+ }
+ if len(vec) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec))
+ }
+
+ for len(b) >= 4 {
+ typ := binary.BigEndian.Uint16(b[:2])
+ vec, b, err = vector(b[2:], 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err)
+ }
+ if typ == sniExtensionID {
+ // Found the SNI extension, return its payload. We don't
+ // care about anything in the packet beyond this point.
+ return vec, nil
+ }
+ }
+
+ if len(b) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b))
+ }
+
+ // Successfully parsed all extensions, but there was no SNI.
+ return nil, nil
+}
+
+const maxTLSRecordLength = 16384
+
+// Read one TLS record, which must be for the handshake protocol, from r.
+func handshakeRecord(r io.Reader) ([]byte, int, error) {
+ var hdr struct {
+ Type uint8
+ Major, Minor uint8
+ Length uint16
+ }
+ if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
+ return nil, 0, fmt.Errorf("reading TLS record header: %s", err)
+ }
+
+ if hdr.Type != 22 {
+ return nil, 0, fmt.Errorf("TLS record is not a handshake")
+ }
+
+ if hdr.Major != 3 {
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+ switch hdr.Minor {
+ case 1, 2, 3:
+ // TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+
+ if hdr.Length > maxTLSRecordLength {
+ return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength)
+ }
+
+ ret := make([]byte, hdr.Length)
+ if _, err := io.ReadFull(r, ret); err != nil {
+ return nil, 0, err
+ }
+
+ return ret, int(hdr.Minor), nil
+}
+
+func vector(b []byte, lenBytes int) ([]byte, []byte, error) {
+ if len(b) < lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ var l int
+ for _, b := range b[:lenBytes] {
+ l = (l << 8) + int(b)
+ }
+ if len(b) < l+lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil
+}
diff --git a/cmd/tlsrouter/sni_test.go b/cmd/tlsrouter/sni_test.go
new file mode 100644
index 0000000..8c87d24
--- /dev/null
+++ b/cmd/tlsrouter/sni_test.go
@@ -0,0 +1,456 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+func slice(l int) []byte {
+ ret := make([]byte, l)
+ for i := 0; i < l; i++ {
+ ret[i] = byte(i)
+ }
+ return ret
+}
+
+func vec(l, lenBytes int) []byte {
+ b := slice(l)
+ vecLen := len(b)
+ ret := make([]byte, vecLen+l)
+ for i := l - 1; i >= 0; i-- {
+ ret[i] = byte(vecLen & 0xff)
+ vecLen >>= 8
+ }
+ copy(ret[l:], b)
+ return ret
+}
+
+func packet(bs ...[]byte) []byte {
+ var ret []byte
+ for _, b := range bs {
+ ret = append(ret, b...)
+ }
+ return ret
+}
+
+func offset(b []byte, off int) []byte {
+ return b[off:]
+}
+
+func TestVector(t *testing.T) {
+ tests := []struct {
+ in []byte
+ inLen int
+ out1, out2 []byte
+ err bool
+ }{
+ {
+ // 1b length
+ append([]byte{3}, slice(10)...), 1,
+ slice(3), offset(slice(10), 3), false,
+ },
+ {
+ // 1b length, no trailer
+ append([]byte{10}, slice(10)...), 1,
+ slice(10), []byte{}, false,
+ },
+ {
+ // 1b length, no vector
+ append([]byte{0}, slice(10)...), 1,
+ []byte{}, slice(10), false,
+ },
+ {
+ // 1b length, no vector or trailer
+ []byte{0}, 1,
+ []byte{}, []byte{}, false,
+ },
+ {
+ // 2b length, LSB only
+ append([]byte{0, 3}, slice(10)...), 2,
+ slice(3), offset(slice(10), 3), false,
+ },
+ {
+ // 2b length, MSB only
+ append([]byte{3, 0}, slice(1024)...), 2,
+ slice(768), offset(slice(1024), 768), false,
+ },
+ {
+ // 2b length, both bytes
+ append([]byte{3, 2}, slice(1024)...), 2,
+ slice(770), offset(slice(1024), 770), false,
+ },
+ {
+ // 3b length
+ append([]byte{1, 2, 3}, slice(100000)...), 3,
+ slice(66051), offset(slice(100000), 66051), false,
+ },
+ {
+ // no bytes
+ []byte{}, 1,
+ nil, nil, true,
+ },
+ {
+ // no slice
+ nil, 1,
+ nil, nil, true,
+ },
+ {
+ // not enough bytes for length
+ []byte{1}, 2,
+ nil, nil, true,
+ },
+ {
+ // no bytes after length
+ []byte{1}, 1,
+ nil, nil, true,
+ },
+ {
+ // not enough bytes for vector
+ []byte{4, 1, 2}, 1,
+ nil, nil, true,
+ },
+ }
+
+ for _, test := range tests {
+ actual1, actual2, err := vector(test.in, test.inLen)
+ if !test.err && (err != nil) {
+ t.Errorf("unexpected error %q", err)
+ }
+ if test.err && (err == nil) {
+ t.Errorf("unexpected success")
+ }
+ if err != nil {
+ continue
+ }
+ if !bytes.Equal(actual1, test.out1) {
+ t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1)
+ }
+ if !bytes.Equal(actual2, test.out2) {
+ t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2)
+ }
+ }
+}
+
+func TestHandshakeRecord(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out []byte
+ tlsver int
+ }{
+ {
+ // TLS 1.0, 1b packet
+ []byte{22, 3, 1, 0, 1, 3},
+ []byte{3},
+ 1,
+ },
+ {
+ // TLS 1.1, 1b packet
+ []byte{22, 3, 2, 0, 1, 3},
+ []byte{3},
+ 2,
+ },
+ {
+ // TLS 1.2, 1b packet
+ []byte{22, 3, 3, 0, 1, 3},
+ []byte{3},
+ 3,
+ },
+ {
+ // TLS 1.2, no payload bytes
+ []byte{22, 3, 3, 0, 0},
+ []byte{},
+ 3,
+ },
+ {
+ // TLS 1.2, >255b payload w/ trailing stuff
+ append([]byte{22, 3, 3, 3, 2}, slice(1024)...),
+ slice(770),
+ 3,
+ },
+ {
+ // TLS 1.2, 2^14 payload
+ append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...),
+ slice(maxTLSRecordLength),
+ 3,
+ },
+ {
+ // TLS 1.2, >2^14 payload
+ append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...),
+ nil,
+ 0,
+ },
+ {
+ // TLS 1.2, truncated payload
+ []byte{22, 3, 3, 0, 4, 1, 2},
+ nil,
+ 0,
+ },
+ {
+ // truncated header
+ []byte{22},
+ nil,
+ 0,
+ },
+ {
+ // wrong record type
+ []byte{42, 3, 3, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // wrong TLS major version
+ []byte{22, 2, 3, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // wrong TLS minor version
+ []byte{22, 3, 42, 0, 1, 3},
+ nil,
+ 0,
+ },
+ {
+ // Obsolete SSL 3.0
+ []byte{22, 3, 0, 0, 1, 3},
+ nil,
+ 0,
+ },
+ }
+
+ for _, test := range tests {
+ r := bytes.NewBuffer(test.in)
+ actual, tlsver, err := handshakeRecord(r)
+ if test.out == nil && err == nil {
+ t.Errorf("unexpected success")
+ continue
+ }
+ if !bytes.Equal(test.out, actual) {
+ t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out)
+ }
+ if tlsver != test.tlsver {
+ t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver)
+ }
+ }
+}
+
+func TestParseHello(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out []byte
+ err bool
+ }{
+ {
+ // Wrong record type
+ packet([]byte{42, 0, 0, 1, 1}),
+ nil,
+ true,
+ },
+ {
+ // Truncated payload
+ packet([]byte{1, 0, 0, 1}),
+ nil,
+ true,
+ },
+ {
+ // Payload too small
+ packet([]byte{1, 0, 0, 1, 1}),
+ nil,
+ true,
+ },
+ {
+ // Unknown major version
+ packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // Unknown minor version
+ packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // Missing required variadic fields
+ packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)),
+ nil,
+ true,
+ },
+ {
+ // All zero variadic fields (no ciphersuites, no compression)
+ packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}),
+ nil,
+ true,
+ },
+ {
+ // All zero variadic fields (no ciphersuites, no compression, nonzero session ID)
+ packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}),
+ nil,
+ true,
+ },
+ {
+ // Session + ciphersuites, no compression
+ packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}),
+ nil,
+ true,
+ },
+ {
+ // First valid packet. TLS 1.0, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.1, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.2, no extensions present.
+ packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}),
+ nil,
+ false,
+ },
+ {
+ // TLS 1.2, garbage extensions
+ packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)),
+ nil,
+ true,
+ },
+ {
+ // empty extensions vector
+ packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}),
+ nil,
+ false,
+ },
+ {
+ // non-SNI extensions
+ packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}),
+ nil,
+ false,
+ },
+ {
+ // SNI present
+ packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}),
+ []byte{182},
+ false,
+ },
+ {
+ // Longer SNI
+ packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)),
+ slice(4),
+ false,
+ },
+ {
+ // Embedded SNI
+ packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}),
+ slice(4),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ actual, err := parseHello(test.in)
+ if test.err {
+ if err == nil {
+ t.Errorf("unexpected success")
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ continue
+ }
+ if !bytes.Equal(test.out, actual) {
+ t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out)
+ }
+ }
+}
+
+func TestParseSNI(t *testing.T) {
+ tests := []struct {
+ in []byte
+ out string
+ err bool
+ }{
+ {
+ // Empty packet
+ []byte{},
+ "",
+ true,
+ },
+ {
+ // Truncated packet
+ []byte{0, 2, 1},
+ "",
+ true,
+ },
+ {
+ // Truncated packet within SNI vector
+ []byte{0, 2, 1, 2},
+ "",
+ true,
+ },
+ {
+ // Wrong SNI kind
+ []byte{0, 3, 1, 0, 0},
+ "",
+ false,
+ },
+ {
+ // Right SNI kind, no hostname
+ []byte{0, 3, 0, 0, 0},
+ "",
+ false,
+ },
+ {
+ // SNI hostname
+ packet([]byte{0, 6, 0, 0, 3}, []byte("lol")),
+ "lol",
+ false,
+ },
+ {
+ // Multiple SNI kinds
+ packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}),
+ "lol",
+ false,
+ },
+ {
+ // Multiple SNI hostnames (illegal, but we just return the first)
+ packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")),
+ "bar",
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ actual, err := parseSNI(test.in)
+ if test.err {
+ if err == nil {
+ t.Errorf("unexpected success")
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ continue
+ }
+ if test.out != actual {
+ t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out)
+ }
+ }
+}
diff --git a/scripts/prune_old_versions.go b/scripts/prune_old_versions.go
new file mode 100644
index 0000000..42e031e
--- /dev/null
+++ b/scripts/prune_old_versions.go
@@ -0,0 +1,150 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "sort"
+ "strings"
+ "time"
+)
+
+var (
+ user = flag.String("user", "", "username")
+ repo = flag.String("repo", "", "repository name")
+ pkgType = flag.String("pkg-type", "deb", "Package type, e.g. 'deb'")
+ distro = flag.String("distro", "", "distro name, e.g. 'debian'")
+ distroVersion = flag.String("version", "", "distro version, e.g. 'stretch'")
+ pkg = flag.String("package", "", "package name")
+ arch = flag.String("arch", "", "package architecture")
+ limit = flag.Int("limit", 2, "package versions to keep")
+)
+
+func fatalf(msg string, args ...interface{}) {
+ fmt.Printf(msg+"\n", args...)
+ os.Exit(1)
+}
+
+func main() {
+ flag.Parse()
+ if *user == "" {
+ fatalf("missing -user")
+ }
+ if *repo == "" {
+ fatalf("missing -repo")
+ }
+ if *pkgType == "" {
+ fatalf("missing -pkg-type")
+ }
+ if *distro == "" {
+ fatalf("missing -distro")
+ }
+ if *distroVersion == "" {
+ fatalf("missing -version")
+ }
+ if *pkg == "" {
+ fatalf("missing -package")
+ }
+ if *arch == "" {
+ fatalf("missing -arch")
+ }
+ if *limit < 1 {
+ fatalf("limit must be >= 1")
+ }
+
+ files, err := packageVersions(*user, *repo, *pkgType, *distro, *distroVersion, *pkg, *arch)
+ if err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+ if len(files) <= *limit {
+ fmt.Println("Below limit, no packages deleted")
+ return
+ }
+ delete := files[:len(files)-*limit]
+ keep := files[len(files)-*limit:]
+ if err = deletePackages(delete); err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+
+ fmt.Printf("Deleted:\n\n%s\n\nKept:\n\n%s\n", strings.Join(delete, "\n"), strings.Join(keep, "\n"))
+}
+
+type packageMeta struct {
+ Created time.Time `json:"created_at"`
+ Filename string `json:"filename"`
+}
+
+type metaSort []packageMeta
+
+func (m metaSort) Len() int { return len(m) }
+func (m metaSort) Less(i, j int) bool { return m[i].Created.Before(m[j].Created) }
+func (m metaSort) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+
+func packageVersions(user, repo, typ, distro, version, pkgname, arch string) ([]string, error) {
+ url := fmt.Sprintf("https://%s:@packagecloud.io/api/v1/repos/%s/%s/package/%s/%s/%s/%s/%s/versions.json", os.Getenv("PACKAGECLOUD_API_KEY"), user, repo, typ, distro, version, pkgname, arch)
+ resp, err := http.Get(url)
+ if err != nil {
+ return nil, fmt.Errorf("get versions.json: %s", err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ msg, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("get error message of versions.json get: %s", err)
+ }
+ return nil, fmt.Errorf("get versions.json: %s (%q)", resp.Status, string(msg))
+ }
+
+ var files []packageMeta
+ if err := json.NewDecoder(resp.Body).Decode(&files); err != nil {
+ return nil, fmt.Errorf("decode versions.json: %s", err)
+ }
+
+ // Newest first
+ sort.Sort(metaSort(files))
+
+ var ret []string
+ for _, meta := range files {
+ ret = append(ret, fmt.Sprintf("/api/v1/repos/%s/%s/%s/%s/%s", user, repo, distro, version, meta.Filename))
+ }
+
+ return ret, nil
+}
+
+func deletePackages(urls []string) error {
+ for _, url := range urls {
+ fullURL := fmt.Sprintf("https://%s:@packagecloud.io%s", os.Getenv("PACKAGECLOUD_API_KEY"), url)
+ req, err := http.NewRequest("DELETE", fullURL, nil)
+ if err != nil {
+ return fmt.Errorf("build delete request for %s: %s", url, err)
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("delete %s: %s", url, err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ return fmt.Errorf("delete %s: %s", url, resp.Status)
+ }
+ }
+ return nil
+}
diff --git a/systemd/tlsrouter.service b/systemd/tlsrouter.service
new file mode 100644
index 0000000..23e8fe1
--- /dev/null
+++ b/systemd/tlsrouter.service
@@ -0,0 +1,25 @@
+[Unit]
+Description=TLS SNI proxy
+Documentation=https://github.com/google/tlsrouter
+
+[Service]
+WorkingDirectory=/tmp
+ExecStart=/usr/bin/tlsrouter -conf /etc/tlsrouter.conf
+Restart=always
+User=nobody
+Group=nogroup
+CapabilityBoundingSet=CAP_NET_BIND_SERVICE
+AmbientCapabilities=CAP_NET_BIND_SERVICE
+PrivateTmp=true
+PrivateDevices=true
+ProtectSystem=strict
+ProtectHome=true
+ProtectKernelTunables=true
+ProtectControlGroups=true
+ProtectKernelModules=true
+NoNewPrivileges=true
+SystemCallFilter=~@clock @cpu-emulation @debug @keyring @module @mount @obsolete @privileged @raw-io
+RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX
+
+[Install]
+WantedBy=multi-user.target