diff options
Diffstat (limited to 'cmd/tlsrouter/sni_test.go')
-rw-r--r-- | cmd/tlsrouter/sni_test.go | 456 |
1 files changed, 456 insertions, 0 deletions
diff --git a/cmd/tlsrouter/sni_test.go b/cmd/tlsrouter/sni_test.go new file mode 100644 index 0000000..8c87d24 --- /dev/null +++ b/cmd/tlsrouter/sni_test.go @@ -0,0 +1,456 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "testing" +) + +func slice(l int) []byte { + ret := make([]byte, l) + for i := 0; i < l; i++ { + ret[i] = byte(i) + } + return ret +} + +func vec(l, lenBytes int) []byte { + b := slice(l) + vecLen := len(b) + ret := make([]byte, vecLen+l) + for i := l - 1; i >= 0; i-- { + ret[i] = byte(vecLen & 0xff) + vecLen >>= 8 + } + copy(ret[l:], b) + return ret +} + +func packet(bs ...[]byte) []byte { + var ret []byte + for _, b := range bs { + ret = append(ret, b...) + } + return ret +} + +func offset(b []byte, off int) []byte { + return b[off:] +} + +func TestVector(t *testing.T) { + tests := []struct { + in []byte + inLen int + out1, out2 []byte + err bool + }{ + { + // 1b length + append([]byte{3}, slice(10)...), 1, + slice(3), offset(slice(10), 3), false, + }, + { + // 1b length, no trailer + append([]byte{10}, slice(10)...), 1, + slice(10), []byte{}, false, + }, + { + // 1b length, no vector + append([]byte{0}, slice(10)...), 1, + []byte{}, slice(10), false, + }, + { + // 1b length, no vector or trailer + []byte{0}, 1, + []byte{}, []byte{}, false, + }, + { + // 2b length, LSB only + append([]byte{0, 3}, slice(10)...), 2, + slice(3), offset(slice(10), 3), false, + }, + { + // 2b length, MSB only + append([]byte{3, 0}, slice(1024)...), 2, + slice(768), offset(slice(1024), 768), false, + }, + { + // 2b length, both bytes + append([]byte{3, 2}, slice(1024)...), 2, + slice(770), offset(slice(1024), 770), false, + }, + { + // 3b length + append([]byte{1, 2, 3}, slice(100000)...), 3, + slice(66051), offset(slice(100000), 66051), false, + }, + { + // no bytes + []byte{}, 1, + nil, nil, true, + }, + { + // no slice + nil, 1, + nil, nil, true, + }, + { + // not enough bytes for length + []byte{1}, 2, + nil, nil, true, + }, + { + // no bytes after length + []byte{1}, 1, + nil, nil, true, + }, + { + // not enough bytes for vector + []byte{4, 1, 2}, 1, + nil, nil, true, + }, + } + + for _, test := range tests { + actual1, actual2, err := vector(test.in, test.inLen) + if !test.err && (err != nil) { + t.Errorf("unexpected error %q", err) + } + if test.err && (err == nil) { + t.Errorf("unexpected success") + } + if err != nil { + continue + } + if !bytes.Equal(actual1, test.out1) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual1, test.out1) + } + if !bytes.Equal(actual2, test.out2) { + t.Errorf("wrong bytes for vector slice. Got %#v, want %#v", actual2, test.out2) + } + } +} + +func TestHandshakeRecord(t *testing.T) { + tests := []struct { + in []byte + out []byte + tlsver int + }{ + { + // TLS 1.0, 1b packet + []byte{22, 3, 1, 0, 1, 3}, + []byte{3}, + 1, + }, + { + // TLS 1.1, 1b packet + []byte{22, 3, 2, 0, 1, 3}, + []byte{3}, + 2, + }, + { + // TLS 1.2, 1b packet + []byte{22, 3, 3, 0, 1, 3}, + []byte{3}, + 3, + }, + { + // TLS 1.2, no payload bytes + []byte{22, 3, 3, 0, 0}, + []byte{}, + 3, + }, + { + // TLS 1.2, >255b payload w/ trailing stuff + append([]byte{22, 3, 3, 3, 2}, slice(1024)...), + slice(770), + 3, + }, + { + // TLS 1.2, 2^14 payload + append([]byte{22, 3, 3, 64, 0}, slice(maxTLSRecordLength)...), + slice(maxTLSRecordLength), + 3, + }, + { + // TLS 1.2, >2^14 payload + append([]byte{22, 3, 3, 64, 1}, slice(maxTLSRecordLength+1)...), + nil, + 0, + }, + { + // TLS 1.2, truncated payload + []byte{22, 3, 3, 0, 4, 1, 2}, + nil, + 0, + }, + { + // truncated header + []byte{22}, + nil, + 0, + }, + { + // wrong record type + []byte{42, 3, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS major version + []byte{22, 2, 3, 0, 1, 3}, + nil, + 0, + }, + { + // wrong TLS minor version + []byte{22, 3, 42, 0, 1, 3}, + nil, + 0, + }, + { + // Obsolete SSL 3.0 + []byte{22, 3, 0, 0, 1, 3}, + nil, + 0, + }, + } + + for _, test := range tests { + r := bytes.NewBuffer(test.in) + actual, tlsver, err := handshakeRecord(r) + if test.out == nil && err == nil { + t.Errorf("unexpected success") + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for TLS record. Got %#v, want %#v", actual, test.out) + } + if tlsver != test.tlsver { + t.Errorf("wrong TLS version returned. Got %d, want %d", tlsver, test.tlsver) + } + } +} + +func TestParseHello(t *testing.T) { + tests := []struct { + in []byte + out []byte + err bool + }{ + { + // Wrong record type + packet([]byte{42, 0, 0, 1, 1}), + nil, + true, + }, + { + // Truncated payload + packet([]byte{1, 0, 0, 1}), + nil, + true, + }, + { + // Payload too small + packet([]byte{1, 0, 0, 1, 1}), + nil, + true, + }, + { + // Unknown major version + packet([]byte{1, 0, 0, 34, 1, 0}, slice(32)), + nil, + true, + }, + { + // Unknown minor version + packet([]byte{1, 0, 0, 34, 3, 42}, slice(32)), + nil, + true, + }, + { + // Missing required variadic fields + packet([]byte{1, 0, 0, 34, 3, 1}, slice(32)), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression) + packet([]byte{1, 0, 0, 38, 3, 1}, slice(32), []byte{0, 0, 0, 0}), + nil, + true, + }, + { + // All zero variadic fields (no ciphersuites, no compression, nonzero session ID) + packet([]byte{1, 0, 0, 70, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 0, 0}), + nil, + true, + }, + { + // Session + ciphersuites, no compression + packet([]byte{1, 0, 0, 72, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 0}), + nil, + true, + }, + { + // First valid packet. TLS 1.0, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 1}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.1, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 2}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, no extensions present. + packet([]byte{1, 0, 0, 73, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}), + nil, + false, + }, + { + // TLS 1.2, garbage extensions + packet([]byte{1, 0, 0, 115, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, slice(42)), + nil, + true, + }, + { + // empty extensions vector + packet([]byte{1, 0, 0, 75, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 0}), + nil, + false, + }, + { + // non-SNI extensions + packet([]byte{1, 0, 0, 85, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 10, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2}), + nil, + false, + }, + { + // SNI present + packet([]byte{1, 0, 0, 90, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 15, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 1, 182}), + []byte{182}, + false, + }, + { + // Longer SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 100, 100, 0, 2, 1, 2, 0, 0, 0, 4}, slice(4)), + slice(4), + false, + }, + { + // Embedded SNI + packet([]byte{1, 0, 0, 93, 3, 3}, slice(32), []byte{32}, slice(32), []byte{0, 2, 1, 2, 1, 0}, []byte{0, 18, 42, 42, 0, 0, 0, 0, 0, 4}, slice(4), []byte{100, 100, 0, 2, 1, 2}), + slice(4), + false, + }, + } + + for _, test := range tests { + actual, err := parseHello(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if !bytes.Equal(test.out, actual) { + t.Errorf("wrong bytes for SNI data. Got %#v, want %#v", actual, test.out) + } + } +} + +func TestParseSNI(t *testing.T) { + tests := []struct { + in []byte + out string + err bool + }{ + { + // Empty packet + []byte{}, + "", + true, + }, + { + // Truncated packet + []byte{0, 2, 1}, + "", + true, + }, + { + // Truncated packet within SNI vector + []byte{0, 2, 1, 2}, + "", + true, + }, + { + // Wrong SNI kind + []byte{0, 3, 1, 0, 0}, + "", + false, + }, + { + // Right SNI kind, no hostname + []byte{0, 3, 0, 0, 0}, + "", + false, + }, + { + // SNI hostname + packet([]byte{0, 6, 0, 0, 3}, []byte("lol")), + "lol", + false, + }, + { + // Multiple SNI kinds + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("lol"), []byte{42, 0, 1, 2}), + "lol", + false, + }, + { + // Multiple SNI hostnames (illegal, but we just return the first) + packet([]byte{0, 13, 1, 0, 0, 0, 0, 3}, []byte("bar"), []byte{0, 0, 3}, []byte("lol")), + "bar", + false, + }, + } + + for _, test := range tests { + actual, err := parseSNI(test.in) + if test.err { + if err == nil { + t.Errorf("unexpected success") + } + continue + } + if err != nil { + t.Errorf("unexpected error %q", err) + continue + } + if test.out != actual { + t.Errorf("wrong SNI hostname. Got %q, want %q", actual, test.out) + } + } +} |