diff options
Diffstat (limited to 'cmd/tlsrouter/main.go')
-rw-r--r-- | cmd/tlsrouter/main.go | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/cmd/tlsrouter/main.go b/cmd/tlsrouter/main.go new file mode 100644 index 0000000..ff1a816 --- /dev/null +++ b/cmd/tlsrouter/main.go @@ -0,0 +1,191 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +var ( + cfgFile = flag.String("conf", "", "configuration file") + listen = flag.String("listen", ":443", "listening port") + helloTimeout = flag.Duration("hello-timeout", 3*time.Second, "how long to wait for the TLS ClientHello") +) + +func main() { + flag.Parse() + + p := &Proxy{} + if err := p.Config.ReadFile(*cfgFile); err != nil { + log.Fatalf("Failed to read config %q: %s", *cfgFile, err) + } + + log.Fatalf("%s", p.ListenAndServe(*listen)) +} + +// Proxy routes connections to backends based on a Config. +type Proxy struct { + Config Config + l net.Listener +} + +// Serve accepts connections from l and routes them according to TLS SNI. +func (p *Proxy) Serve(l net.Listener) error { + for { + c, err := l.Accept() + if err != nil { + return fmt.Errorf("accept new conn: %s", err) + } + + conn := &Conn{ + TCPConn: c.(*net.TCPConn), + config: &p.Config, + } + go conn.proxy() + } +} + +// ListenAndServe creates a listener on addr calls Serve on it. +func (p *Proxy) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("create listener: %s", err) + } + return p.Serve(l) +} + +// A Conn handles the TLS proxying of one user connection. +type Conn struct { + *net.TCPConn + config *Config + + tlsMinor int + hostname string + backend string + backendConn *net.TCPConn +} + +func (c *Conn) logf(msg string, args ...interface{}) { + msg = fmt.Sprintf(msg, args...) + log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg) +} + +func (c *Conn) abort(alert byte, msg string, args ...interface{}) { + c.logf(msg, args...) + alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert} + + if err := c.SetWriteDeadline(time.Now().Add(*helloTimeout)); err != nil { + c.logf("error while setting write deadline during abort: %s", err) + // Do NOT send the alert if we can't set a write deadline, + // that could result in leaking a connection for an extended + // period. + return + } + + if _, err := c.Write(alertMsg); err != nil { + c.logf("error while sending alert: %s", err) + } +} + +func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) } +func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) } + +func (c *Conn) proxy() { + defer c.Close() + + if err := c.SetReadDeadline(time.Now().Add(*helloTimeout)); err != nil { + c.internalError("Setting read deadline for ClientHello: %s", err) + return + } + + var ( + err error + handshakeBuf bytes.Buffer + ) + c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf)) + if err != nil { + c.internalError("Extracting SNI: %s", err) + return + } + + c.logf("extracted SNI %s", c.hostname) + + if err = c.SetReadDeadline(time.Time{}); err != nil { + c.internalError("Clearing read deadline for ClientHello: %s", err) + return + } + + addProxyHeader := false + c.backend, addProxyHeader = c.config.Match(c.hostname) + if c.backend == "" { + c.sniFailed("no backend found for %q", c.hostname) + return + } + + c.logf("routing %q to %q", c.hostname, c.backend) + backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second) + if err != nil { + c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err) + return + } + defer backend.Close() + + c.backendConn = backend.(*net.TCPConn) + + // If the backend supports the HAProxy PROXY protocol, give it the + // real source information about the connection. + if addProxyHeader { + remote := c.TCPConn.RemoteAddr().(*net.TCPAddr) + local := c.TCPConn.LocalAddr().(*net.TCPAddr) + family := "TCP6" + if remote.IP.To4() != nil { + family = "TCP4" + } + if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil { + c.internalError("failed to send PROXY header to %q: %s", c.backend, err) + return + } + } + + // Replay the piece of the handshake we had to read to do the + // routing, then blindly proxy any other bytes. + if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { + c.internalError("failed to replay handshake to %q: %s", c.backend, err) + return + } + + var wg sync.WaitGroup + wg.Add(2) + go proxy(&wg, c.TCPConn, c.backendConn) + go proxy(&wg, c.backendConn, c.TCPConn) + wg.Wait() +} + +func proxy(wg *sync.WaitGroup, a, b net.Conn) { + defer wg.Done() + atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn) + if _, err := io.Copy(atcp, btcp); err != nil { + log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err) + } + btcp.CloseWrite() + atcp.CloseRead() +} |