summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Anderson <[email protected]>2017-05-14 01:48:18 -0700
committerDavid Anderson <[email protected]>2017-05-14 01:48:18 -0700
commit4b8641f40e04705b8227f245be36457c05ccba2c (patch)
tree6d33d6d234e2fee20cdab2376dbb5bd959ae699f
parentd23eadc3a6c89bf5058db893acee26d5f1d7e350 (diff)
Add support for HAProxy's PROXY protocol.
This allows backends that support it to receive the client's true ip:port as out-of-band information, despite the connection being proxied.
-rw-r--r--README.md5
-rw-r--r--config.go29
-rw-r--r--config_test.go40
-rw-r--r--e2e_test.go71
-rw-r--r--main.go18
5 files changed, 119 insertions, 44 deletions
diff --git a/README.md b/README.md
index 12a7800..5a75935 100644
--- a/README.md
+++ b/README.md
@@ -31,6 +31,11 @@ google.* 10.20.30.40:443
# RE2 regexes are also available
/(alpha|beta|gamma)\.mon(itoring)?\.dave\.tf/ 100.200.100.200:443
+
+# If your backend supports HAProxy's PROXY protocol, you can enable
+# it to receive the real client ip:port.
+
+fancy.backend 2.3.4.5:443 PROXY
```
TLSRouter takes one mandatory commandline argument, the configuration file to use:
diff --git a/config.go b/config.go
index c205c22..1c8151f 100644
--- a/config.go
+++ b/config.go
@@ -17,6 +17,7 @@ package main
import (
"bufio"
"bytes"
+ "errors"
"fmt"
"io"
"os"
@@ -27,8 +28,9 @@ import (
// A Route maps a match on a domain name to a backend.
type Route struct {
- match *regexp.Regexp
- backend string
+ match *regexp.Regexp
+ backend string
+ proxyInfo bool
}
// Config stores the TLS routing configuration.
@@ -57,21 +59,21 @@ func dnsRegex(s string) (*regexp.Regexp, error) {
return regexp.Compile(fmt.Sprintf("^%s$", strings.Join(b, `\.`)))
}
-// Match returns the backend for hostname.
-func (c *Config) Match(hostname string) string {
+// Match returns the backend for hostname, and whether to use the PROXY protocol.
+func (c *Config) Match(hostname string) (string, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if strings.HasSuffix(hostname, ".acme.invalid") {
- return c.acme.Match(hostname)
+ return c.acme.Match(hostname), false
}
for _, r := range c.routes {
if r.match.MatchString(hostname) {
- return r.backend
+ return r.backend, r.proxyInfo
}
}
- return ""
+ return "", false
}
// Read replaces the current Config with one read from r.
@@ -97,7 +99,17 @@ func (c *Config) Read(r io.Reader) error {
if err != nil {
return err
}
- routes = append(routes, Route{re, fs[1]})
+ routes = append(routes, Route{re, fs[1], false})
+ backends = append(backends, fs[1])
+ case 3:
+ re, err := dnsRegex(fs[0])
+ if err != nil {
+ return err
+ }
+ if fs[2] != "PROXY" {
+ return errors.New("third item on a line can only be PROXY")
+ }
+ routes = append(routes, Route{re, fs[1], true})
backends = append(backends, fs[1])
default:
// TODO: multiple backends?
@@ -127,6 +139,7 @@ func (c *Config) ReadFile(path string) error {
return c.Read(f)
}
+// ReadString replaces the current Config with one read from cfg.
func (c *Config) ReadString(cfg string) error {
b := bytes.NewBufferString(cfg)
return c.Read(b)
diff --git a/config_test.go b/config_test.go
index bcb9e5f..9819b91 100644
--- a/config_test.go
+++ b/config_test.go
@@ -6,9 +6,14 @@ import (
)
func TestConfig(t *testing.T) {
+ type result struct {
+ backend string
+ proxy bool
+ }
+
cases := []struct {
Config string
- Tests map[string]string
+ Tests map[string]result
}{
{
Config: `
@@ -18,19 +23,21 @@ go.universe.tf 1.2.3.4
# Comment
google.* 3.4.5.6
/gooo+gle\.com/ 4.5.6.7
+foobar.net 6.7.8.9 PROXY
`,
- Tests: map[string]string{
- "go.universe.tf": "1.2.3.4",
- "foo.universe.tf": "2.3.4.5",
- "bar.universe.tf": "2.3.4.5",
- "google.com": "3.4.5.6",
- "google.fr": "3.4.5.6",
- "goooooooooogle.com": "4.5.6.7",
+ Tests: map[string]result{
+ "go.universe.tf": result{"1.2.3.4", false},
+ "foo.universe.tf": result{"2.3.4.5", false},
+ "bar.universe.tf": result{"2.3.4.5", false},
+ "google.com": result{"3.4.5.6", false},
+ "google.fr": result{"3.4.5.6", false},
+ "goooooooooogle.com": result{"4.5.6.7", false},
+ "foobar.net": result{"6.7.8.9", true},
- "blah.com": "",
- "google.com.br": "",
- "foo.bar.universe.tf": "",
- "goooooglexcom": "",
+ "blah.com": result{"", false},
+ "google.com.br": result{"", false},
+ "foo.bar.universe.tf": result{"", false},
+ "goooooglexcom": result{"", false},
},
},
}
@@ -42,9 +49,12 @@ google.* 3.4.5.6
}
for hostname, expected := range test.Tests {
- actual := cfg.Match(hostname)
- if expected != actual {
- t.Errorf("cfg.Match(%q) is %q, want %q", hostname, actual, expected)
+ backend, proxy := cfg.Match(hostname)
+ if expected.backend != backend {
+ t.Errorf("cfg.Match(%q) is %q, want %q", hostname, backend, expected.backend)
+ }
+ if expected.proxy != proxy {
+ t.Errorf("cfg.Match(%q).proxy is %v, want %v", hostname, proxy, expected.proxy)
}
}
}
diff --git a/e2e_test.go b/e2e_test.go
index 769e60c..c53e8c5 100644
--- a/e2e_test.go
+++ b/e2e_test.go
@@ -12,31 +12,40 @@ import (
"io/ioutil"
"math/big"
"net"
+ "strings"
"sync/atomic"
"testing"
"time"
+
+ proxyproto "github.com/armon/go-proxyproto"
)
func TestRouting(t *testing.T) {
- // Two backend servers
- s1, err := serveTLS(t, "server1", "test.com")
+ // Backend servers
+ s1, err := serveTLS(t, "server1", false, "test.com")
if err != nil {
t.Fatalf("serve TLS server1: %s", err)
}
defer s1.Close()
- s2, err := serveTLS(t, "server2", "foo.net")
+ s2, err := serveTLS(t, "server2", false, "foo.net")
if err != nil {
t.Fatalf("serve TLS server2: %s", err)
}
defer s2.Close()
- s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid")
+ s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid")
if err != nil {
t.Fatalf("server TLS server3: %s", err)
}
defer s3.Close()
+ s4, err := serveTLS(t, "server4", true, "proxy.design")
+ if err != nil {
+ t.Fatalf("server TLS server4: %s", err)
+ }
+ defer s4.Close()
+
// One proxy
var p Proxy
l, err := net.Listen("tcp", "localhost:0")
@@ -50,21 +59,24 @@ func TestRouting(t *testing.T) {
test.com %s
foo.net %s
borkbork.tf %s
-`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil {
+proxy.design %s PROXY
+`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil {
t.Fatalf("configure proxy: %s", err)
}
for _, test := range []struct {
- N, V string
- P *x509.CertPool
- OK bool
+ N, V string
+ P *x509.CertPool
+ OK bool
+ Transparent bool
}{
- {"test.com", "server1", s1.Pool, true},
- {"foo.net", "server2", s2.Pool, true},
- {"bar.org", "", s1.Pool, false},
- {"blarghblargh.acme.invalid", "server3", s3.Pool, true},
+ {"test.com", "server1", s1.Pool, true, false},
+ {"foo.net", "server2", s2.Pool, true, false},
+ {"bar.org", "", s1.Pool, false, false},
+ {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false},
+ {"proxy.design", "server4", s4.Pool, true, true},
} {
- res, err := getTLS(l.Addr().String(), test.N, test.P)
+ res, transparent, err := getTLS(l.Addr().String(), test.N, test.P)
switch {
case test.OK && err != nil:
t.Fatalf("get %q failed: %s", test.N, err)
@@ -72,25 +84,36 @@ borkbork.tf %s
t.Fatalf("get %q should have failed, but returned %q", test.N, res)
case test.OK && res != test.V:
t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ case test.OK && transparent != test.Transparent:
+ t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent)
}
}
}
-func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) {
+// getTLS attempts to set up a TLS session using the given proxy
+// address, domain, and cert pool. It returns the value served by the
+// server, as well as a bool indicating whether the server knew the
+// true client address, indicating that the PROXY protocol was in use.
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) {
cfg := tls.Config{
RootCAs: pool,
ServerName: domain,
}
conn, err := tls.Dial("tcp", addr, &cfg)
if err != nil {
- return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
}
defer conn.Close()
bs, err := ioutil.ReadAll(conn)
if err != nil {
- return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ fs := strings.Split(string(bs), " ")
+ if len(fs) != 2 {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs))
}
- return string(bs), nil
+ transparent := fs[1] == conn.LocalAddr().String()
+ return fs[0], transparent, nil
}
type tlsServer struct {
@@ -110,7 +133,7 @@ func (s *tlsServer) Serve() {
return
}
atomic.AddUint32(&s.NumHits, 1)
- c.Write([]byte(s.Value))
+ fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr())
c.Close()
}
}
@@ -123,7 +146,7 @@ func (s *tlsServer) Close() error {
return s.l.Close()
}
-func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) {
+func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) {
cert, pool, err := selfSignedCert(domains)
if err != nil {
return nil, err
@@ -134,11 +157,19 @@ func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error)
}
cfg.BuildNameToCertificate()
- l, err := tls.Listen("tcp", "localhost:0", cfg)
+ var l net.Listener
+
+ l, err = net.Listen("tcp", "localhost:0")
if err != nil {
return nil, err
}
+ if understandProxy {
+ l = &proxyproto.Listener{Listener: l}
+ }
+
+ l = tls.NewListener(l, cfg)
+
ret := &tlsServer{
Domains: domains,
Value: value,
diff --git a/main.go b/main.go
index cfdf5cc..ff1a816 100644
--- a/main.go
+++ b/main.go
@@ -134,7 +134,8 @@ func (c *Conn) proxy() {
return
}
- c.backend = c.config.Match(c.hostname)
+ addProxyHeader := false
+ c.backend, addProxyHeader = c.config.Match(c.hostname)
if c.backend == "" {
c.sniFailed("no backend found for %q", c.hostname)
return
@@ -150,6 +151,21 @@ func (c *Conn) proxy() {
c.backendConn = backend.(*net.TCPConn)
+ // If the backend supports the HAProxy PROXY protocol, give it the
+ // real source information about the connection.
+ if addProxyHeader {
+ remote := c.TCPConn.RemoteAddr().(*net.TCPAddr)
+ local := c.TCPConn.LocalAddr().(*net.TCPAddr)
+ family := "TCP6"
+ if remote.IP.To4() != nil {
+ family = "TCP4"
+ }
+ if _, err := fmt.Fprintf(c.backendConn, "PROXY %s %s %s %d %d\r\n", family, remote.IP, local.IP, remote.Port, local.Port); err != nil {
+ c.internalError("failed to send PROXY header to %q: %s", c.backend, err)
+ return
+ }
+ }
+
// Replay the piece of the handshake we had to read to do the
// routing, then blindly proxy any other bytes.
if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil {