diff options
Diffstat (limited to 'sni.go')
-rw-r--r-- | sni.go | 66 |
1 files changed, 66 insertions, 0 deletions
@@ -0,0 +1,66 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpproxy + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "net" +) + +// AddSNIRoute appends a route to the ipPort listener that says if the +// incoming TLS SNI server name is sni, the connection is given to +// dest. If it doesn't match, rule processing continues for any +// additional routes on ipPort. +// +// The ipPort is any valid net.Listen TCP address. +func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) { + p.addRoute(ipPort, sniMatch(sni), dest) +} + +type sniMatch string + +func (sni sniMatch) match(br *bufio.Reader) bool { + return clientHelloServerName(br) == string(sni) +} + +// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// without consuming any bytes from br. +// On any error, the empty string is returned. +func clientHelloServerName(br *bufio.Reader) (sni string) { + const recordHeaderLen = 5 + hdr, err := br.Peek(recordHeaderLen) + if err != nil { + return "" + } + const recordTypeHandshake = 0x16 + if hdr[0] != recordTypeHandshake { + return "" // Not TLS. + } + recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] + helloBytes, err := br.Peek(recordHeaderLen + recLen) + if err != nil { + return "" + } + tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + sni = hello.ServerName + return nil, nil + }, + }).Handshake() + return +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } |