summaryrefslogtreecommitdiff
path: root/tcpproxy_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'tcpproxy_test.go')
-rw-r--r--tcpproxy_test.go175
1 files changed, 172 insertions, 3 deletions
diff --git a/tcpproxy_test.go b/tcpproxy_test.go
index 7150372..ac7c917 100644
--- a/tcpproxy_test.go
+++ b/tcpproxy_test.go
@@ -17,19 +17,26 @@ package tcpproxy
import (
"bufio"
"bytes"
+ "crypto/rand"
+ "crypto/rsa"
"crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
+ "math/big"
"net"
"strings"
"testing"
+ "time"
)
type noopTarget struct{}
-func (t *noopTarget) HandleConn(net.Conn) {}
+func (noopTarget) HandleConn(net.Conn) {}
func TestMatchHTTPHost(t *testing.T) {
tests := []struct {
@@ -57,7 +64,6 @@ func TestMatchHTTPHost(t *testing.T) {
want: true,
},
}
- target := &noopTarget{}
for i, tt := range tests {
name := tt.name
if name == "" {
@@ -65,7 +71,7 @@ func TestMatchHTTPHost(t *testing.T) {
}
t.Run(name, func(t *testing.T) {
br := bufio.NewReader(tt.r)
- r := httpHostMatch{tt.host, target}
+ r := httpHostMatch{tt.host, noopTarget{}}
got := r.match(br) != nil
if got != tt.want {
t.Fatalf("match = %v; want %v", got, tt.want)
@@ -313,3 +319,166 @@ func TestProxyPROXYOut(t *testing.T) {
t.Fatalf("got %q; want %q", bs, want)
}
}
+
+type tlsServer struct {
+ Listener net.Listener
+ Domain string
+ Test *testing.T
+}
+
+func (t *tlsServer) Start() {
+ cert, acmeCert := cert(t.Test, t.Domain), cert(t.Test, t.Domain+".acme.invalid")
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert, acmeCert},
+ }
+ cfg.BuildNameToCertificate()
+
+ go func() {
+ for {
+ rawConn, err := t.Listener.Accept()
+ if err != nil {
+ return // assume Close()
+ }
+
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, t.Domain); err != nil {
+ t.Test.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+}
+
+func (t *tlsServer) Close() {
+ t.Listener.Close()
+}
+
+// cert creates a well-formed, but completely insecure self-signed
+// cert for domain.
+func cert(t *testing.T, domain string) tls.Certificate {
+ private, err := rsa.GenerateKey(rand.Reader, 512)
+ if err != nil {
+ t.Fatal(err)
+ }
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Test Co"},
+ CommonName: domain,
+ },
+ NotBefore: time.Time{},
+ NotAfter: time.Now().Add(60 * time.Minute),
+ IsCA: true,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &private.PublicKey, private)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var cert, key bytes.Buffer
+ pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(private)})
+
+ tlscert, err := tls.X509KeyPair(cert.Bytes(), key.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return tlscert
+}
+
+// newTLSServer starts a TLS server that serves a self-signed cert for
+// domain, and a corresonding acme.invalid dummy domain.
+func newTLSServer(t *testing.T, domain string) net.Listener {
+ cert, acmeCert := cert(t, domain), cert(t, domain+".acme.invalid")
+
+ l := newLocalListener(t)
+ go func() {
+ for {
+ rawConn, err := l.Accept()
+ if err != nil {
+ return // assume closed
+ }
+
+ cfg := &tls.Config{
+ Certificates: []tls.Certificate{cert, acmeCert},
+ }
+ cfg.BuildNameToCertificate()
+ conn := tls.Server(rawConn, cfg)
+ if _, err = io.WriteString(conn, domain); err != nil {
+ t.Errorf("writing to tlsconn: %s", err)
+ }
+ conn.Close()
+ }
+ }()
+
+ return l
+}
+
+func readTLS(dest, domain string) (string, error) {
+ conn, err := tls.Dial("tcp", dest, &tls.Config{
+ ServerName: domain,
+ InsecureSkipVerify: true,
+ })
+ if err != nil {
+ return "", err
+ }
+ defer conn.Close()
+
+ bs, err := ioutil.ReadAll(conn)
+ if err != nil {
+ return "", err
+ }
+ return string(bs), nil
+}
+
+func TestProxyACME(t *testing.T) {
+ front := newLocalListener(t)
+ defer front.Close()
+
+ backFoo := newTLSServer(t, "foo.com")
+ defer backFoo.Close()
+ backBar := newTLSServer(t, "bar.com")
+ defer backBar.Close()
+ backQuux := newTLSServer(t, "quux.com")
+ defer backQuux.Close()
+
+ p := testProxy(t, front)
+ p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String()))
+ p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
+ p.AddStopACMESearch(testFrontAddr)
+ p.AddSNIRoute(testFrontAddr, "quux.com", To(backQuux.Addr().String()))
+ if err := p.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ domain, want string
+ succeeds bool
+ }{
+ {"foo.com", "foo.com", true},
+ {"bar.com", "bar.com", true},
+ {"quux.com", "quux.com", true},
+ {"xyzzy.com", "", false},
+ {"foo.com.acme.invalid", "foo.com", true},
+ {"bar.com.acme.invalid", "bar.com", true},
+ {"quux.com.acme.invalid", "", false},
+ }
+ for _, test := range tests {
+ got, err := readTLS(front.Addr().String(), test.domain)
+ if test.succeeds {
+ if err != nil {
+ t.Fatalf("readTLS %q got error %q, want nil", test.domain, err)
+ }
+ if got != test.want {
+ t.Fatalf("readTLS %q got %q, want %q", test.domain, got, test.want)
+ }
+ } else if err == nil {
+ t.Fatalf("readTLS %q unexpectedly succeeded", test.domain)
+ }
+ }
+}