summaryrefslogtreecommitdiff
path: root/cmd/tlsrouter/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/tlsrouter/main.go')
-rw-r--r--cmd/tlsrouter/main.go191
1 files changed, 191 insertions, 0 deletions
diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go
new file mode 100644
index 0000000..ff1a816
--- /dev/null
+++ b/cmd/tlsrouter/main.go
@@ -0,0 +1,191 @@
+// Copyright 2016 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+ "time"
+)
+
+var (
+ cfgFile = flag.String("conf", "", "configuration file")
+ listen = flag.String("listen", ":443", "listening port")
+ helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello")
+)
+
+func main() {
+ flag.Parse()
+
+ p := &Proxy{}
+ if err := p.Config.ReadFile(*cfgFile); err != nil {
+ log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
+ }
+
+ log.Fatalf("%s", p.ListenAndServe(*listen))
+}
+
+// Proxy routes connections to backends based on a Config.
+type Proxy struct {
+ Config Config
+ l net.Listener
+}
+
+// Serve accepts connections from l and routes them according to TLS SNI.
+func (p *Proxy) Serve(l net.Listener) error {
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ return fmt.Errorf("accept new conn: %s", err)
+ }
+
+ conn := &Conn{
+ TCPConn: c.(*net.TCPConn),
+ config: &p.Config,
+ }
+ go conn.proxy()
+ }
+}
+
+// ListenAndServe creates a listener on addr calls Serve on it.
+func (p *Proxy) ListenAndServe(addr string) error {
+ l, err := net.Listen("tcp", addr)
+ if err != nil {
+ return fmt.Errorf("create listener: %s", err)
+ }
+ return p.Serve(l)
+}
+
+// A Conn handles the TLS proxying of one user connection.
+type Conn struct {
+ *net.TCPConn
+ config *Config
+
+ tlsMinor int
+ hostname string
+ backend string
+ backendConn *net.TCPConn
+}
+
+func (c *Conn) logf(msg string, args ...interface{}) {
+ msg = fmt.Sprintf(msg, args...)
+ log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg)
+}
+
+func (c *Conn) abort(alert byte, msg string, args ...interface{}) {
+ c.logf(msg, args...)
+ alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert}
+
+ if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.logf("error while setting write deadline during abort: %s", err)
+ // Do NOT send the alert if we can't set a write deadline,
+ // that could result in leaking a connection for an extended
+ // period.
+ return
+ }
+
+ if _, err := c.Write(alertMsg); err != nil {
+ c.logf("error while sending alert: %s", err)
+ }
+}
+
+func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) }
+func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) }
+
+func (c *Conn) proxy() {
+ defer c.Close()
+
+ if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil {
+ c.internalError("Setting read deadline for ClientHello: %s", err)
+ return
+ }
+
+ var (
+ err error
+ handshakeBuf bytes.Buffer
+ )
+ c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf))
+ if err != nil {
+ c.internalError("Extracting SNI: %s", err)
+ return
+ }
+
+ c.logf("extracted SNI %s", c.hostname)
+
+ if err = c.SetReadDeadline(time.Time{}); err != nil {
+ c.internalError("Clearing read deadline for ClientHello: %s", err)
+ return
+ }
+
+ addProxyHeader := false
+ c.backend, addProxyHeader = c.config.Match(c.hostname)
+ if c.backend == "" {
+ c.sniFailed("no backend found for %q", c.hostname)
+ return
+ }
+
+ c.logf("routing %q to %q", c.hostname, c.backend)
+ backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second)
+ if err != nil {
+ c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err)
+ return
+ }
+ defer backend.Close()
+
+ c.backendConn = backend.(*net.TCPConn)
+
+ // If the backend supports the HAProxy PROXY protocol, give it the
+ // real source information about the connection.
+ if addProxyHeader {
+ remote := c.TCPConn.RemoteAddr().(*net.TCPAddr)
+ local := c.TCPConn.LocalAddr().(*net.TCPAddr)
+ family := "TCP6"
+ if remote.IP.To4() != nil {
+ family = "TCP4"
+ }
+ if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil {
+ c.internalError("failed to send PROXY header to %q: %s", c.backend, err)
+ return
+ }
+ }
+
+ // Replay the piece of the handshake we had to read to do the
+ // routing, then blindly proxy any other bytes.
+ if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil {
+ c.internalError("failed to replay handshake to %q: %s", c.backend, err)
+ return
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go proxy(&wg, c.TCPConn, c.backendConn)
+ go proxy(&wg, c.backendConn, c.TCPConn)
+ wg.Wait()
+}
+
+func proxy(wg *sync.WaitGroup, a, b net.Conn) {
+ defer wg.Done()
+ atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn)
+ if _, err := io.Copy(atcp, btcp); err != nil {
+ log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err)
+ }
+ btcp.CloseWrite()
+ atcp.CloseRead()
+}