diff options
Diffstat (limited to 'tcpproxy_test.go')
-rw-r--r-- | tcpproxy_test.go | 175 |
1 files changed, 172 insertions, 3 deletions
diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 7150372..ac7c917 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -17,19 +17,26 @@ package tcpproxy import ( "bufio" "bytes" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" "io/ioutil" + "math/big" "net" "strings" "testing" + "time" ) type noopTarget struct{} -func (t *noopTarget) HandleConn(net.Conn) {} +func (noopTarget) HandleConn(net.Conn) {} func TestMatchHTTPHost(t *testing.T) { tests := []struct { @@ -57,7 +64,6 @@ func TestMatchHTTPHost(t *testing.T) { want: true, }, } - target := &noopTarget{} for i, tt := range tests { name := tt.name if name == "" { @@ -65,7 +71,7 @@ func TestMatchHTTPHost(t *testing.T) { } t.Run(name, func(t *testing.T) { br := bufio.NewReader(tt.r) - r := httpHostMatch{tt.host, target} + r := httpHostMatch{tt.host, noopTarget{}} got := r.match(br) != nil if got != tt.want { t.Fatalf("match = %v; want %v", got, tt.want) @@ -313,3 +319,166 @@ func TestProxyPROXYOut(t *testing.T) { t.Fatalf("got %q; want %q", bs, want) } } + +type tlsServer struct { + Listener net.Listener + Domain string + Test *testing.T +} + +func (t *tlsServer) Start() { + cert, acmeCert := cert(t.Test, t.Domain), cert(t.Test, t.Domain+".acme.invalid") + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert, acmeCert}, + } + cfg.BuildNameToCertificate() + + go func() { + for { + rawConn, err := t.Listener.Accept() + if err != nil { + return // assume Close() + } + + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, t.Domain); err != nil { + t.Test.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() +} + +func (t *tlsServer) Close() { + t.Listener.Close() +} + +// cert creates a well-formed, but completely insecure self-signed +// cert for domain. +func cert(t *testing.T, domain string) tls.Certificate { + private, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + CommonName: domain, + }, + NotBefore: time.Time{}, + NotAfter: time.Now().Add(60 * time.Minute), + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private) + if err != nil { + t.Fatal(err) + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)}) + + tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes()) + if err != nil { + t.Fatal(err) + } + + return tlscert +} + +// newTLSServer starts a TLS server that serves a self-signed cert for +// domain, and a corresonding acme.invalid dummy domain. +func newTLSServer(t *testing.T, domain string) net.Listener { + cert, acmeCert := cert(t, domain), cert(t, domain+".acme.invalid") + + l := newLocalListener(t) + go func() { + for { + rawConn, err := l.Accept() + if err != nil { + return // assume closed + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert, acmeCert}, + } + cfg.BuildNameToCertificate() + conn := tls.Server(rawConn, cfg) + if _, err = io.WriteString(conn, domain); err != nil { + t.Errorf("writing to tlsconn: %s", err) + } + conn.Close() + } + }() + + return l +} + +func readTLS(dest, domain string) (string, error) { + conn, err := tls.Dial("tcp", dest, &tls.Config{ + ServerName: domain, + InsecureSkipVerify: true, + }) + if err != nil { + return "", err + } + defer conn.Close() + + bs, err := ioutil.ReadAll(conn) + if err != nil { + return "", err + } + return string(bs), nil +} + +func TestProxyACME(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newTLSServer(t, "foo.com") + defer backFoo.Close() + backBar := newTLSServer(t, "bar.com") + defer backBar.Close() + backQuux := newTLSServer(t, "quux.com") + defer backQuux.Close() + + p := testProxy(t, front) + p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + p.AddStopACMESearch(testFrontAddr) + p.AddSNIRoute(testFrontAddr, "quux.com", To(backQuux.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + tests := []struct { + domain, want string + succeeds bool + }{ + {"foo.com", "foo.com", true}, + {"bar.com", "bar.com", true}, + {"quux.com", "quux.com", true}, + {"xyzzy.com", "", false}, + {"foo.com.acme.invalid", "foo.com", true}, + {"bar.com.acme.invalid", "bar.com", true}, + {"quux.com.acme.invalid", "", false}, + } + for _, test := range tests { + got, err := readTLS(front.Addr().String(), test.domain) + if test.succeeds { + if err != nil { + t.Fatalf("readTLS %q got error %q, want nil", test.domain, err) + } + if got != test.want { + t.Fatalf("readTLS %q got %q, want %q", test.domain, got, test.want) + } + } else if err == nil { + t.Fatalf("readTLS %q unexpectedly succeeded", test.domain) + } + } +} |