diff options
author | David Anderson <[email protected]> | 2016-11-27 20:31:39 -0800 |
---|---|---|
committer | David Anderson <[email protected]> | 2016-11-27 20:31:39 -0800 |
commit | b1edd90c0436159dcf4d3f794121633fb8ed9035 (patch) | |
tree | 00bf5a757ae9e657cca5ffa6de9e995d0ea2ab90 |
Initial commit.
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | config.go | 77 | ||||
-rw-r--r-- | main.go | 118 | ||||
-rw-r--r-- | sni.go | 218 | ||||
-rw-r--r-- | sni_test.go | 448 |
5 files changed, 863 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ab78466 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +tlsrouter +tlsrouter.test diff --git a/config.go b/config.go new file mode 100644 index 0000000..c6a40ac --- /dev/null +++ b/config.go @@ -0,0 +1,77 @@ +package main + +import ( + "bufio" + "fmt" + "io" + "os" + "regexp" + "strings" + "sync" +) + +type Route struct { + match *regexp.Regexp + backend string +} + +// Config stores the TLS routing configuration. +type Config struct { + mu sync.Mutex + routes []Route +} + +func dnsRegex(s string) (*regexp.Regexp, error) { + return regexp.Compile(s) +} + +func (c *Config) Match(hostname string) string { + c.mu.Lock() + defer c.mu.Unlock() + for _, r := range c.routes { + if r.match.MatchString(hostname) { + return r.backend + } + } + return "" +} + +func (c *Config) Read(r io.Reader) error { + var routes []Route + + s := bufio.NewScanner(r) + for s.Scan() { + fs := strings.Fields(s.Text()) + switch len(fs) { + case 0: + continue + case 1: + return fmt.Errorf("invalid %q on a line by itself", s.Text()) + case 2: + re, err := dnsRegex(fs[0]) + if err != nil { + return err + } + routes = append(routes, Route{re, fs[1]}) + default: + // TODO: multiple backends? + return fmt.Errorf("too many fields on line: %q", s.Text()) + } + } + if err := s.Err(); err != nil { + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + c.routes = routes + return nil +} + +func (c *Config) ReadFile(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + return c.Read(f) +} @@ -0,0 +1,118 @@ +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +var cfgFile = flag.String("conf", "", "configuration file") +var listen = flag.String("listen", ":443", "listening port") + +var config Config + +func main() { + flag.Parse() + + if err := config.ReadFile(*cfgFile); err != nil { + log.Fatalf("Failed to read config %q: %s", *cfgFile, err) + } + + l, err := net.Listen("tcp", *listen) + if err != nil { + log.Fatalf("Failed to listen: %s", err) + } + + for { + c, err := l.Accept() + if err != nil { + log.Fatalf("Error while accepting: %s", err) + } + + conn := &Conn{TCPConn: c.(*net.TCPConn)} + go conn.proxy() + } +} + +type Conn struct { + *net.TCPConn + + tlsMinor int + hostname string + backend string + backendConn *net.TCPConn +} + +func (c *Conn) log(msg string, args ...interface{}) { + msg = fmt.Sprintf(msg, args...) + log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg) +} + +func (c *Conn) abort(alert byte, msg string, args ...interface{}) { + c.log(msg, args...) + alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert} + if _, err := c.Write(alertMsg); err != nil { + c.log("error while sending alert: %s", err) + } +} + +func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) } +func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) } + +func (c *Conn) proxy() { + defer c.Close() + + var ( + err error + handshakeBuf bytes.Buffer + ) + c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf)) + if err != nil { + c.internalError("Extracting SNI: %s", err) + return + } + + c.backend = config.Match(c.hostname) + if c.backend == "" { + c.sniFailed("no backend found for %q", c.hostname) + return + } + + c.log("routing %q to %q", c.hostname, c.backend) + backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) + if err != nil { + c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) + return + } + defer backend.Close() + + c.backendConn = backend.(*net.TCPConn) + + // Replay the piece of the handshake we had to read to do the + // routing, then blindly proxy any other bytes. + if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { + c.internalError("failed to replay handshake to %q: %s", c.backend, err) + return + } + + var wg sync.WaitGroup + wg.Add(2) + go proxy(&wg, c.TCPConn, c.backendConn) + go proxy(&wg, c.backendConn, c.TCPConn) + wg.Wait() +} + +func proxy(wg *sync.WaitGroup, a, b net.Conn) { + defer wg.Done() + atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn) + if _, err := io.Copy(atcp, btcp); err != nil { + log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err) + } + btcp.CloseWrite() + atcp.CloseRead() +} @@ -0,0 +1,218 @@ +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +func extractSNI(r io.Reader) (string, int, error) { + handshake, tlsver, err := handshakeRecord(r) + if err != nil { + return "", 0, fmt.Errorf("reading TLS record: %s", err) + } + + sni, err := parseHello(handshake) + if err != nil { + return "", 0, fmt.Errorf("reading ClientHello: %s", err) + } + if len(sni) == 0 { + // ClientHello did not present an SNI extension. Valid packet, + // no hostname. + return "", tlsver, nil + } + + hostname, err := parseSNI(sni) + if err != nil { + return "", 0, fmt.Errorf("parsing SNI extension: %s", err) + } + return hostname, tlsver, nil +} + +// Extract the indicated hostname, if any, from the given SNI +// extension bytes. +func parseSNI(b []byte) (string, error) { + b, _, err := vector(b, 2) + if err != nil { + return "", err + } + + var ret []byte + for len(b) >= 3 { + typ := b[0] + ret, b, err = vector(b[1:], 2) + if err != nil { + return "", fmt.Errorf("truncated SNI extension") + } + + if typ == sniHostnameID { + return string(ret), nil + } + } + + if len(b) != 0 { + return "", fmt.Errorf("trailing garbage at end of SNI extension") + } + + // No DNS-based SNI present. + return "", nil +} + +const sniExtensionID = 0 +const sniHostnameID = 0 + +// Parse a TLS handshake record as a ClientHello message and extract +// the SNI extension bytes, if any. +func parseHello(b []byte) ([]byte, error) { + if len(b) == 0 { + return nil, errors.New("zero length handshake record") + } + if b[0] != 1 { + return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0]) + } + + // We're expecting a stricter TLS parser to run after we've + // proxied, so we ignore any trailing bytes that might be present + // (e.g. another handshake message). + b, _, err := vector(b[1:], 3) + if err != nil { + return nil, fmt.Errorf("reading ClientHello: %s", err) + } + + // ClientHello must be at least 34 bytes to reach the first vector + // length byte. The actual minimal size is larger than that, but + // vector() will correctly handle truncated packets. + if len(b) < 34 { + return nil, errors.New("ClientHello packet too short") + } + + if b[0] != 3 { + return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1]) + } + switch b[1] { + case 0, 1, 2, 3: + // SSL 3, TLS 1.0, TLS 1.1, TLS 1.2 + default: + return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1]) + } + + // Skip over version and random struct + b = b[34:] + + // We don't technically care about SessionID, but we care that the + // framing is well-formed all the way up to the SNI field, so that + // we are sure that we're pulling the same SNI bytes as the + // eventual TLS implementation. + vec, b, err := vector(b, 1) + if err != nil { + return nil, fmt.Errorf("reading ClientHello SessionID: %s", err) + } + if len(vec) > 32 { + return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec)) + } + + // Likewise, we're just checking the bare minimum of framing. + vec, b, err = vector(b, 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err) + } + if len(vec) < 2 || len(vec)%2 != 0 { + return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec)) + } + + vec, b, err = vector(b, 1) + if err != nil { + return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err) + } + if len(vec) < 1 { + return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec)) + } + + // Finally, we reach the extensions. + if len(b) == 0 { + // No extensions. This is not an error, it just means we have + // no SNI payload. + return nil, nil + } + b, vec, err = vector(b, 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello extensions: %s", err) + } + if len(vec) != 0 { + return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec)) + } + + for len(b) >= 4 { + typ := binary.BigEndian.Uint16(b[:2]) + vec, b, err = vector(b[2:], 2) + if err != nil { + return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err) + } + if typ == sniExtensionID { + // Found the SNI extension, return its payload. We don't + // care about anything in the packet beyond this point. + return vec, nil + } + } + + if len(b) != 0 { + return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b)) + } + + // Successfully parsed all extensions, but there was no SNI. + return nil, nil +} + +const maxTLSRecordLength = 16384 + +// Read one TLS record, which must be for the handshake protocol, from r. +func handshakeRecord(r io.Reader) ([]byte, int, error) { + var hdr struct { + Type uint8 + Major, Minor uint8 + Length uint16 + } + if err := binary.Read(r, binary.BigEndian, &hdr); err != nil { + return nil, 0, fmt.Errorf("reading TLS record header: %s", err) + } + + if hdr.Type != 22 { + return nil, 0, fmt.Errorf("TLS record is not a handshake") + } + + if hdr.Major != 3 { + return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) + } + switch hdr.Minor { + case 0, 1, 2, 3: + // SSL 3, TLS 1.0, TLS 1.1, TLS 1.2 + default: + return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) + } + + if hdr.Length > maxTLSRecordLength { + return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength) + } + + ret := make([]byte, hdr.Length) + if _, err := io.ReadFull(r, ret); err != nil { + return nil, 0, err + } + + return ret, int(hdr.Minor), nil +} + +func vector(b []byte, lenBytes int) ([]byte, []byte, error) { + if len(b) < lenBytes { + return nil, nil, errors.New("not enough space in packet for vector") + } + var l int + for _, b := range b[:lenBytes] { + l = (l << 8) + int(b) + } + if len(b) < l+lenBytes { + return nil, nil, errors.New("not enough space in packet for vector") + } + return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil +} diff --git a/sni_test.go b/sni_test.go new file mode 100644 index 0000000..ca7adf7 --- /dev/null +++ b/sni_test.go @@ -0,0 +1,448 @@ +package main + +import ( + "bytes" + "testing" +) + +func slice(l int) []byte { + ret := make([]byte, l) + for i := 0; i < l; i++ { + ret[i] = byte(i) + } + return ret +} + +func vec(l, lenBytes int) []byte { + b := slice(l) + vecLen := len(b) + ret := make([]byte, vecLen+l) + for i := l - 1; i >= 0; i-- { + ret[i] = byte(vecLen & 0xff) + vecLen >>= 8 + } + copy(ret[l:], b) + return ret +} + +func packet(bs ...[]byte) []byte { + var ret []byte + for _, b := range bs { + ret = append(ret, b...) + } + return ret +} + +func offset(b []byte, off int) []byte { + return b[off:] +} + +func TestVector(t *testing.T) { + tests := []struct { + in []byte + inLen int + out1, out2 []byte + err bool + }{ + { + // 1b length + append([]byte{3}, slice(10)...), 1, + slice(3), offset(slice(10), 3), false, + }, + { + // 1b length, no trailer + append([]byte{10}, slice(10)...), 1, + slice(10), []byte{}, false, + }, + { + // 1b length, no vector + append([]byte{0}, slice(10)...), 1, + []byte{}, slice(10), false, + }, + { + // 1b length, no vector or trailer + []byte{0}, 1, + []byte{}, []byte{}, false, + }, + { + // 2b length, LSB only + append([]byte{0, 3}, slice(10)...), 2, + slice(3), offset(slice(10), 3), false, + }, + { + // 2b length, MSB only + append([]byte{3, 0}, slice(1024)...), 2, + slice(768), offset(slice(1024), 768), false, + }, + { + // 2b length, both bytes + append([]byte{3, 2}, slice(1024)...), 2, + slice(770), offset(slice(1024), 770), false, + }, + { + // 3b length + append([]byte{1, 2, 3}, slice(100000)...), 3, + slice(66051), offset(slice(100000), 66051), false, + }, + { + // no bytes + []byte{}, 1, + nil, nil, true, + }, + { + // no slice + nil, 1, + nil, nil, true, + }, + { + // not enough bytes for length + []byte{1}, 2, + nil, nil, true, + }, + { + // no bytes after length + []byte{1}, 1, + nil, nil, true, + }, + { + // not enough bytes for vector + []byte{4, 1, 2}, 1, + nil, nil, true, + }, + } + + for _, test := range tests { + actual1, actual2, err := vector(test.in, test.inLen) + if !test.err && (err != nil) { + t.Errorf("unexpected error %q", err) + } + if test.err && (err == nil) { + t.Errorf("unexpected success") + } + if err != nil { + continue + } + if !bytes.Equal(actual1, test.out1) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1) + } + if !bytes.Equal(actual2, test.out2) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2) + } + } +} + +func TestHandshakeRecord(t *testing.T) { + tests := []struct { + in []byte + out []byte + tlsver int + }{ + { + // SSL 3.0, 1b packet + []byte{22, 3, 0, 0, 1, 3}, + []byte{3}, + 0, + }, + { + // TLS 1.0, 1b packet + []byte{22, 3, 1, 0, 1, 3}, + []byte{3}, + 1, + }, + { + // TLS 1.1, 1b packet + []byte{22, 3, 2, 0, 1, 3}, + []byte{3}, + 2, + }, + { + // TLS 1.2, 1b packet + []byte{22, 3, 3, 0, 1, 3}, + []byte{3}, + 3, + }, + { + // TLS 1.2, no payload bytes + []byte{22, 3, 3, 0, 0}, + []byte{}, + 3, + }, + { + // TLS 1.2, >255b payload w/ trailing stuff + append([]byte{22, 3, 3, 3, 2}, slice(1024)...), + slice(770), + 3, + }, + { + // TLS 1.2, 2^14 payload + append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...), + slice(maxTLSRecordLength), + 3, + }, + { + // TLS 1.2, >2^14 payload + append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...), + nil, + 0, + }, + { + // TLS 1.2, truncated payload + []byte{22, 3, 3, 0, 4, 1, 2}, + nil, + 0, + }, + { + // truncated header + []byte{22}, + nil, + 0, + }, + { + // wrong record type + []byte{42, 3, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS major version + []byte{22, 2, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS minor version + []byte{22, 3, 42, 0, 1, 3}, + nil, + 0, + }, + } + + for _, test := range tests { + r := bytes.NewBuffer(test.in) + actual, tlsver, err := handshakeRecord(r) + if test.out == nil && err == nil { + t.Errorf("unexpected success") + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out) + } + if tlsver != test.tlsver { + t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver) + } + } +} + +func TestParseHello(t *testing.T) { + tests := []struct { + in []byte + out []byte + err bool + }{ + { + // Wrong record type + packet([]byte{42, 0, 0, 1, 1}), + nil, + true, + }, + { + // Truncated payload + packet([]byte{1, 0, 0, 1}), + nil, + true, + }, + { + // Payload too small + packet([]byte{1, 0, 0, 1, 1}), + nil, + true, + }, + { + // Unknown major version + packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)), + nil, + true, + }, + { + // Unknown minor version + packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)), + nil, + true, + }, + { + // Missing required variadic fields + packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression) + packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression, nonzero session ID) + packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}), + nil, + true, + }, + { + // Session + ciphersuites, no compression + packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}), + nil, + true, + }, + { + // First valid packet. SSL 3.0, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 0}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.0, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.1, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, garbage extensions + packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)), + nil, + true, + }, + { + // empty extensions vector + packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}), + nil, + false, + }, + { + // non-SNI extensions + packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}), + nil, + false, + }, + { + // SNI present + packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}), + []byte{182}, + false, + }, + { + // Longer SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)), + slice(4), + false, + }, + { + // Embedded SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}), + slice(4), + false, + }, + } + + for _, test := range tests { + actual, err := parseHello(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out) + } + } +} + +func TestParseSNI(t *testing.T) { + tests := []struct { + in []byte + out string + err bool + }{ + { + // Empty packet + []byte{}, + "", + true, + }, + { + // Truncated packet + []byte{0, 2, 1}, + "", + true, + }, + { + // Truncated packet within SNI vector + []byte{0, 2, 1, 2}, + "", + true, + }, + { + // Wrong SNI kind + []byte{0, 3, 1, 0, 0}, + "", + false, + }, + { + // Right SNI kind, no hostname + []byte{0, 3, 0, 0, 0}, + "", + false, + }, + { + // SNI hostname + packet([]byte{0, 6, 0, 0, 3}, []byte("lol")), + "lol", + false, + }, + { + // Multiple SNI kinds + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}), + "lol", + false, + }, + { + // Multiple SNI hostnames (illegal, but we just return the first) + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")), + "bar", + false, + }, + } + + for _, test := range tests { + actual, err := parseSNI(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if test.out != actual { + t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out) + } + } +} |