summaryrefslogtreecommitdiff
path: root/sni.go
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2016-11-27 20:31:39 -0800
committerDavid Anderson <[email protected]>2016-11-27 20:31:39 -0800
commitb1edd90c0436159dcf4d3f794121633fb8ed9035 (patch)
tree00bf5a757ae9e657cca5ffa6de9e995d0ea2ab90 /sni.go
Initial commit.
Diffstat (limited to 'sni.go')
-rw-r--r--sni.go218
1 files changed, 218 insertions, 0 deletions
diff --git a/sni.go b/sni.go
new file mode 100644
index 0000000..42025ee
--- /dev/null
+++ b/sni.go
@@ -0,0 +1,218 @@
+package main
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+)
+
+func extractSNI(r io.Reader) (string, int, error) {
+ handshake, tlsver, err := handshakeRecord(r)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading TLS record: %s", err)
+ }
+
+ sni, err := parseHello(handshake)
+ if err != nil {
+ return "", 0, fmt.Errorf("reading ClientHello: %s", err)
+ }
+ if len(sni) == 0 {
+ // ClientHello did not present an SNI extension. Valid packet,
+ // no hostname.
+ return "", tlsver, nil
+ }
+
+ hostname, err := parseSNI(sni)
+ if err != nil {
+ return "", 0, fmt.Errorf("parsing SNI extension: %s", err)
+ }
+ return hostname, tlsver, nil
+}
+
+// Extract the indicated hostname, if any, from the given SNI
+// extension bytes.
+func parseSNI(b []byte) (string, error) {
+ b, _, err := vector(b, 2)
+ if err != nil {
+ return "", err
+ }
+
+ var ret []byte
+ for len(b) >= 3 {
+ typ := b[0]
+ ret, b, err = vector(b[1:], 2)
+ if err != nil {
+ return "", fmt.Errorf("truncated SNI extension")
+ }
+
+ if typ == sniHostnameID {
+ return string(ret), nil
+ }
+ }
+
+ if len(b) != 0 {
+ return "", fmt.Errorf("trailing garbage at end of SNI extension")
+ }
+
+ // No DNS-based SNI present.
+ return "", nil
+}
+
+const sniExtensionID = 0
+const sniHostnameID = 0
+
+// Parse a TLS handshake record as a ClientHello message and extract
+// the SNI extension bytes, if any.
+func parseHello(b []byte) ([]byte, error) {
+ if len(b) == 0 {
+ return nil, errors.New("zero length handshake record")
+ }
+ if b[0] != 1 {
+ return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0])
+ }
+
+ // We're expecting a stricter TLS parser to run after we've
+ // proxied, so we ignore any trailing bytes that might be present
+ // (e.g. another handshake message).
+ b, _, err := vector(b[1:], 3)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello: %s", err)
+ }
+
+ // ClientHello must be at least 34 bytes to reach the first vector
+ // length byte. The actual minimal size is larger than that, but
+ // vector() will correctly handle truncated packets.
+ if len(b) < 34 {
+ return nil, errors.New("ClientHello packet too short")
+ }
+
+ if b[0] != 3 {
+ return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1])
+ }
+ switch b[1] {
+ case 0, 1, 2, 3:
+ // SSL 3, TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1])
+ }
+
+ // Skip over version and random struct
+ b = b[34:]
+
+ // We don't technically care about SessionID, but we care that the
+ // framing is well-formed all the way up to the SNI field, so that
+ // we are sure that we're pulling the same SNI bytes as the
+ // eventual TLS implementation.
+ vec, b, err := vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello SessionID: %s", err)
+ }
+ if len(vec) > 32 {
+ return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
+ }
+
+ // Likewise, we're just checking the bare minimum of framing.
+ vec, b, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err)
+ }
+ if len(vec) < 2 || len(vec)%2 != 0 {
+ return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
+ }
+
+ vec, b, err = vector(b, 1)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err)
+ }
+ if len(vec) < 1 {
+ return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
+ }
+
+ // Finally, we reach the extensions.
+ if len(b) == 0 {
+ // No extensions. This is not an error, it just means we have
+ // no SNI payload.
+ return nil, nil
+ }
+ b, vec, err = vector(b, 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extensions: %s", err)
+ }
+ if len(vec) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec))
+ }
+
+ for len(b) >= 4 {
+ typ := binary.BigEndian.Uint16(b[:2])
+ vec, b, err = vector(b[2:], 2)
+ if err != nil {
+ return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err)
+ }
+ if typ == sniExtensionID {
+ // Found the SNI extension, return its payload. We don't
+ // care about anything in the packet beyond this point.
+ return vec, nil
+ }
+ }
+
+ if len(b) != 0 {
+ return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b))
+ }
+
+ // Successfully parsed all extensions, but there was no SNI.
+ return nil, nil
+}
+
+const maxTLSRecordLength = 16384
+
+// Read one TLS record, which must be for the handshake protocol, from r.
+func handshakeRecord(r io.Reader) ([]byte, int, error) {
+ var hdr struct {
+ Type uint8
+ Major, Minor uint8
+ Length uint16
+ }
+ if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
+ return nil, 0, fmt.Errorf("reading TLS record header: %s", err)
+ }
+
+ if hdr.Type != 22 {
+ return nil, 0, fmt.Errorf("TLS record is not a handshake")
+ }
+
+ if hdr.Major != 3 {
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+ switch hdr.Minor {
+ case 0, 1, 2, 3:
+ // SSL 3, TLS 1.0, TLS 1.1, TLS 1.2
+ default:
+ return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
+ }
+
+ if hdr.Length > maxTLSRecordLength {
+ return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength)
+ }
+
+ ret := make([]byte, hdr.Length)
+ if _, err := io.ReadFull(r, ret); err != nil {
+ return nil, 0, err
+ }
+
+ return ret, int(hdr.Minor), nil
+}
+
+func vector(b []byte, lenBytes int) ([]byte, []byte, error) {
+ if len(b) < lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ var l int
+ for _, b := range b[:lenBytes] {
+ l = (l << 8) + int(b)
+ }
+ if len(b) < l+lenBytes {
+ return nil, nil, errors.New("not enough space in packet for vector")
+ }
+ return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil
+}