summaryrefslogtreecommitdiff
path: root/e2e_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'e2e_test.go')
-rw-r--r--e2e_test.go71
1 files changed, 51 insertions, 20 deletions
diff --git a/e2e_test.go b/e2e_test.go
index 769e60c..c53e8c5 100644
--- a/e2e_test.go
+++ b/e2e_test.go
@@ -12,31 +12,40 @@ import (
"io/ioutil"
"math/big"
"net"
+ "strings"
"sync/atomic"
"testing"
"time"
+
+ proxyproto "github.com/armon/go-proxyproto"
)
func TestRouting(t *testing.T) {
- // Two backend servers
- s1, err := serveTLS(t, "server1", "test.com")
+ // Backend servers
+ s1, err := serveTLS(t, "server1", false, "test.com")
if err != nil {
t.Fatalf("serve TLS server1: %s", err)
}
defer s1.Close()
- s2, err := serveTLS(t, "server2", "foo.net")
+ s2, err := serveTLS(t, "server2", false, "foo.net")
if err != nil {
t.Fatalf("serve TLS server2: %s", err)
}
defer s2.Close()
- s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid")
+ s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid")
if err != nil {
t.Fatalf("server TLS server3: %s", err)
}
defer s3.Close()
+ s4, err := serveTLS(t, "server4", true, "proxy.design")
+ if err != nil {
+ t.Fatalf("server TLS server4: %s", err)
+ }
+ defer s4.Close()
+
// One proxy
var p Proxy
l, err := net.Listen("tcp", "localhost:0")
@@ -50,21 +59,24 @@ func TestRouting(t *testing.T) {
test.com %s
foo.net %s
borkbork.tf %s
-`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil {
+proxy.design %s PROXY
+`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil {
t.Fatalf("configure proxy: %s", err)
}
for _, test := range []struct {
- N, V string
- P *x509.CertPool
- OK bool
+ N, V string
+ P *x509.CertPool
+ OK bool
+ Transparent bool
}{
- {"test.com", "server1", s1.Pool, true},
- {"foo.net", "server2", s2.Pool, true},
- {"bar.org", "", s1.Pool, false},
- {"blarghblargh.acme.invalid", "server3", s3.Pool, true},
+ {"test.com", "server1", s1.Pool, true, false},
+ {"foo.net", "server2", s2.Pool, true, false},
+ {"bar.org", "", s1.Pool, false, false},
+ {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false},
+ {"proxy.design", "server4", s4.Pool, true, true},
} {
- res, err := getTLS(l.Addr().String(), test.N, test.P)
+ res, transparent, err := getTLS(l.Addr().String(), test.N, test.P)
switch {
case test.OK && err != nil:
t.Fatalf("get %q failed: %s", test.N, err)
@@ -72,25 +84,36 @@ borkbork.tf %s
t.Fatalf("get %q should have failed, but returned %q", test.N, res)
case test.OK && res != test.V:
t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V)
+ case test.OK && transparent != test.Transparent:
+ t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent)
}
}
}
-func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) {
+// getTLS attempts to set up a TLS session using the given proxy
+// address, domain, and cert pool. It returns the value served by the
+// server, as well as a bool indicating whether the server knew the
+// true client address, indicating that the PROXY protocol was in use.
+func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) {
cfg := tls.Config{
RootCAs: pool,
ServerName: domain,
}
conn, err := tls.Dial("tcp", addr, &cfg)
if err != nil {
- return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
+ return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err)
}
defer conn.Close()
bs, err := ioutil.ReadAll(conn)
if err != nil {
- return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err)
+ }
+ fs := strings.Split(string(bs), " ")
+ if len(fs) != 2 {
+ return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs))
}
- return string(bs), nil
+ transparent := fs[1] == conn.LocalAddr().String()
+ return fs[0], transparent, nil
}
type tlsServer struct {
@@ -110,7 +133,7 @@ func (s *tlsServer) Serve() {
return
}
atomic.AddUint32(&s.NumHits, 1)
- c.Write([]byte(s.Value))
+ fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr())
c.Close()
}
}
@@ -123,7 +146,7 @@ func (s *tlsServer) Close() error {
return s.l.Close()
}
-func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) {
+func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) {
cert, pool, err := selfSignedCert(domains)
if err != nil {
return nil, err
@@ -134,11 +157,19 @@ func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error)
}
cfg.BuildNameToCertificate()
- l, err := tls.Listen("tcp", "localhost:0", cfg)
+ var l net.Listener
+
+ l, err = net.Listen("tcp", "localhost:0")
if err != nil {
return nil, err
}
+ if understandProxy {
+ l = &proxyproto.Listener{Listener: l}
+ }
+
+ l = tls.NewListener(l, cfg)
+
ret := &tlsServer{
Domains: domains,
Value: value,