summaryrefslogtreecommitdiff
path: root/cmd/tlsrouter/sni.go
blob: ed79df2afe6fbd954105684db1119d5272fa26da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
	"encoding/binary"
	"errors"
	"fmt"
	"io"
)

func extractSNI(r io.Reader) (string, int, error) {
	handshake, tlsver, err := handshakeRecord(r)
	if err != nil {
		return "", 0, fmt.Errorf("reading TLS record: %s", err)
	}

	sni, err := parseHello(handshake)
	if err != nil {
		return "", 0, fmt.Errorf("reading ClientHello: %s", err)
	}
	if len(sni) == 0 {
		// ClientHello did not present an SNI extension. Valid packet,
		// no hostname.
		return "", tlsver, nil
	}

	hostname, err := parseSNI(sni)
	if err != nil {
		return "", 0, fmt.Errorf("parsing SNI extension: %s", err)
	}
	return hostname, tlsver, nil
}

// Extract the indicated hostname, if any, from the given SNI
// extension bytes.
func parseSNI(b []byte) (string, error) {
	b, _, err := vector(b, 2)
	if err != nil {
		return "", err
	}

	var ret []byte
	for len(b) >= 3 {
		typ := b[0]
		ret, b, err = vector(b[1:], 2)
		if err != nil {
			return "", fmt.Errorf("truncated SNI extension")
		}

		if typ == sniHostnameID {
			return string(ret), nil
		}
	}

	if len(b) != 0 {
		return "", fmt.Errorf("trailing garbage at end of SNI extension")
	}

	// No DNS-based SNI present.
	return "", nil
}

const sniExtensionID = 0
const sniHostnameID = 0

// Parse a TLS handshake record as a ClientHello message and extract
// the SNI extension bytes, if any.
func parseHello(b []byte) ([]byte, error) {
	if len(b) == 0 {
		return nil, errors.New("zero length handshake record")
	}
	if b[0] != 1 {
		return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0])
	}

	// We're expecting a stricter TLS parser to run after we've
	// proxied, so we ignore any trailing bytes that might be present
	// (e.g. another handshake message).
	b, _, err := vector(b[1:], 3)
	if err != nil {
		return nil, fmt.Errorf("reading ClientHello: %s", err)
	}

	// ClientHello must be at least 34 bytes to reach the first vector
	// length byte. The actual minimal size is larger than that, but
	// vector() will correctly handle truncated packets.
	if len(b) < 34 {
		return nil, errors.New("ClientHello packet too short")
	}

	if b[0] != 3 {
		return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1])
	}
	switch b[1] {
	case 1, 2, 3:
		// TLS 1.0, TLS 1.1, TLS 1.2
	default:
		return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1])
	}

	// Skip over version and random struct
	b = b[34:]

	// We don't technically care about SessionID, but we care that the
	// framing is well-formed all the way up to the SNI field, so that
	// we are sure that we're pulling the same SNI bytes as the
	// eventual TLS implementation.
	vec, b, err := vector(b, 1)
	if err != nil {
		return nil, fmt.Errorf("reading ClientHello SessionID: %s", err)
	}
	if len(vec) > 32 {
		return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
	}

	// Likewise, we're just checking the bare minimum of framing.
	vec, b, err = vector(b, 2)
	if err != nil {
		return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err)
	}
	if len(vec) < 2 || len(vec)%2 != 0 {
		return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
	}

	vec, b, err = vector(b, 1)
	if err != nil {
		return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err)
	}
	if len(vec) < 1 {
		return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
	}

	// Finally, we reach the extensions.
	if len(b) == 0 {
		// No extensions. This is not an error, it just means we have
		// no SNI payload.
		return nil, nil
	}
	b, vec, err = vector(b, 2)
	if err != nil {
		return nil, fmt.Errorf("reading ClientHello extensions: %s", err)
	}
	if len(vec) != 0 {
		return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec))
	}

	for len(b) >= 4 {
		typ := binary.BigEndian.Uint16(b[:2])
		vec, b, err = vector(b[2:], 2)
		if err != nil {
			return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err)
		}
		if typ == sniExtensionID {
			// Found the SNI extension, return its payload. We don't
			// care about anything in the packet beyond this point.
			return vec, nil
		}
	}

	if len(b) != 0 {
		return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b))
	}

	// Successfully parsed all extensions, but there was no SNI.
	return nil, nil
}

const maxTLSRecordLength = 16384

// Read one TLS record, which must be for the handshake protocol, from r.
func handshakeRecord(r io.Reader) ([]byte, int, error) {
	var hdr struct {
		Type         uint8
		Major, Minor uint8
		Length       uint16
	}
	if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
		return nil, 0, fmt.Errorf("reading TLS record header: %s", err)
	}

	if hdr.Type != 22 {
		return nil, 0, fmt.Errorf("TLS record is not a handshake")
	}

	if hdr.Major != 3 {
		return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
	}
	switch hdr.Minor {
	case 1, 2, 3:
		// TLS 1.0, TLS 1.1, TLS 1.2
	default:
		return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
	}

	if hdr.Length > maxTLSRecordLength {
		return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength)
	}

	ret := make([]byte, hdr.Length)
	if _, err := io.ReadFull(r, ret); err != nil {
		return nil, 0, err
	}

	return ret, int(hdr.Minor), nil
}

func vector(b []byte, lenBytes int) ([]byte, []byte, error) {
	if len(b) < lenBytes {
		return nil, nil, errors.New("not enough space in packet for vector")
	}
	var l int
	for _, b := range b[:lenBytes] {
		l = (l << 8) + int(b)
	}
	if len(b) < l+lenBytes {
		return nil, nil, errors.New("not enough space in packet for vector")
	}
	return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil
}