diff options
Diffstat (limited to 'e2e_test.go')
-rw-r--r-- | e2e_test.go | 71 |
1 files changed, 51 insertions, 20 deletions
diff --git a/e2e_test.go b/e2e_test.go index 769e60c..c53e8c5 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -12,31 +12,40 @@ import ( "io/ioutil" "math/big" "net" + "strings" "sync/atomic" "testing" "time" + + proxyproto "github.com/armon/go-proxyproto" ) func TestRouting(t *testing.T) { - // Two backend servers - s1, err := serveTLS(t, "server1", "test.com") + // Backend servers + s1, err := serveTLS(t, "server1", false, "test.com") if err != nil { t.Fatalf("serve TLS server1: %s", err) } defer s1.Close() - s2, err := serveTLS(t, "server2", "foo.net") + s2, err := serveTLS(t, "server2", false, "foo.net") if err != nil { t.Fatalf("serve TLS server2: %s", err) } defer s2.Close() - s3, err := serveTLS(t, "server3", "blarghblargh.acme.invalid") + s3, err := serveTLS(t, "server3", false, "blarghblargh.acme.invalid") if err != nil { t.Fatalf("server TLS server3: %s", err) } defer s3.Close() + s4, err := serveTLS(t, "server4", true, "proxy.design") + if err != nil { + t.Fatalf("server TLS server4: %s", err) + } + defer s4.Close() + // One proxy var p Proxy l, err := net.Listen("tcp", "localhost:0") @@ -50,21 +59,24 @@ func TestRouting(t *testing.T) { test.com %s foo.net %s borkbork.tf %s -`, s1.Addr(), s2.Addr(), s3.Addr())); err != nil { +proxy.design %s PROXY +`, s1.Addr(), s2.Addr(), s3.Addr(), s4.Addr())); err != nil { t.Fatalf("configure proxy: %s", err) } for _, test := range []struct { - N, V string - P *x509.CertPool - OK bool + N, V string + P *x509.CertPool + OK bool + Transparent bool }{ - {"test.com", "server1", s1.Pool, true}, - {"foo.net", "server2", s2.Pool, true}, - {"bar.org", "", s1.Pool, false}, - {"blarghblargh.acme.invalid", "server3", s3.Pool, true}, + {"test.com", "server1", s1.Pool, true, false}, + {"foo.net", "server2", s2.Pool, true, false}, + {"bar.org", "", s1.Pool, false, false}, + {"blarghblargh.acme.invalid", "server3", s3.Pool, true, false}, + {"proxy.design", "server4", s4.Pool, true, true}, } { - res, err := getTLS(l.Addr().String(), test.N, test.P) + res, transparent, err := getTLS(l.Addr().String(), test.N, test.P) switch { case test.OK && err != nil: t.Fatalf("get %q failed: %s", test.N, err) @@ -72,25 +84,36 @@ borkbork.tf %s t.Fatalf("get %q should have failed, but returned %q", test.N, res) case test.OK && res != test.V: t.Fatalf("got wrong value from %q, got %q, want %q", test.N, res, test.V) + case test.OK && transparent != test.Transparent: + t.Fatalf("connection transparency for %q was %v, want %v", test.N, transparent, test.Transparent) } } } -func getTLS(addr string, domain string, pool *x509.CertPool) (string, error) { +// getTLS attempts to set up a TLS session using the given proxy +// address, domain, and cert pool. It returns the value served by the +// server, as well as a bool indicating whether the server knew the +// true client address, indicating that the PROXY protocol was in use. +func getTLS(addr string, domain string, pool *x509.CertPool) (string, bool, error) { cfg := tls.Config{ RootCAs: pool, ServerName: domain, } conn, err := tls.Dial("tcp", addr, &cfg) if err != nil { - return "", fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) + return "", false, fmt.Errorf("dial TLS %q for %q: %s", addr, domain, err) } defer conn.Close() bs, err := ioutil.ReadAll(conn) if err != nil { - return "", fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + return "", false, fmt.Errorf("read TLS from %q (domain %q): %s", addr, domain, err) + } + fs := strings.Split(string(bs), " ") + if len(fs) != 2 { + return "", false, fmt.Errorf("read TLS from %q (domain %q): incoherent response %q", addr, domain, string(bs)) } - return string(bs), nil + transparent := fs[1] == conn.LocalAddr().String() + return fs[0], transparent, nil } type tlsServer struct { @@ -110,7 +133,7 @@ func (s *tlsServer) Serve() { return } atomic.AddUint32(&s.NumHits, 1) - c.Write([]byte(s.Value)) + fmt.Fprintf(c, "%s %s", s.Value, c.RemoteAddr()) c.Close() } } @@ -123,7 +146,7 @@ func (s *tlsServer) Close() error { return s.l.Close() } -func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) { +func serveTLS(t *testing.T, value string, understandProxy bool, domains ...string) (*tlsServer, error) { cert, pool, err := selfSignedCert(domains) if err != nil { return nil, err @@ -134,11 +157,19 @@ func serveTLS(t *testing.T, value string, domains ...string) (*tlsServer, error) } cfg.BuildNameToCertificate() - l, err := tls.Listen("tcp", "localhost:0", cfg) + var l net.Listener + + l, err = net.Listen("tcp", "localhost:0") if err != nil { return nil, err } + if understandProxy { + l = &proxyproto.Listener{Listener: l} + } + + l = tls.NewListener(l, cfg) + ret := &tlsServer{ Domains: domains, Value: value, |