From 3eb49e9b3902de95b3c9f5729d51ca7f61f02e5a Mon Sep 17 00:00:00 2001 From: David Anderson Date: Sun, 2 Jul 2017 14:46:05 -0700 Subject: Move tlsrouter to cmd/tlsrouter, in preparation for rewrite as a pkg. --- .travis.yml | 8 +- acme.go | 101 ---------- cmd/tlsrouter/acme.go | 101 ++++++++++ cmd/tlsrouter/config.go | 146 ++++++++++++++ cmd/tlsrouter/config_test.go | 61 ++++++ cmd/tlsrouter/e2e_test.go | 224 +++++++++++++++++++++ cmd/tlsrouter/main.go | 191 ++++++++++++++++++ cmd/tlsrouter/sni.go | 232 ++++++++++++++++++++++ cmd/tlsrouter/sni_test.go | 456 +++++++++++++++++++++++++++++++++++++++++++ config.go | 146 -------------- config_test.go | 61 ------ e2e_test.go | 224 --------------------- main.go | 191 ------------------ sni.go | 232 ---------------------- sni_test.go | 456 ------------------------------------------- 15 files changed, 1415 insertions(+), 1415 deletions(-) delete mode 100644 acme.go create mode 100644 cmd/tlsrouter/acme.go create mode 100644 cmd/tlsrouter/config.go create mode 100644 cmd/tlsrouter/config_test.go create mode 100644 cmd/tlsrouter/e2e_test.go create mode 100644 cmd/tlsrouter/main.go create mode 100644 cmd/tlsrouter/sni.go create mode 100644 cmd/tlsrouter/sni_test.go delete mode 100644 config.go delete mode 100644 config_test.go delete mode 100644 e2e_test.go delete mode 100644 main.go delete mode 100644 sni.go delete mode 100644 sni_test.go diff --git a/.travis.yml b/.travis.yml index 9e5c641..7d98c9f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,9 +11,9 @@ install: before_script: script: - go get -t . -- go build . -- go test . -- go vet . +- go build ./... +- go test ./... +- go vet ./... - golint -set_exit_status . jobs: @@ -30,7 +30,7 @@ jobs: --maintainer "David Anderson " --description "TLS SNI router" --url "https://github.com/google/tlsrouter" - ./tlsrouter=/usr/bin/tlsrouter + ./cmd/tlsrouter/tlsrouter=/usr/bin/tlsrouter ./systemd/tlsrouter.service=/lib/systemd/system/tlsrouter.service deploy: - provider: packagecloud diff --git a/acme.go b/acme.go deleted file mode 100644 index ab8d59a..0000000 --- a/acme.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2016 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "crypto/tls" - "net" - "time" -) - -type acmeCacheEntry struct { - backend string - expires time.Time -} - -// ACME locates backends that are attempting ACME SNI-based validation. -type ACME struct { - backends []string - // *.acme.invalid domain to cache entry - cache map[string]acmeCacheEntry -} - -// Match returns the backend for hostname, if one is found. -func (s *ACME) Match(hostname string) string { - c := s.cache[hostname] - if time.Now().Before(c.expires) { - return c.backend - } - - // Cache entry is either expired or invalid, need to figure out - // which backend is the right one. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - ch := make(chan string, len(s.backends)) - for _, backend := range s.backends { - go tryAcme(ctx, ch, backend, hostname) - } - for range s.backends { - backend := <-ch - if backend != "" { - s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)} - return backend - } - } - - // No usable backends found :( - s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)} - return "" -} - -func tryAcme(ctx context.Context, ch chan string, backend, hostname string) { - var res string - var err error - defer func() { ch <- res }() - - dialer := net.Dialer{Timeout: 10 * time.Second} - conn, err := dialer.DialContext(ctx, "tcp", backend) - if err != nil { - return - } - defer conn.Close() - - deadline, ok := ctx.Deadline() - if ok { - conn.SetDeadline(deadline) - } - client := tls.Client(conn, &tls.Config{ - ServerName: hostname, - InsecureSkipVerify: true, - }) - if err != nil { - return - } - if err = client.Handshake(); err != nil { - return - } - - certs := client.ConnectionState().PeerCertificates - if len(certs) == 0 { - return - } - if err = certs[0].VerifyHostname(hostname); err != nil { - return - } - - res = backend -} diff --git a/cmd/tlsrouter/acme.go b/cmd/tlsrouter/acme.go new file mode 100644 index 0000000..ab8d59a --- /dev/null +++ b/cmd/tlsrouter/acme.go @@ -0,0 +1,101 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "crypto/tls" + "net" + "time" +) + +type acmeCacheEntry struct { + backend string + expires time.Time +} + +// ACME locates backends that are attempting ACME SNI-based validation. +type ACME struct { + backends []string + // *.acme.invalid domain to cache entry + cache map[string]acmeCacheEntry +} + +// Match returns the backend for hostname, if one is found. +func (s *ACME) Match(hostname string) string { + c := s.cache[hostname] + if time.Now().Before(c.expires) { + return c.backend + } + + // Cache entry is either expired or invalid, need to figure out + // which backend is the right one. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ch := make(chan string, len(s.backends)) + for _, backend := range s.backends { + go tryAcme(ctx, ch, backend, hostname) + } + for range s.backends { + backend := <-ch + if backend != "" { + s.cache[hostname] = acmeCacheEntry{backend, time.Now().Add(5 * time.Second)} + return backend + } + } + + // No usable backends found :( + s.cache[hostname] = acmeCacheEntry{"", time.Now().Add(5 * time.Second)} + return "" +} + +func tryAcme(ctx context.Context, ch chan string, backend, hostname string) { + var res string + var err error + defer func() { ch <- res }() + + dialer := net.Dialer{Timeout: 10 * time.Second} + conn, err := dialer.DialContext(ctx, "tcp", backend) + if err != nil { + return + } + defer conn.Close() + + deadline, ok := ctx.Deadline() + if ok { + conn.SetDeadline(deadline) + } + client := tls.Client(conn, &tls.Config{ + ServerName: hostname, + InsecureSkipVerify: true, + }) + if err != nil { + return + } + if err = client.Handshake(); err != nil { + return + } + + certs := client.ConnectionState().PeerCertificates + if len(certs) == 0 { + return + } + if err = certs[0].VerifyHostname(hostname); err != nil { + return + } + + res = backend +} diff --git a/cmd/tlsrouter/config.go b/cmd/tlsrouter/config.go new file mode 100644 index 0000000..1c8151f --- /dev/null +++ b/cmd/tlsrouter/config.go @@ -0,0 +1,146 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "regexp" + "strings" + "sync" +) + +// A Route maps a match on a domain name to a backend. +type Route struct { + match *regexp.Regexp + backend string + proxyInfo bool +} + +// Config stores the TLS routing configuration. +type Config struct { + mu sync.Mutex + routes []Route + acme *ACME +} + +func dnsRegex(s string) (*regexp.Regexp, error) { + if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' { + return regexp.Compile(s[1 : len(s)-1]) + } + + var b []string + for _, f := range strings.Split(s, ".") { + switch f { + case "*": + b = append(b, `[^.]+`) + case "": + return nil, fmt.Errorf("DNS name %q has empty label", s) + default: + b = append(b, regexp.QuoteMeta(f)) + } + } + return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`))) +} + +// Match returns the backend for hostname, and whether to use the PROXY protocol. +func (c *Config) Match(hostname string) (string, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + if strings.HasSuffix(hostname, ".acme.invalid") { + return c.acme.Match(hostname), false + } + + for _, r := range c.routes { + if r.match.MatchString(hostname) { + return r.backend, r.proxyInfo + } + } + return "", false +} + +// Read replaces the current Config with one read from r. +func (c *Config) Read(r io.Reader) error { + var routes []Route + var backends []string + + s := bufio.NewScanner(r) + for s.Scan() { + if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") { + // Comment, ignore. + continue + } + + fs := strings.Fields(s.Text()) + switch len(fs) { + case 0: + continue + case 1: + return fmt.Errorf("invalid %q on a line by itself", s.Text()) + case 2: + re, err := dnsRegex(fs[0]) + if err != nil { + return err + } + routes = append(routes, Route{re, fs[1], false}) + backends = append(backends, fs[1]) + case 3: + re, err := dnsRegex(fs[0]) + if err != nil { + return err + } + if fs[2] != "PROXY" { + return errors.New("third item on a line can only be PROXY") + } + routes = append(routes, Route{re, fs[1], true}) + backends = append(backends, fs[1]) + default: + // TODO: multiple backends? + return fmt.Errorf("too many fields on line: %q", s.Text()) + } + } + if err := s.Err(); err != nil { + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + c.routes = routes + c.acme = &ACME{ + backends: backends, + cache: make(map[string]acmeCacheEntry), + } + return nil +} + +// ReadFile replaces the current Config with one read from path. +func (c *Config) ReadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + return c.Read(f) +} + +// ReadString replaces the current Config with one read from cfg. +func (c *Config) ReadString(cfg string) error { + b := bytes.NewBufferString(cfg) + return c.Read(b) +} diff --git a/cmd/tlsrouter/config_test.go b/cmd/tlsrouter/config_test.go new file mode 100644 index 0000000..9819b91 --- /dev/null +++ b/cmd/tlsrouter/config_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestConfig(t *testing.T) { + type result struct { + backend string + proxy bool + } + + cases := []struct { + Config string + Tests map[string]result + }{ + { + Config: ` +# Comment +go.universe.tf 1.2.3.4 +*.universe.tf 2.3.4.5 +# Comment +google.* 3.4.5.6 +/gooo+gle\.com/ 4.5.6.7 +foobar.net 6.7.8.9 PROXY +`, + Tests: map[string]result{ + "go.universe.tf": result{"1.2.3.4", false}, + "foo.universe.tf": result{"2.3.4.5", false}, + "bar.universe.tf": result{"2.3.4.5", false}, + "google.com": result{"3.4.5.6", false}, + "google.fr": result{"3.4.5.6", false}, + "goooooooooogle.com": result{"4.5.6.7", false}, + "foobar.net": result{"6.7.8.9", true}, + + "blah.com": result{"", false}, + "google.com.br": result{"", false}, + "foo.bar.universe.tf": result{"", false}, + "goooooglexcom": result{"", false}, + }, + }, + } + + for _, test := range cases { + var cfg Config + if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil { + t.Fatalf("Failed to read config (%s):\n%q", err, test.Config) + } + + for hostname, expected := range test.Tests { + backend, proxy := cfg.Match(hostname) + if expected.backend != backend { + t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend) + } + if expected.proxy != proxy { + t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy) + } + } + } +} diff --git a/cmd/tlsrouter/e2e_test.go b/cmd/tlsrouter/e2e_test.go new file mode 100644 index 0000000..c53e8c5 --- /dev/null +++ b/cmd/tlsrouter/e2e_test.go @@ -0,0 +1,224 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "strings" + "sync/atomic" + "testing" + "time" + + proxyproto "github.com/armon/go-proxyproto" +) + +func TestRouting(t *testing.T) { + // Backend servers + s1, err := serveTLS(t, "server1", false, "test.com") + if err != nil { + t.Fatalf("serve TLS server1: %s", err) + } + defer s1.Close() + + s2, err := serveTLS(t, "server2", false, "foo.net") + if err != nil { + t.Fatalf("serve TLS server2: %s", err) + } + defer s2.Close() + + s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid") + if err != nil { + t.Fatalf("server TLS server3: %s", err) + } + defer s3.Close() + + s4, err := serveTLS(t, "server4", true, "proxy.design") + if err != nil { + t.Fatalf("server TLS server4: %s", err) + } + defer s4.Close() + + // One proxy + var p Proxy + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("create listener: %s", err) + } + defer l.Close() + go p.Serve(l) + + if err := p.Config.ReadString(fmt.Sprintf(` +test.com %s +foo.net %s +borkbork.tf %s +proxy.design %s PROXY +`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil { + t.Fatalf("configure proxy: %s", err) + } + + for _, test := range []struct { + N, V string + P *x509.CertPool + OK bool + Transparent bool + }{ + {"test.com", "server1", s1.Pool, true, false}, + {"foo.net", "server2", s2.Pool, true, false}, + {"bar.org", "", s1.Pool, false, false}, + {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false}, + {"proxy.design", "server4", s4.Pool, true, true}, + } { + res, transparent, err := getTLS(l.Addr().String(), test.N, test.P) + switch { + case test.OK && err != nil: + t.Fatalf("get %q failed: %s", test.N, err) + case !test.OK && err == nil: + t.Fatalf("get %q should have failed, but returned %q", test.N, res) + case test.OK && res != test.V: + t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) + case test.OK && transparent != test.Transparent: + t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent) + } + } +} + +// getTLS attempts to set up a TLS session using the given proxy +// address, domain, and cert pool. It returns the value served by the +// server, as well as a bool indicating whether the server knew the +// true client address, indicating that the PROXY protocol was in use. +func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) { + cfg := tls.Config{ + RootCAs: pool, + ServerName: domain, + } + conn, err := tls.Dial("tcp", addr, &cfg) + if err != nil { + return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) + } + defer conn.Close() + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + } + fs := strings.Split(string(bs), " ") + if len(fs) != 2 { + return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs)) + } + transparent := fs[1] == conn.LocalAddr().String() + return fs[0], transparent, nil +} + +type tlsServer struct { + Domains []string + Value string + Pool *x509.CertPool + Test *testing.T + NumHits uint32 + l net.Listener +} + +func (s *tlsServer) Serve() { + for { + c, err := s.l.Accept() + if err != nil { + s.Test.Logf("accept failed on %q: %s", s.Domains, err) + return + } + atomic.AddUint32(&s.NumHits, 1) + fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr()) + c.Close() + } +} + +func (s *tlsServer) Addr() string { + return s.l.Addr().String() +} + +func (s *tlsServer) Close() error { + return s.l.Close() +} + +func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) { + cert, pool, err := selfSignedCert(domains) + if err != nil { + return nil, err + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + cfg.BuildNameToCertificate() + + var l net.Listener + + l, err = net.Listen("tcp", "localhost:0") + if err != nil { + return nil, err + } + + if understandProxy { + l = &proxyproto.Listener{Listener: l} + } + + l = tls.NewListener(l, cfg) + + ret := &tlsServer{ + Domains: domains, + Value: value, + Pool: pool, + Test: t, + l: l, + } + go ret.Serve() + return ret, nil +} + +func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) { + pkey, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + return tls.Certificate{}, nil, err + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domains[0], + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: domains[1:], + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey) + if err != nil { + return tls.Certificate{}, nil, err + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + return tls.Certificate{}, nil, err + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert.Bytes()) { + return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains) + } + + return tlscert, pool, nil +} diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go new file mode 100644 index 0000000..ff1a816 --- /dev/null +++ b/cmd/tlsrouter/main.go @@ -0,0 +1,191 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +var ( + cfgFile = flag.String("conf", "", "configuration file") + listen = flag.String("listen", ":443", "listening port") + helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") +) + +func main() { + flag.Parse() + + p := &Proxy{} + if err := p.Config.ReadFile(*cfgFile); err != nil { + log.Fatalf("Failed to read config %q: %s", *cfgFile, err) + } + + log.Fatalf("%s", p.ListenAndServe(*listen)) +} + +// Proxy routes connections to backends based on a Config. +type Proxy struct { + Config Config + l net.Listener +} + +// Serve accepts connections from l and routes them according to TLS SNI. +func (p *Proxy) Serve(l net.Listener) error { + for { + c, err := l.Accept() + if err != nil { + return fmt.Errorf("accept new conn: %s", err) + } + + conn := &Conn{ + TCPConn: c.(*net.TCPConn), + config: &p.Config, + } + go conn.proxy() + } +} + +// ListenAndServe creates a listener on addr calls Serve on it. +func (p *Proxy) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("create listener: %s", err) + } + return p.Serve(l) +} + +// A Conn handles the TLS proxying of one user connection. +type Conn struct { + *net.TCPConn + config *Config + + tlsMinor int + hostname string + backend string + backendConn *net.TCPConn +} + +func (c *Conn) logf(msg string, args ...interface{}) { + msg = fmt.Sprintf(msg, args...) + log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg) +} + +func (c *Conn) abort(alert byte, msg string, args ...interface{}) { + c.logf(msg, args...) + alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert} + + if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil { + c.logf("error while setting write deadline during abort: %s", err) + // Do NOT send the alert if we can't set a write deadline, + // that could result in leaking a connection for an extended + // period. + return + } + + if _, err := c.Write(alertMsg); err != nil { + c.logf("error while sending alert: %s", err) + } +} + +func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) } +func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) } + +func (c *Conn) proxy() { + defer c.Close() + + if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil { + c.internalError("Setting read deadline for ClientHello: %s", err) + return + } + + var ( + err error + handshakeBuf bytes.Buffer + ) + c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf)) + if err != nil { + c.internalError("Extracting SNI: %s", err) + return + } + + c.logf("extracted SNI %s", c.hostname) + + if err = c.SetReadDeadline(time.Time{}); err != nil { + c.internalError("Clearing read deadline for ClientHello: %s", err) + return + } + + addProxyHeader := false + c.backend, addProxyHeader = c.config.Match(c.hostname) + if c.backend == "" { + c.sniFailed("no backend found for %q", c.hostname) + return + } + + c.logf("routing %q to %q", c.hostname, c.backend) + backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) + if err != nil { + c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) + return + } + defer backend.Close() + + c.backendConn = backend.(*net.TCPConn) + + // If the backend supports the HAProxy PROXY protocol, give it the + // real source information about the connection. + if addProxyHeader { + remote := c.TCPConn.RemoteAddr().(*net.TCPAddr) + local := c.TCPConn.LocalAddr().(*net.TCPAddr) + family := "TCP6" + if remote.IP.To4() != nil { + family = "TCP4" + } + if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil { + c.internalError("failed to send PROXY header to %q: %s", c.backend, err) + return + } + } + + // Replay the piece of the handshake we had to read to do the + // routing, then blindly proxy any other bytes. + if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { + c.internalError("failed to replay handshake to %q: %s", c.backend, err) + return + } + + var wg sync.WaitGroup + wg.Add(2) + go proxy(&wg, c.TCPConn, c.backendConn) + go proxy(&wg, c.backendConn, c.TCPConn) + wg.Wait() +} + +func proxy(wg *sync.WaitGroup, a, b net.Conn) { + defer wg.Done() + atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn) + if _, err := io.Copy(atcp, btcp); err != nil { + log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err) + } + btcp.CloseWrite() + atcp.CloseRead() +} diff --git a/cmd/tlsrouter/sni.go b/cmd/tlsrouter/sni.go new file mode 100644 index 0000000..ed79df2 --- /dev/null +++ b/cmd/tlsrouter/sni.go @@ -0,0 +1,232 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +func extractSNI(r io.Reader) (string, int, error) { + handshake, tlsver, err := handshakeRecord(r) + if err != nil { + return "", 0, fmt.Errorf("reading TLS record: %s", err) + } + + sni, err := parseHello(handshake) + if err != nil { + return "", 0, fmt.Errorf("reading ClientHello: %s", err) + } + if len(sni) == 0 { + // ClientHello did not present an SNI extension. Valid packet, + // no hostname. + return "", tlsver, nil + } + + hostname, err := parseSNI(sni) + if err != nil { + return "", 0, fmt.Errorf("parsing SNI extension: %s", err) + } + return hostname, tlsver, nil +} + +// Extract the indicated hostname, if any, from the given SNI +// extension bytes. +func parseSNI(b []byte) (string, error) { + b, _, err := vector(b, 2) + if err != nil { + return "", err + } + + var ret []byte + for len(b) >= 3 { + typ := b[0] + ret, b, err = vector(b[1:], 2) + if err != nil { + return "", fmt.Errorf("truncated SNI extension") + } + + if typ == sniHostnameID { + return string(ret), nil + } + } + + if len(b) != 0 { + return "", fmt.Errorf("trailing garbage at end of SNI extension") + } + + // No DNS-based SNI present. + return "", nil +} + +const sniExtensionID = 0 +const sniHostnameID = 0 + +// Parse a TLS handshake record as a ClientHello message and extract +// the SNI extension bytes, if any. +func parseHello(b []byte) ([]byte, error) { + if len(b) == 0 { + return nil, errors.New("zero length handshake record") + } + if b[0] != 1 { + return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0]) + } + + // We're expecting a stricter TLS parser to run after we've + // proxied, so we ignore any trailing bytes that might be present + // (e.g. another handshake message). + b, _, err := vector(b[1:], 3) + if err != nil { + return nil, fmt.Errorf("reading ClientHello: %s", err) + } + + // ClientHello must be at least 34 bytes to reach the first vector + // length byte. The actual minimal size is larger than that, but + // vector() will correctly handle truncated packets. + if len(b) < 34 { + return nil, errors.New("ClientHello packet too short") + } + + if b[0] != 3 { + return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1]) + } + switch b[1] { + case 1, 2, 3: + // TLS 1.0, TLS 1.1, TLS 1.2 + default: + return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1]) + } + + // Skip over version and random struct + b = b[34:] + + // We don't technically care about SessionID, but we care that the + // framing is well-formed all the way up to the SNI field, so that + // we are sure that we're pulling the same SNI bytes as the + // eventual TLS implementation. + vec, b, err := vector(b, 1) + if err != nil { + return nil, fmt.Errorf("reading ClientHello SessionID: %s", err) + } + if len(vec) > 32 { + return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec)) + } + + // Likewise, we're just checking the bare minimum of framing. + vec, b, err = vector(b, 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err) + } + if len(vec) < 2 || len(vec)%2 != 0 { + return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec)) + } + + vec, b, err = vector(b, 1) + if err != nil { + return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err) + } + if len(vec) < 1 { + return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec)) + } + + // Finally, we reach the extensions. + if len(b) == 0 { + // No extensions. This is not an error, it just means we have + // no SNI payload. + return nil, nil + } + b, vec, err = vector(b, 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello extensions: %s", err) + } + if len(vec) != 0 { + return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec)) + } + + for len(b) >= 4 { + typ := binary.BigEndian.Uint16(b[:2]) + vec, b, err = vector(b[2:], 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err) + } + if typ == sniExtensionID { + // Found the SNI extension, return its payload. We don't + // care about anything in the packet beyond this point. + return vec, nil + } + } + + if len(b) != 0 { + return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b)) + } + + // Successfully parsed all extensions, but there was no SNI. + return nil, nil +} + +const maxTLSRecordLength = 16384 + +// Read one TLS record, which must be for the handshake protocol, from r. +func handshakeRecord(r io.Reader) ([]byte, int, error) { + var hdr struct { + Type uint8 + Major, Minor uint8 + Length uint16 + } + if err := binary.Read(r, binary.BigEndian, &hdr); err != nil { + return nil, 0, fmt.Errorf("reading TLS record header: %s", err) + } + + if hdr.Type != 22 { + return nil, 0, fmt.Errorf("TLS record is not a handshake") + } + + if hdr.Major != 3 { + return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) + } + switch hdr.Minor { + case 1, 2, 3: + // TLS 1.0, TLS 1.1, TLS 1.2 + default: + return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) + } + + if hdr.Length > maxTLSRecordLength { + return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength) + } + + ret := make([]byte, hdr.Length) + if _, err := io.ReadFull(r, ret); err != nil { + return nil, 0, err + } + + return ret, int(hdr.Minor), nil +} + +func vector(b []byte, lenBytes int) ([]byte, []byte, error) { + if len(b) < lenBytes { + return nil, nil, errors.New("not enough space in packet for vector") + } + var l int + for _, b := range b[:lenBytes] { + l = (l << 8) + int(b) + } + if len(b) < l+lenBytes { + return nil, nil, errors.New("not enough space in packet for vector") + } + return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil +} diff --git a/cmd/tlsrouter/sni_test.go b/cmd/tlsrouter/sni_test.go new file mode 100644 index 0000000..8c87d24 --- /dev/null +++ b/cmd/tlsrouter/sni_test.go @@ -0,0 +1,456 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "testing" +) + +func slice(l int) []byte { + ret := make([]byte, l) + for i := 0; i < l; i++ { + ret[i] = byte(i) + } + return ret +} + +func vec(l, lenBytes int) []byte { + b := slice(l) + vecLen := len(b) + ret := make([]byte, vecLen+l) + for i := l - 1; i >= 0; i-- { + ret[i] = byte(vecLen & 0xff) + vecLen >>= 8 + } + copy(ret[l:], b) + return ret +} + +func packet(bs ...[]byte) []byte { + var ret []byte + for _, b := range bs { + ret = append(ret, b...) + } + return ret +} + +func offset(b []byte, off int) []byte { + return b[off:] +} + +func TestVector(t *testing.T) { + tests := []struct { + in []byte + inLen int + out1, out2 []byte + err bool + }{ + { + // 1b length + append([]byte{3}, slice(10)...), 1, + slice(3), offset(slice(10), 3), false, + }, + { + // 1b length, no trailer + append([]byte{10}, slice(10)...), 1, + slice(10), []byte{}, false, + }, + { + // 1b length, no vector + append([]byte{0}, slice(10)...), 1, + []byte{}, slice(10), false, + }, + { + // 1b length, no vector or trailer + []byte{0}, 1, + []byte{}, []byte{}, false, + }, + { + // 2b length, LSB only + append([]byte{0, 3}, slice(10)...), 2, + slice(3), offset(slice(10), 3), false, + }, + { + // 2b length, MSB only + append([]byte{3, 0}, slice(1024)...), 2, + slice(768), offset(slice(1024), 768), false, + }, + { + // 2b length, both bytes + append([]byte{3, 2}, slice(1024)...), 2, + slice(770), offset(slice(1024), 770), false, + }, + { + // 3b length + append([]byte{1, 2, 3}, slice(100000)...), 3, + slice(66051), offset(slice(100000), 66051), false, + }, + { + // no bytes + []byte{}, 1, + nil, nil, true, + }, + { + // no slice + nil, 1, + nil, nil, true, + }, + { + // not enough bytes for length + []byte{1}, 2, + nil, nil, true, + }, + { + // no bytes after length + []byte{1}, 1, + nil, nil, true, + }, + { + // not enough bytes for vector + []byte{4, 1, 2}, 1, + nil, nil, true, + }, + } + + for _, test := range tests { + actual1, actual2, err := vector(test.in, test.inLen) + if !test.err && (err != nil) { + t.Errorf("unexpected error %q", err) + } + if test.err && (err == nil) { + t.Errorf("unexpected success") + } + if err != nil { + continue + } + if !bytes.Equal(actual1, test.out1) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1) + } + if !bytes.Equal(actual2, test.out2) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2) + } + } +} + +func TestHandshakeRecord(t *testing.T) { + tests := []struct { + in []byte + out []byte + tlsver int + }{ + { + // TLS 1.0, 1b packet + []byte{22, 3, 1, 0, 1, 3}, + []byte{3}, + 1, + }, + { + // TLS 1.1, 1b packet + []byte{22, 3, 2, 0, 1, 3}, + []byte{3}, + 2, + }, + { + // TLS 1.2, 1b packet + []byte{22, 3, 3, 0, 1, 3}, + []byte{3}, + 3, + }, + { + // TLS 1.2, no payload bytes + []byte{22, 3, 3, 0, 0}, + []byte{}, + 3, + }, + { + // TLS 1.2, >255b payload w/ trailing stuff + append([]byte{22, 3, 3, 3, 2}, slice(1024)...), + slice(770), + 3, + }, + { + // TLS 1.2, 2^14 payload + append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...), + slice(maxTLSRecordLength), + 3, + }, + { + // TLS 1.2, >2^14 payload + append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...), + nil, + 0, + }, + { + // TLS 1.2, truncated payload + []byte{22, 3, 3, 0, 4, 1, 2}, + nil, + 0, + }, + { + // truncated header + []byte{22}, + nil, + 0, + }, + { + // wrong record type + []byte{42, 3, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS major version + []byte{22, 2, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS minor version + []byte{22, 3, 42, 0, 1, 3}, + nil, + 0, + }, + { + // Obsolete SSL 3.0 + []byte{22, 3, 0, 0, 1, 3}, + nil, + 0, + }, + } + + for _, test := range tests { + r := bytes.NewBuffer(test.in) + actual, tlsver, err := handshakeRecord(r) + if test.out == nil && err == nil { + t.Errorf("unexpected success") + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out) + } + if tlsver != test.tlsver { + t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver) + } + } +} + +func TestParseHello(t *testing.T) { + tests := []struct { + in []byte + out []byte + err bool + }{ + { + // Wrong record type + packet([]byte{42, 0, 0, 1, 1}), + nil, + true, + }, + { + // Truncated payload + packet([]byte{1, 0, 0, 1}), + nil, + true, + }, + { + // Payload too small + packet([]byte{1, 0, 0, 1, 1}), + nil, + true, + }, + { + // Unknown major version + packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)), + nil, + true, + }, + { + // Unknown minor version + packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)), + nil, + true, + }, + { + // Missing required variadic fields + packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression) + packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression, nonzero session ID) + packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}), + nil, + true, + }, + { + // Session + ciphersuites, no compression + packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}), + nil, + true, + }, + { + // First valid packet. TLS 1.0, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.1, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, garbage extensions + packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)), + nil, + true, + }, + { + // empty extensions vector + packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}), + nil, + false, + }, + { + // non-SNI extensions + packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}), + nil, + false, + }, + { + // SNI present + packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}), + []byte{182}, + false, + }, + { + // Longer SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)), + slice(4), + false, + }, + { + // Embedded SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}), + slice(4), + false, + }, + } + + for _, test := range tests { + actual, err := parseHello(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out) + } + } +} + +func TestParseSNI(t *testing.T) { + tests := []struct { + in []byte + out string + err bool + }{ + { + // Empty packet + []byte{}, + "", + true, + }, + { + // Truncated packet + []byte{0, 2, 1}, + "", + true, + }, + { + // Truncated packet within SNI vector + []byte{0, 2, 1, 2}, + "", + true, + }, + { + // Wrong SNI kind + []byte{0, 3, 1, 0, 0}, + "", + false, + }, + { + // Right SNI kind, no hostname + []byte{0, 3, 0, 0, 0}, + "", + false, + }, + { + // SNI hostname + packet([]byte{0, 6, 0, 0, 3}, []byte("lol")), + "lol", + false, + }, + { + // Multiple SNI kinds + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}), + "lol", + false, + }, + { + // Multiple SNI hostnames (illegal, but we just return the first) + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")), + "bar", + false, + }, + } + + for _, test := range tests { + actual, err := parseSNI(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if test.out != actual { + t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out) + } + } +} diff --git a/config.go b/config.go deleted file mode 100644 index 1c8151f..0000000 --- a/config.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2016 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "os" - "regexp" - "strings" - "sync" -) - -// A Route maps a match on a domain name to a backend. -type Route struct { - match *regexp.Regexp - backend string - proxyInfo bool -} - -// Config stores the TLS routing configuration. -type Config struct { - mu sync.Mutex - routes []Route - acme *ACME -} - -func dnsRegex(s string) (*regexp.Regexp, error) { - if len(s) >= 2 && s[0] == '/' && s[len(s)-1] == '/' { - return regexp.Compile(s[1 : len(s)-1]) - } - - var b []string - for _, f := range strings.Split(s, ".") { - switch f { - case "*": - b = append(b, `[^.]+`) - case "": - return nil, fmt.Errorf("DNS name %q has empty label", s) - default: - b = append(b, regexp.QuoteMeta(f)) - } - } - return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`))) -} - -// Match returns the backend for hostname, and whether to use the PROXY protocol. -func (c *Config) Match(hostname string) (string, bool) { - c.mu.Lock() - defer c.mu.Unlock() - - if strings.HasSuffix(hostname, ".acme.invalid") { - return c.acme.Match(hostname), false - } - - for _, r := range c.routes { - if r.match.MatchString(hostname) { - return r.backend, r.proxyInfo - } - } - return "", false -} - -// Read replaces the current Config with one read from r. -func (c *Config) Read(r io.Reader) error { - var routes []Route - var backends []string - - s := bufio.NewScanner(r) - for s.Scan() { - if strings.HasPrefix(strings.TrimSpace(s.Text()), "#") { - // Comment, ignore. - continue - } - - fs := strings.Fields(s.Text()) - switch len(fs) { - case 0: - continue - case 1: - return fmt.Errorf("invalid %q on a line by itself", s.Text()) - case 2: - re, err := dnsRegex(fs[0]) - if err != nil { - return err - } - routes = append(routes, Route{re, fs[1], false}) - backends = append(backends, fs[1]) - case 3: - re, err := dnsRegex(fs[0]) - if err != nil { - return err - } - if fs[2] != "PROXY" { - return errors.New("third item on a line can only be PROXY") - } - routes = append(routes, Route{re, fs[1], true}) - backends = append(backends, fs[1]) - default: - // TODO: multiple backends? - return fmt.Errorf("too many fields on line: %q", s.Text()) - } - } - if err := s.Err(); err != nil { - return err - } - - c.mu.Lock() - defer c.mu.Unlock() - c.routes = routes - c.acme = &ACME{ - backends: backends, - cache: make(map[string]acmeCacheEntry), - } - return nil -} - -// ReadFile replaces the current Config with one read from path. -func (c *Config) ReadFile(path string) error { - f, err := os.Open(path) - if err != nil { - return err - } - return c.Read(f) -} - -// ReadString replaces the current Config with one read from cfg. -func (c *Config) ReadString(cfg string) error { - b := bytes.NewBufferString(cfg) - return c.Read(b) -} diff --git a/config_test.go b/config_test.go deleted file mode 100644 index 9819b91..0000000 --- a/config_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package main - -import ( - "bytes" - "testing" -) - -func TestConfig(t *testing.T) { - type result struct { - backend string - proxy bool - } - - cases := []struct { - Config string - Tests map[string]result - }{ - { - Config: ` -# Comment -go.universe.tf 1.2.3.4 -*.universe.tf 2.3.4.5 -# Comment -google.* 3.4.5.6 -/gooo+gle\.com/ 4.5.6.7 -foobar.net 6.7.8.9 PROXY -`, - Tests: map[string]result{ - "go.universe.tf": result{"1.2.3.4", false}, - "foo.universe.tf": result{"2.3.4.5", false}, - "bar.universe.tf": result{"2.3.4.5", false}, - "google.com": result{"3.4.5.6", false}, - "google.fr": result{"3.4.5.6", false}, - "goooooooooogle.com": result{"4.5.6.7", false}, - "foobar.net": result{"6.7.8.9", true}, - - "blah.com": result{"", false}, - "google.com.br": result{"", false}, - "foo.bar.universe.tf": result{"", false}, - "goooooglexcom": result{"", false}, - }, - }, - } - - for _, test := range cases { - var cfg Config - if err := cfg.Read(bytes.NewBufferString(test.Config)); err != nil { - t.Fatalf("Failed to read config (%s):\n%q", err, test.Config) - } - - for hostname, expected := range test.Tests { - backend, proxy := cfg.Match(hostname) - if expected.backend != backend { - t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend) - } - if expected.proxy != proxy { - t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy) - } - } - } -} diff --git a/e2e_test.go b/e2e_test.go deleted file mode 100644 index c53e8c5..0000000 --- a/e2e_test.go +++ /dev/null @@ -1,224 +0,0 @@ -package main - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "io/ioutil" - "math/big" - "net" - "strings" - "sync/atomic" - "testing" - "time" - - proxyproto "github.com/armon/go-proxyproto" -) - -func TestRouting(t *testing.T) { - // Backend servers - s1, err := serveTLS(t, "server1", false, "test.com") - if err != nil { - t.Fatalf("serve TLS server1: %s", err) - } - defer s1.Close() - - s2, err := serveTLS(t, "server2", false, "foo.net") - if err != nil { - t.Fatalf("serve TLS server2: %s", err) - } - defer s2.Close() - - s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid") - if err != nil { - t.Fatalf("server TLS server3: %s", err) - } - defer s3.Close() - - s4, err := serveTLS(t, "server4", true, "proxy.design") - if err != nil { - t.Fatalf("server TLS server4: %s", err) - } - defer s4.Close() - - // One proxy - var p Proxy - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("create listener: %s", err) - } - defer l.Close() - go p.Serve(l) - - if err := p.Config.ReadString(fmt.Sprintf(` -test.com %s -foo.net %s -borkbork.tf %s -proxy.design %s PROXY -`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil { - t.Fatalf("configure proxy: %s", err) - } - - for _, test := range []struct { - N, V string - P *x509.CertPool - OK bool - Transparent bool - }{ - {"test.com", "server1", s1.Pool, true, false}, - {"foo.net", "server2", s2.Pool, true, false}, - {"bar.org", "", s1.Pool, false, false}, - {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false}, - {"proxy.design", "server4", s4.Pool, true, true}, - } { - res, transparent, err := getTLS(l.Addr().String(), test.N, test.P) - switch { - case test.OK && err != nil: - t.Fatalf("get %q failed: %s", test.N, err) - case !test.OK && err == nil: - t.Fatalf("get %q should have failed, but returned %q", test.N, res) - case test.OK && res != test.V: - t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) - case test.OK && transparent != test.Transparent: - t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent) - } - } -} - -// getTLS attempts to set up a TLS session using the given proxy -// address, domain, and cert pool. It returns the value served by the -// server, as well as a bool indicating whether the server knew the -// true client address, indicating that the PROXY protocol was in use. -func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) { - cfg := tls.Config{ - RootCAs: pool, - ServerName: domain, - } - conn, err := tls.Dial("tcp", addr, &cfg) - if err != nil { - return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) - } - defer conn.Close() - bs, err := ioutil.ReadAll(conn) - if err != nil { - return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) - } - fs := strings.Split(string(bs), " ") - if len(fs) != 2 { - return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs)) - } - transparent := fs[1] == conn.LocalAddr().String() - return fs[0], transparent, nil -} - -type tlsServer struct { - Domains []string - Value string - Pool *x509.CertPool - Test *testing.T - NumHits uint32 - l net.Listener -} - -func (s *tlsServer) Serve() { - for { - c, err := s.l.Accept() - if err != nil { - s.Test.Logf("accept failed on %q: %s", s.Domains, err) - return - } - atomic.AddUint32(&s.NumHits, 1) - fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr()) - c.Close() - } -} - -func (s *tlsServer) Addr() string { - return s.l.Addr().String() -} - -func (s *tlsServer) Close() error { - return s.l.Close() -} - -func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) { - cert, pool, err := selfSignedCert(domains) - if err != nil { - return nil, err - } - - cfg := &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - cfg.BuildNameToCertificate() - - var l net.Listener - - l, err = net.Listen("tcp", "localhost:0") - if err != nil { - return nil, err - } - - if understandProxy { - l = &proxyproto.Listener{Listener: l} - } - - l = tls.NewListener(l, cfg) - - ret := &tlsServer{ - Domains: domains, - Value: value, - Pool: pool, - Test: t, - l: l, - } - go ret.Serve() - return ret, nil -} - -func selfSignedCert(domains []string) (tls.Certificate, *x509.CertPool, error) { - pkey, err := rsa.GenerateKey(rand.Reader, 512) - if err != nil { - return tls.Certificate{}, nil, err - } - template := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Co"}, - CommonName: domains[0], - }, - NotBefore: time.Time{}, - NotAfter: time.Now().Add(60 * time.Minute), - IsCA: true, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - DNSNames: domains[1:], - } - - derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &pkey.PublicKey, pkey) - if err != nil { - return tls.Certificate{}, nil, err - } - - var cert, key bytes.Buffer - pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pkey)}) - - tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) - if err != nil { - return tls.Certificate{}, nil, err - } - - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(cert.Bytes()) { - return tls.Certificate{}, nil, fmt.Errorf("failed to add cert %q to pool", domains) - } - - return tlscert, pool, nil -} diff --git a/main.go b/main.go deleted file mode 100644 index ff1a816..0000000 --- a/main.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2016 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bytes" - "flag" - "fmt" - "io" - "log" - "net" - "sync" - "time" -) - -var ( - cfgFile = flag.String("conf", "", "configuration file") - listen = flag.String("listen", ":443", "listening port") - helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") -) - -func main() { - flag.Parse() - - p := &Proxy{} - if err := p.Config.ReadFile(*cfgFile); err != nil { - log.Fatalf("Failed to read config %q: %s", *cfgFile, err) - } - - log.Fatalf("%s", p.ListenAndServe(*listen)) -} - -// Proxy routes connections to backends based on a Config. -type Proxy struct { - Config Config - l net.Listener -} - -// Serve accepts connections from l and routes them according to TLS SNI. -func (p *Proxy) Serve(l net.Listener) error { - for { - c, err := l.Accept() - if err != nil { - return fmt.Errorf("accept new conn: %s", err) - } - - conn := &Conn{ - TCPConn: c.(*net.TCPConn), - config: &p.Config, - } - go conn.proxy() - } -} - -// ListenAndServe creates a listener on addr calls Serve on it. -func (p *Proxy) ListenAndServe(addr string) error { - l, err := net.Listen("tcp", addr) - if err != nil { - return fmt.Errorf("create listener: %s", err) - } - return p.Serve(l) -} - -// A Conn handles the TLS proxying of one user connection. -type Conn struct { - *net.TCPConn - config *Config - - tlsMinor int - hostname string - backend string - backendConn *net.TCPConn -} - -func (c *Conn) logf(msg string, args ...interface{}) { - msg = fmt.Sprintf(msg, args...) - log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg) -} - -func (c *Conn) abort(alert byte, msg string, args ...interface{}) { - c.logf(msg, args...) - alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert} - - if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil { - c.logf("error while setting write deadline during abort: %s", err) - // Do NOT send the alert if we can't set a write deadline, - // that could result in leaking a connection for an extended - // period. - return - } - - if _, err := c.Write(alertMsg); err != nil { - c.logf("error while sending alert: %s", err) - } -} - -func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) } -func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) } - -func (c *Conn) proxy() { - defer c.Close() - - if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil { - c.internalError("Setting read deadline for ClientHello: %s", err) - return - } - - var ( - err error - handshakeBuf bytes.Buffer - ) - c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf)) - if err != nil { - c.internalError("Extracting SNI: %s", err) - return - } - - c.logf("extracted SNI %s", c.hostname) - - if err = c.SetReadDeadline(time.Time{}); err != nil { - c.internalError("Clearing read deadline for ClientHello: %s", err) - return - } - - addProxyHeader := false - c.backend, addProxyHeader = c.config.Match(c.hostname) - if c.backend == "" { - c.sniFailed("no backend found for %q", c.hostname) - return - } - - c.logf("routing %q to %q", c.hostname, c.backend) - backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) - if err != nil { - c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) - return - } - defer backend.Close() - - c.backendConn = backend.(*net.TCPConn) - - // If the backend supports the HAProxy PROXY protocol, give it the - // real source information about the connection. - if addProxyHeader { - remote := c.TCPConn.RemoteAddr().(*net.TCPAddr) - local := c.TCPConn.LocalAddr().(*net.TCPAddr) - family := "TCP6" - if remote.IP.To4() != nil { - family = "TCP4" - } - if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil { - c.internalError("failed to send PROXY header to %q: %s", c.backend, err) - return - } - } - - // Replay the piece of the handshake we had to read to do the - // routing, then blindly proxy any other bytes. - if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { - c.internalError("failed to replay handshake to %q: %s", c.backend, err) - return - } - - var wg sync.WaitGroup - wg.Add(2) - go proxy(&wg, c.TCPConn, c.backendConn) - go proxy(&wg, c.backendConn, c.TCPConn) - wg.Wait() -} - -func proxy(wg *sync.WaitGroup, a, b net.Conn) { - defer wg.Done() - atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn) - if _, err := io.Copy(atcp, btcp); err != nil { - log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err) - } - btcp.CloseWrite() - atcp.CloseRead() -} diff --git a/sni.go b/sni.go deleted file mode 100644 index ed79df2..0000000 --- a/sni.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2016 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "encoding/binary" - "errors" - "fmt" - "io" -) - -func extractSNI(r io.Reader) (string, int, error) { - handshake, tlsver, err := handshakeRecord(r) - if err != nil { - return "", 0, fmt.Errorf("reading TLS record: %s", err) - } - - sni, err := parseHello(handshake) - if err != nil { - return "", 0, fmt.Errorf("reading ClientHello: %s", err) - } - if len(sni) == 0 { - // ClientHello did not present an SNI extension. Valid packet, - // no hostname. - return "", tlsver, nil - } - - hostname, err := parseSNI(sni) - if err != nil { - return "", 0, fmt.Errorf("parsing SNI extension: %s", err) - } - return hostname, tlsver, nil -} - -// Extract the indicated hostname, if any, from the given SNI -// extension bytes. -func parseSNI(b []byte) (string, error) { - b, _, err := vector(b, 2) - if err != nil { - return "", err - } - - var ret []byte - for len(b) >= 3 { - typ := b[0] - ret, b, err = vector(b[1:], 2) - if err != nil { - return "", fmt.Errorf("truncated SNI extension") - } - - if typ == sniHostnameID { - return string(ret), nil - } - } - - if len(b) != 0 { - return "", fmt.Errorf("trailing garbage at end of SNI extension") - } - - // No DNS-based SNI present. - return "", nil -} - -const sniExtensionID = 0 -const sniHostnameID = 0 - -// Parse a TLS handshake record as a ClientHello message and extract -// the SNI extension bytes, if any. -func parseHello(b []byte) ([]byte, error) { - if len(b) == 0 { - return nil, errors.New("zero length handshake record") - } - if b[0] != 1 { - return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0]) - } - - // We're expecting a stricter TLS parser to run after we've - // proxied, so we ignore any trailing bytes that might be present - // (e.g. another handshake message). - b, _, err := vector(b[1:], 3) - if err != nil { - return nil, fmt.Errorf("reading ClientHello: %s", err) - } - - // ClientHello must be at least 34 bytes to reach the first vector - // length byte. The actual minimal size is larger than that, but - // vector() will correctly handle truncated packets. - if len(b) < 34 { - return nil, errors.New("ClientHello packet too short") - } - - if b[0] != 3 { - return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1]) - } - switch b[1] { - case 1, 2, 3: - // TLS 1.0, TLS 1.1, TLS 1.2 - default: - return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1]) - } - - // Skip over version and random struct - b = b[34:] - - // We don't technically care about SessionID, but we care that the - // framing is well-formed all the way up to the SNI field, so that - // we are sure that we're pulling the same SNI bytes as the - // eventual TLS implementation. - vec, b, err := vector(b, 1) - if err != nil { - return nil, fmt.Errorf("reading ClientHello SessionID: %s", err) - } - if len(vec) > 32 { - return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec)) - } - - // Likewise, we're just checking the bare minimum of framing. - vec, b, err = vector(b, 2) - if err != nil { - return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err) - } - if len(vec) < 2 || len(vec)%2 != 0 { - return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec)) - } - - vec, b, err = vector(b, 1) - if err != nil { - return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err) - } - if len(vec) < 1 { - return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec)) - } - - // Finally, we reach the extensions. - if len(b) == 0 { - // No extensions. This is not an error, it just means we have - // no SNI payload. - return nil, nil - } - b, vec, err = vector(b, 2) - if err != nil { - return nil, fmt.Errorf("reading ClientHello extensions: %s", err) - } - if len(vec) != 0 { - return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec)) - } - - for len(b) >= 4 { - typ := binary.BigEndian.Uint16(b[:2]) - vec, b, err = vector(b[2:], 2) - if err != nil { - return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err) - } - if typ == sniExtensionID { - // Found the SNI extension, return its payload. We don't - // care about anything in the packet beyond this point. - return vec, nil - } - } - - if len(b) != 0 { - return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b)) - } - - // Successfully parsed all extensions, but there was no SNI. - return nil, nil -} - -const maxTLSRecordLength = 16384 - -// Read one TLS record, which must be for the handshake protocol, from r. -func handshakeRecord(r io.Reader) ([]byte, int, error) { - var hdr struct { - Type uint8 - Major, Minor uint8 - Length uint16 - } - if err := binary.Read(r, binary.BigEndian, &hdr); err != nil { - return nil, 0, fmt.Errorf("reading TLS record header: %s", err) - } - - if hdr.Type != 22 { - return nil, 0, fmt.Errorf("TLS record is not a handshake") - } - - if hdr.Major != 3 { - return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) - } - switch hdr.Minor { - case 1, 2, 3: - // TLS 1.0, TLS 1.1, TLS 1.2 - default: - return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) - } - - if hdr.Length > maxTLSRecordLength { - return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength) - } - - ret := make([]byte, hdr.Length) - if _, err := io.ReadFull(r, ret); err != nil { - return nil, 0, err - } - - return ret, int(hdr.Minor), nil -} - -func vector(b []byte, lenBytes int) ([]byte, []byte, error) { - if len(b) < lenBytes { - return nil, nil, errors.New("not enough space in packet for vector") - } - var l int - for _, b := range b[:lenBytes] { - l = (l << 8) + int(b) - } - if len(b) < l+lenBytes { - return nil, nil, errors.New("not enough space in packet for vector") - } - return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil -} diff --git a/sni_test.go b/sni_test.go deleted file mode 100644 index 8c87d24..0000000 --- a/sni_test.go +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2016 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bytes" - "testing" -) - -func slice(l int) []byte { - ret := make([]byte, l) - for i := 0; i < l; i++ { - ret[i] = byte(i) - } - return ret -} - -func vec(l, lenBytes int) []byte { - b := slice(l) - vecLen := len(b) - ret := make([]byte, vecLen+l) - for i := l - 1; i >= 0; i-- { - ret[i] = byte(vecLen & 0xff) - vecLen >>= 8 - } - copy(ret[l:], b) - return ret -} - -func packet(bs ...[]byte) []byte { - var ret []byte - for _, b := range bs { - ret = append(ret, b...) - } - return ret -} - -func offset(b []byte, off int) []byte { - return b[off:] -} - -func TestVector(t *testing.T) { - tests := []struct { - in []byte - inLen int - out1, out2 []byte - err bool - }{ - { - // 1b length - append([]byte{3}, slice(10)...), 1, - slice(3), offset(slice(10), 3), false, - }, - { - // 1b length, no trailer - append([]byte{10}, slice(10)...), 1, - slice(10), []byte{}, false, - }, - { - // 1b length, no vector - append([]byte{0}, slice(10)...), 1, - []byte{}, slice(10), false, - }, - { - // 1b length, no vector or trailer - []byte{0}, 1, - []byte{}, []byte{}, false, - }, - { - // 2b length, LSB only - append([]byte{0, 3}, slice(10)...), 2, - slice(3), offset(slice(10), 3), false, - }, - { - // 2b length, MSB only - append([]byte{3, 0}, slice(1024)...), 2, - slice(768), offset(slice(1024), 768), false, - }, - { - // 2b length, both bytes - append([]byte{3, 2}, slice(1024)...), 2, - slice(770), offset(slice(1024), 770), false, - }, - { - // 3b length - append([]byte{1, 2, 3}, slice(100000)...), 3, - slice(66051), offset(slice(100000), 66051), false, - }, - { - // no bytes - []byte{}, 1, - nil, nil, true, - }, - { - // no slice - nil, 1, - nil, nil, true, - }, - { - // not enough bytes for length - []byte{1}, 2, - nil, nil, true, - }, - { - // no bytes after length - []byte{1}, 1, - nil, nil, true, - }, - { - // not enough bytes for vector - []byte{4, 1, 2}, 1, - nil, nil, true, - }, - } - - for _, test := range tests { - actual1, actual2, err := vector(test.in, test.inLen) - if !test.err && (err != nil) { - t.Errorf("unexpected error %q", err) - } - if test.err && (err == nil) { - t.Errorf("unexpected success") - } - if err != nil { - continue - } - if !bytes.Equal(actual1, test.out1) { - t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1) - } - if !bytes.Equal(actual2, test.out2) { - t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2) - } - } -} - -func TestHandshakeRecord(t *testing.T) { - tests := []struct { - in []byte - out []byte - tlsver int - }{ - { - // TLS 1.0, 1b packet - []byte{22, 3, 1, 0, 1, 3}, - []byte{3}, - 1, - }, - { - // TLS 1.1, 1b packet - []byte{22, 3, 2, 0, 1, 3}, - []byte{3}, - 2, - }, - { - // TLS 1.2, 1b packet - []byte{22, 3, 3, 0, 1, 3}, - []byte{3}, - 3, - }, - { - // TLS 1.2, no payload bytes - []byte{22, 3, 3, 0, 0}, - []byte{}, - 3, - }, - { - // TLS 1.2, >255b payload w/ trailing stuff - append([]byte{22, 3, 3, 3, 2}, slice(1024)...), - slice(770), - 3, - }, - { - // TLS 1.2, 2^14 payload - append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...), - slice(maxTLSRecordLength), - 3, - }, - { - // TLS 1.2, >2^14 payload - append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...), - nil, - 0, - }, - { - // TLS 1.2, truncated payload - []byte{22, 3, 3, 0, 4, 1, 2}, - nil, - 0, - }, - { - // truncated header - []byte{22}, - nil, - 0, - }, - { - // wrong record type - []byte{42, 3, 3, 0, 1, 3}, - nil, - 0, - }, - { - // wrong TLS major version - []byte{22, 2, 3, 0, 1, 3}, - nil, - 0, - }, - { - // wrong TLS minor version - []byte{22, 3, 42, 0, 1, 3}, - nil, - 0, - }, - { - // Obsolete SSL 3.0 - []byte{22, 3, 0, 0, 1, 3}, - nil, - 0, - }, - } - - for _, test := range tests { - r := bytes.NewBuffer(test.in) - actual, tlsver, err := handshakeRecord(r) - if test.out == nil && err == nil { - t.Errorf("unexpected success") - continue - } - if !bytes.Equal(test.out, actual) { - t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out) - } - if tlsver != test.tlsver { - t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver) - } - } -} - -func TestParseHello(t *testing.T) { - tests := []struct { - in []byte - out []byte - err bool - }{ - { - // Wrong record type - packet([]byte{42, 0, 0, 1, 1}), - nil, - true, - }, - { - // Truncated payload - packet([]byte{1, 0, 0, 1}), - nil, - true, - }, - { - // Payload too small - packet([]byte{1, 0, 0, 1, 1}), - nil, - true, - }, - { - // Unknown major version - packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)), - nil, - true, - }, - { - // Unknown minor version - packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)), - nil, - true, - }, - { - // Missing required variadic fields - packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)), - nil, - true, - }, - { - // All zero variadic fields (no ciphersuites, no compression) - packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}), - nil, - true, - }, - { - // All zero variadic fields (no ciphersuites, no compression, nonzero session ID) - packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}), - nil, - true, - }, - { - // Session + ciphersuites, no compression - packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}), - nil, - true, - }, - { - // First valid packet. TLS 1.0, no extensions present. - packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), - nil, - false, - }, - { - // TLS 1.1, no extensions present. - packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), - nil, - false, - }, - { - // TLS 1.2, no extensions present. - packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), - nil, - false, - }, - { - // TLS 1.2, garbage extensions - packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)), - nil, - true, - }, - { - // empty extensions vector - packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}), - nil, - false, - }, - { - // non-SNI extensions - packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}), - nil, - false, - }, - { - // SNI present - packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}), - []byte{182}, - false, - }, - { - // Longer SNI - packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)), - slice(4), - false, - }, - { - // Embedded SNI - packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}), - slice(4), - false, - }, - } - - for _, test := range tests { - actual, err := parseHello(test.in) - if test.err { - if err == nil { - t.Errorf("unexpected success") - } - continue - } - if err != nil { - t.Errorf("unexpected error %q", err) - continue - } - if !bytes.Equal(test.out, actual) { - t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out) - } - } -} - -func TestParseSNI(t *testing.T) { - tests := []struct { - in []byte - out string - err bool - }{ - { - // Empty packet - []byte{}, - "", - true, - }, - { - // Truncated packet - []byte{0, 2, 1}, - "", - true, - }, - { - // Truncated packet within SNI vector - []byte{0, 2, 1, 2}, - "", - true, - }, - { - // Wrong SNI kind - []byte{0, 3, 1, 0, 0}, - "", - false, - }, - { - // Right SNI kind, no hostname - []byte{0, 3, 0, 0, 0}, - "", - false, - }, - { - // SNI hostname - packet([]byte{0, 6, 0, 0, 3}, []byte("lol")), - "lol", - false, - }, - { - // Multiple SNI kinds - packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}), - "lol", - false, - }, - { - // Multiple SNI hostnames (illegal, but we just return the first) - packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")), - "bar", - false, - }, - } - - for _, test := range tests { - actual, err := parseSNI(test.in) - if test.err { - if err == nil { - t.Errorf("unexpected success") - } - continue - } - if err != nil { - t.Errorf("unexpected error %q", err) - continue - } - if test.out != actual { - t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out) - } - } -} -- cgit v1.2.3