diff options
Diffstat (limited to 'tcpproxy_test.go')
-rw-r--r-- | tcpproxy_test.go | 270 |
1 files changed, 270 insertions, 0 deletions
diff --git a/tcpproxy_test.go b/tcpproxy_test.go new file mode 100644 index 0000000..4cfb0ab --- /dev/null +++ b/tcpproxy_test.go @@ -0,0 +1,270 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "strings" + "testing" +) + +func TestMatchHTTPHost(t *testing.T) { + tests := []struct { + name string + r io.Reader + host string + want bool + }{ + { + name: "match", + r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), + host: "foo.com", + want: true, + }, + { + name: "no-match", + r: strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n"), + host: "bar.com", + want: false, + }, + { + name: "match-huge-request", + r: io.MultiReader(strings.NewReader("GET / HTTP/1.1\r\nHost: foo.com\r\n"), neverEnding('a')), + host: "foo.com", + want: true, + }, + } + for i, tt := range tests { + name := tt.name + if name == "" { + name = fmt.Sprintf("test_index_%d", i) + } + t.Run(name, func(t *testing.T) { + br := bufio.NewReader(tt.r) + var matcher matcher = httpHostMatch(tt.host) + got := matcher.match(br) + if got != tt.want { + t.Fatalf("match = %v; want %v", got, tt.want) + } + get := make([]byte, 3) + if _, err := io.ReadFull(br, get); err != nil { + t.Fatal(err) + } + if string(get) != "GET" { + t.Fatalf("did bufio.Reader consume bytes? got %q; want GET", get) + } + }) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +type recordWritesConn struct { + buf bytes.Buffer + net.Conn +} + +func (c *recordWritesConn) Write(p []byte) (int, error) { + c.buf.Write(p) + return len(p), nil +} + +func (c *recordWritesConn) Read(p []byte) (int, error) { return 0, io.EOF } + +func clientHelloRecord(t *testing.T, hostName string) string { + rec := new(recordWritesConn) + cl := tls.Client(rec, &tls.Config{ServerName: hostName}) + cl.Handshake() + + s := rec.buf.String() + if !strings.Contains(s, hostName) { + t.Fatalf("clientHello sent in test didn't contain %q", hostName) + } + return s +} + +func TestSNI(t *testing.T) { + const hostName = "foo.com" + greeting := clientHelloRecord(t, hostName) + got := clientHelloServerName(bufio.NewReader(strings.NewReader(greeting))) + if got != hostName { + t.Errorf("got SNI %q; want %q", got, hostName) + } +} + +func TestProxyStartNone(t *testing.T) { + var p Proxy + if err := p.Start(); err != nil { + t.Fatal(err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp", "[::1]:0") + if err != nil { + t.Fatal(err) + } + } + return ln +} + +const testFrontAddr = "1.2.3.4:567" + +func testListenFunc(t *testing.T, ln net.Listener) func(network, laddr string) (net.Listener, error) { + return func(network, laddr string) (net.Listener, error) { + if network != "tcp" { + t.Errorf("got Listen call with network %q, not tcp", network) + return nil, errors.New("invalid network") + } + if laddr != testFrontAddr { + t.Fatalf("got Listen call with laddr %q, want %q", laddr, testFrontAddr) + panic("bogus address") + } + return ln, nil + } +} + +func testProxy(t *testing.T, front net.Listener) *Proxy { + return &Proxy{ + ListenFunc: testListenFunc(t, front), + } +} + +func TestProxyAlwaysMatch(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + back := newLocalListener(t) + defer back.Close() + + p := testProxy(t, front) + p.AddRoute(testFrontAddr, To(back.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + const msg = "message" + io.WriteString(toFront, msg) + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxyHTTP(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddHTTPHostRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddHTTPHostRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + +func TestProxySNI(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddSNIRoute(testFrontAddr, "foo.com", To(backFoo.Addr().String())) + p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String())) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + msg := clientHelloRecord(t, "bar.com") + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} |