diff options
Diffstat (limited to 'sni.go')
-rw-r--r-- | sni.go | 76 |
1 files changed, 76 insertions, 0 deletions
@@ -0,0 +1,76 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "net" +) + +// AddSNIRoute appends a route to the ipPort listener that says if the +// incoming TLS SNI server name is sni, the connection is given to +// dest. If it doesn't match, rule processing continues for any +// additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { + p.addRoute(ipPort, sniMatch(sni), dest) +} + +type sniMatch string + +func (sni sniMatch) match(br *bufio.Reader) bool { + return clientHelloServerName(br) == string(sni) +} + +// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// without consuming any bytes from br. +// On any error, the empty string is returned. +func clientHelloServerName(br *bufio.Reader) (sni string) { + const recordHeaderLen = 5 + hdr, err := br.Peek(recordHeaderLen) + if err != nil { + return "" + } + const recordTypeHandshake = 0x16 + if hdr[0] != recordTypeHandshake { + return "" // Not TLS. + } + recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] + helloBytes, err := br.Peek(recordHeaderLen + recLen) + if err != nil { + return "" + } + tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + sni = hello.ServerName + return nil, nil + }, + }).Handshake() + return +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } |