diff options
author | David Anderson <[email protected]> | 2017-05-14 01:48:18 -0700 |
---|---|---|
committer | David Anderson <[email protected]> | 2017-05-14 01:48:18 -0700 |
commit | 4b8641f40e04705b8227f245be36457c05ccba2c (patch) | |
tree | 6d33d6d234e2fee20cdab2376dbb5bd959ae699f | |
parent | d23eadc3a6c89bf5058db893acee26d5f1d7e350 (diff) |
Add support for HAProxy's PROXY protocol.
This allows backends that support it to receive the client's true
ip:port as out-of-band information, despite the connection being
proxied.
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | config.go | 29 | ||||
-rw-r--r-- | config_test.go | 40 | ||||
-rw-r--r-- | e2e_test.go | 71 | ||||
-rw-r--r-- | main.go | 18 |
5 files changed, 119 insertions, 44 deletions
@@ -31,6 +31,11 @@ google.* 10.20.30.40:443 # RE2 regexes are also available /(alpha|beta|gamma)\.mon(itoring)?\.dave\.tf/ 100.200.100.200:443 + +# If your backend supports HAProxy's PROXY protocol, you can enable +# it to receive the real client ip:port. + +fancy.backend 2.3.4.5:443 PROXY ``` TLSRouter takes one mandatory commandline argument, the configuration file to use: @@ -17,6 +17,7 @@ package main import ( "bufio" "bytes" + "errors" "fmt" "io" "os" @@ -27,8 +28,9 @@ import ( // A Route maps a match on a domain name to a backend. type Route struct { - match *regexp.Regexp - backend string + match *regexp.Regexp + backend string + proxyInfo bool } // Config stores the TLS routing configuration. @@ -57,21 +59,21 @@ func dnsRegex(s string) (*regexp.Regexp, error) { return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`))) } -// Match returns the backend for hostname. -func (c *Config) Match(hostname string) string { +// Match returns the backend for hostname, and whether to use the PROXY protocol. +func (c *Config) Match(hostname string) (string, bool) { c.mu.Lock() defer c.mu.Unlock() if strings.HasSuffix(hostname, ".acme.invalid") { - return c.acme.Match(hostname) + return c.acme.Match(hostname), false } for _, r := range c.routes { if r.match.MatchString(hostname) { - return r.backend + return r.backend, r.proxyInfo } } - return "" + return "", false } // Read replaces the current Config with one read from r. @@ -97,7 +99,17 @@ func (c *Config) Read(r io.Reader) error { if err != nil { return err } - routes = append(routes, Route{re, fs[1]}) + routes = append(routes, Route{re, fs[1], false}) + backends = append(backends, fs[1]) + case 3: + re, err := dnsRegex(fs[0]) + if err != nil { + return err + } + if fs[2] != "PROXY" { + return errors.New("third item on a line can only be PROXY") + } + routes = append(routes, Route{re, fs[1], true}) backends = append(backends, fs[1]) default: // TODO: multiple backends? @@ -127,6 +139,7 @@ func (c *Config) ReadFile(path string) error { return c.Read(f) } +// ReadString replaces the current Config with one read from cfg. func (c *Config) ReadString(cfg string) error { b := bytes.NewBufferString(cfg) return c.Read(b) diff --git a/config_test.go b/config_test.go index bcb9e5f..9819b91 100644 --- a/config_test.go +++ b/config_test.go @@ -6,9 +6,14 @@ import ( ) func TestConfig(t *testing.T) { + type result struct { + backend string + proxy bool + } + cases := []struct { Config string - Tests map[string]string + Tests map[string]result }{ { Config: ` @@ -18,19 +23,21 @@ go.universe.tf 1.2.3.4 # Comment google.* 3.4.5.6 /gooo+gle\.com/ 4.5.6.7 +foobar.net 6.7.8.9 PROXY `, - Tests: map[string]string{ - "go.universe.tf": "1.2.3.4", - "foo.universe.tf": "2.3.4.5", - "bar.universe.tf": "2.3.4.5", - "google.com": "3.4.5.6", - "google.fr": "3.4.5.6", - "goooooooooogle.com": "4.5.6.7", + Tests: map[string]result{ + "go.universe.tf": result{"1.2.3.4", false}, + "foo.universe.tf": result{"2.3.4.5", false}, + "bar.universe.tf": result{"2.3.4.5", false}, + "google.com": result{"3.4.5.6", false}, + "google.fr": result{"3.4.5.6", false}, + "goooooooooogle.com": result{"4.5.6.7", false}, + "foobar.net": result{"6.7.8.9", true}, - "blah.com": "", - "google.com.br": "", - "foo.bar.universe.tf": "", - "goooooglexcom": "", + "blah.com": result{"", false}, + "google.com.br": result{"", false}, + "foo.bar.universe.tf": result{"", false}, + "goooooglexcom": result{"", false}, }, }, } @@ -42,9 +49,12 @@ google.* 3.4.5.6 } for hostname, expected := range test.Tests { - actual := cfg.Match(hostname) - if expected != actual { - t.Errorf("cfg.Match(%q) is %q, want %q", hostname, actual, expected) + backend, proxy := cfg.Match(hostname) + if expected.backend != backend { + t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend) + } + if expected.proxy != proxy { + t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy) } } } diff --git a/e2e_test.go b/e2e_test.go index 769e60c..c53e8c5 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -12,31 +12,40 @@ import ( "io/ioutil" "math/big" "net" + "strings" "sync/atomic" "testing" "time" + + proxyproto "github.com/armon/go-proxyproto" ) func TestRouting(t *testing.T) { - // Two backend servers - s1, err := serveTLS(t, "server1", "test.com") + // Backend servers + s1, err := serveTLS(t, "server1", false, "test.com") if err != nil { t.Fatalf("serve TLS server1: %s", err) } defer s1.Close() - s2, err := serveTLS(t, "server2", "foo.net") + s2, err := serveTLS(t, "server2", false, "foo.net") if err != nil { t.Fatalf("serve TLS server2: %s", err) } defer s2.Close() - s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid") + s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid") if err != nil { t.Fatalf("server TLS server3: %s", err) } defer s3.Close() + s4, err := serveTLS(t, "server4", true, "proxy.design") + if err != nil { + t.Fatalf("server TLS server4: %s", err) + } + defer s4.Close() + // One proxy var p Proxy l, err := net.Listen("tcp", "localhost:0") @@ -50,21 +59,24 @@ func TestRouting(t *testing.T) { test.com %s foo.net %s borkbork.tf %s -`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil { +proxy.design %s PROXY +`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil { t.Fatalf("configure proxy: %s", err) } for _, test := range []struct { - N, V string - P *x509.CertPool - OK bool + N, V string + P *x509.CertPool + OK bool + Transparent bool }{ - {"test.com", "server1", s1.Pool, true}, - {"foo.net", "server2", s2.Pool, true}, - {"bar.org", "", s1.Pool, false}, - {"blarghblargh.acme.invalid", "server3", s3.Pool, true}, + {"test.com", "server1", s1.Pool, true, false}, + {"foo.net", "server2", s2.Pool, true, false}, + {"bar.org", "", s1.Pool, false, false}, + {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false}, + {"proxy.design", "server4", s4.Pool, true, true}, } { - res, err := getTLS(l.Addr().String(), test.N, test.P) + res, transparent, err := getTLS(l.Addr().String(), test.N, test.P) switch { case test.OK && err != nil: t.Fatalf("get %q failed: %s", test.N, err) @@ -72,25 +84,36 @@ borkbork.tf %s t.Fatalf("get %q should have failed, but returned %q", test.N, res) case test.OK && res != test.V: t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) + case test.OK && transparent != test.Transparent: + t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent) } } } -func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) { +// getTLS attempts to set up a TLS session using the given proxy +// address, domain, and cert pool. It returns the value served by the +// server, as well as a bool indicating whether the server knew the +// true client address, indicating that the PROXY protocol was in use. +func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) { cfg := tls.Config{ RootCAs: pool, ServerName: domain, } conn, err := tls.Dial("tcp", addr, &cfg) if err != nil { - return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) + return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) } defer conn.Close() bs, err := ioutil.ReadAll(conn) if err != nil { - return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + } + fs := strings.Split(string(bs), " ") + if len(fs) != 2 { + return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs)) } - return string(bs), nil + transparent := fs[1] == conn.LocalAddr().String() + return fs[0], transparent, nil } type tlsServer struct { @@ -110,7 +133,7 @@ func (s *tlsServer) Serve() { return } atomic.AddUint32(&s.NumHits, 1) - c.Write([]byte(s.Value)) + fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr()) c.Close() } } @@ -123,7 +146,7 @@ func (s *tlsServer) Close() error { return s.l.Close() } -func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) { +func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) { cert, pool, err := selfSignedCert(domains) if err != nil { return nil, err @@ -134,11 +157,19 @@ func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) } cfg.BuildNameToCertificate() - l, err := tls.Listen("tcp", "localhost:0", cfg) + var l net.Listener + + l, err = net.Listen("tcp", "localhost:0") if err != nil { return nil, err } + if understandProxy { + l = &proxyproto.Listener{Listener: l} + } + + l = tls.NewListener(l, cfg) + ret := &tlsServer{ Domains: domains, Value: value, @@ -134,7 +134,8 @@ func (c *Conn) proxy() { return } - c.backend = c.config.Match(c.hostname) + addProxyHeader := false + c.backend, addProxyHeader = c.config.Match(c.hostname) if c.backend == "" { c.sniFailed("no backend found for %q", c.hostname) return @@ -150,6 +151,21 @@ func (c *Conn) proxy() { c.backendConn = backend.(*net.TCPConn) + // If the backend supports the HAProxy PROXY protocol, give it the + // real source information about the connection. + if addProxyHeader { + remote := c.TCPConn.RemoteAddr().(*net.TCPAddr) + local := c.TCPConn.LocalAddr().(*net.TCPAddr) + family := "TCP6" + if remote.IP.To4() != nil { + family = "TCP4" + } + if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil { + c.internalError("failed to send PROXY header to %q: %s", c.backend, err) + return + } + } + // Replay the piece of the handshake we had to read to do the // routing, then blindly proxy any other bytes. if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil { |