summaryrefslogtreecommitdiff
path: root/sni.go
diff options
context:
space:
mode:
Diffstat (limited to 'sni.go')
-rw-r--r--sni.go66
1 files changed, 66 insertions, 0 deletions
diff --git a/sni.go b/sni.go
new file mode 100644
index 0000000..d57dd31
--- /dev/null
+++ b/sni.go
@@ -0,0 +1,66 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "io"
+ "net"
+)
+
+// AddSNIRoute appends a route to the ipPort listener that says if the
+// incoming TLS SNI server name is sni, the connection is given to
+// dest. If it doesn't match, rule processing continues for any
+// additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ p.addRoute(ipPort, sniMatch(sni), dest)
+}
+
+type sniMatch string
+
+func (sni sniMatch) match(br *bufio.Reader) bool {
+ return clientHelloServerName(br) == string(sni)
+}
+
+// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
+// without consuming any bytes from br.
+// On any error, the empty string is returned.
+func clientHelloServerName(br *bufio.Reader) (sni string) {
+ const recordHeaderLen = 5
+ hdr, err := br.Peek(recordHeaderLen)
+ if err != nil {
+ return ""
+ }
+ const recordTypeHandshake = 0x16
+ if hdr[0] != recordTypeHandshake {
+ return "" // Not TLS.
+ }
+ recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
+ helloBytes, err := br.Peek(recordHeaderLen + recLen)
+ if err != nil {
+ return ""
+ }
+ tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
+ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
+ sni = hello.ServerName
+ return nil, nil
+ },
+ }).Handshake()
+ return
+}
+
+// sniSniffConn is a net.Conn that reads from r, fails on Writes,
+// and crashes otherwise.
+type sniSniffConn struct {
+ r io.Reader
+ net.Conn // nil; crash on any unexpected use
+}
+
+func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }