summaryrefslogtreecommitdiff
path: root/sni.go
diff options
context:
space:
mode:
Diffstat (limited to 'sni.go')
-rw-r--r--sni.go76
1 files changed, 76 insertions, 0 deletions
diff --git a/sni.go b/sni.go
new file mode 100644
index 0000000..e12c744
--- /dev/null
+++ b/sni.go
@@ -0,0 +1,76 @@
+// Copyright 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpproxy
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "io"
+ "net"
+)
+
+// AddSNIRoute appends a route to the ipPort listener that says if the
+// incoming TLS SNI server name is sni, the connection is given to
+// dest. If it doesn't match, rule processing continues for any
+// additional routes on ipPort.
+//
+// The ipPort is any valid net.Listen TCP address.
+func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
+ p.addRoute(ipPort, sniMatch(sni), dest)
+}
+
+type sniMatch string
+
+func (sni sniMatch) match(br *bufio.Reader) bool {
+ return clientHelloServerName(br) == string(sni)
+}
+
+// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
+// without consuming any bytes from br.
+// On any error, the empty string is returned.
+func clientHelloServerName(br *bufio.Reader) (sni string) {
+ const recordHeaderLen = 5
+ hdr, err := br.Peek(recordHeaderLen)
+ if err != nil {
+ return ""
+ }
+ const recordTypeHandshake = 0x16
+ if hdr[0] != recordTypeHandshake {
+ return "" // Not TLS.
+ }
+ recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3]
+ helloBytes, err := br.Peek(recordHeaderLen + recLen)
+ if err != nil {
+ return ""
+ }
+ tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{
+ GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
+ sni = hello.ServerName
+ return nil, nil
+ },
+ }).Handshake()
+ return
+}
+
+// sniSniffConn is a net.Conn that reads from r, fails on Writes,
+// and crashes otherwise.
+type sniSniffConn struct {
+ r io.Reader
+ net.Conn // nil; crash on any unexpected use
+}
+
+func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) }
+func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }