diff options
-rwxr-xr-x | .github/workflows/ci-build.yml | 31 | ||||
-rwxr-xr-x | .golangci.yml | 1 | ||||
-rw-r--r-- | Makefile | 6 | ||||
-rw-r--r-- | codecov.yml | 4 | ||||
-rw-r--r-- | plugins/gzip/tests/plugin_test.go | 4 | ||||
-rw-r--r-- | plugins/headers/config.go | 36 | ||||
-rw-r--r-- | plugins/headers/plugin.go | 117 | ||||
-rw-r--r-- | plugins/headers/tests/configs/.rr-cors-headers.yaml | 37 | ||||
-rw-r--r-- | plugins/headers/tests/configs/.rr-headers-init.yaml | 37 | ||||
-rw-r--r-- | plugins/headers/tests/configs/.rr-req-headers.yaml | 30 | ||||
-rw-r--r-- | plugins/headers/tests/configs/.rr-res-headers.yaml | 30 | ||||
-rw-r--r-- | plugins/headers/tests/headers_plugin_test.go | 359 | ||||
-rw-r--r-- | plugins/http/plugin.go | 10 | ||||
-rw-r--r-- | plugins/http/tests/http_test.go | 25 | ||||
-rw-r--r-- | plugins/logger/plugin.go | 8 | ||||
-rw-r--r-- | plugins/metrics/plugin.go | 8 | ||||
-rw-r--r-- | plugins/metrics/tests/metrics_test.go | 4 | ||||
-rwxr-xr-x | plugins/rpc/plugin.go | 10 | ||||
-rw-r--r-- | plugins/server/plugin.go | 6 | ||||
-rw-r--r-- | plugins/static/tests/static_plugin_test.go | 8 | ||||
-rwxr-xr-x | static_pool_test.go | 1 |
21 files changed, 711 insertions, 61 deletions
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index c64f0f11..7e3e9c03 100755 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -65,26 +65,27 @@ jobs: - name: Run golang tests run: | - go test -v -race . -tags=debug -coverprofile=lib.txt -covermode=atomic - go test -v -race ./plugins/rpc -tags=debug -coverprofile=rpc_config.txt -covermode=atomic - go test -v -race ./plugins/rpc/tests -tags=debug -coverprofile=rpc.txt -covermode=atomic - go test -v -race ./plugins/config/tests -tags=debug -coverprofile=plugin_config.txt -covermode=atomic - go test -v -race ./plugins/logger/tests -tags=debug -coverprofile=logger.txt -covermode=atomic - go test -v -race ./plugins/server/tests -tags=debug -coverprofile=server.txt -covermode=atomic - go test -v -race ./plugins/metrics/tests -tags=debug -coverprofile=metrics.txt -covermode=atomic - go test -v -race ./plugins/informer/tests -tags=debug -coverprofile=informer.txt -covermode=atomic - go test -v -race ./plugins/resetter/tests -tags=debug -coverprofile=informer.txt -covermode=atomic - go test -v -race ./plugins/http/attributes -tags=debug -coverprofile=attributes.txt -covermode=atomic - go test -v -race ./plugins/http/tests -tags=debug -coverprofile=http_tests.txt -covermode=atomic - go test -v -race ./plugins/gzip/tests -tags=debug -coverprofile=gzip.txt -covermode=atomic - go test -v -race -cover ./plugins/static/tests -tags=debug -coverprofile=static.txt -covermode=atomic - go test -v -race -cover ./plugins/static -tags=debug -coverprofile=static_root.txt -covermode=atomic + go test -v -race -cover -tags=debug -coverprofile=lib.txt -covermode=atomic . + go test -v -race -cover -tags=debug -coverprofile=rpc_config.txt -covermode=atomic ./plugins/rpc + go test -v -race -cover -tags=debug -coverprofile=rpc.txt -covermode=atomic ./plugins/rpc/tests + go test -v -race -cover -tags=debug -coverprofile=plugin_config.txt -covermode=atomic ./plugins/config/tests + go test -v -race -cover -tags=debug -coverprofile=logger.txt -covermode=atomic ./plugins/logger/tests + go test -v -race -cover -tags=debug -coverprofile=server.txt -covermode=atomic ./plugins/server/tests + go test -v -race -cover -tags=debug -coverprofile=metrics.txt -covermode=atomic ./plugins/metrics/tests + go test -v -race -cover -tags=debug -coverprofile=informer.txt -covermode=atomic ./plugins/informer/tests + go test -v -race -cover -tags=debug -coverprofile=informer.txt -covermode=atomic ./plugins/resetter/tests + go test -v -race -cover -tags=debug -coverprofile=attributes.txt -covermode=atomic ./plugins/http/attributes + go test -v -race -cover -tags=debug -coverprofile=http_tests.txt -covermode=atomic ./plugins/http/tests + go test -v -race -cover -tags=debug -coverprofile=gzip.txt -covermode=atomic ./plugins/gzip/tests + go test -v -race -cover -tags=debug -coverprofile=static.txt -covermode=atomic ./plugins/static/tests + go test -v -race -cover -tags=debug -coverprofile=static_root.txt -covermode=atomic ./plugins/static + go test -v -race -cover -tags=debug -coverprofile=headers.txt -covermode=atomic ./plugins/headers/tests - name: Run code coverage uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} - files: static.txt, static_root.txt, gzip.txt, lib.txt, rpc_config.txt, rpc.txt, plugin_config.txt, logger.txt, server.txt, metrics.txt, informer.txt attributes.txt http_tests.txt + files: headers.txt, static.txt, static_root.txt, gzip.txt, lib.txt, rpc_config.txt, rpc.txt, plugin_config.txt, logger.txt, server.txt, metrics.txt, informer.txt attributes.txt http_tests.txt flags: unittests name: codecov-umbrella fail_ci_if_error: false diff --git a/.golangci.yml b/.golangci.yml index a49abafb..7a49f0c8 100755 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,6 +5,7 @@ run: - plugins/http/tests/rpc_test_old.go - plugins/http/tests/config_test.go - plugins/static/tests/static_plugin_test.go + - plugins/headers/tests/old.go linters: disable-all: true enable: @@ -13,4 +13,8 @@ test: go test -v -race -cover ./plugins/http/tests -tags=debug go test -v -race -cover ./plugins/gzip/tests -tags=debug go test -v -race -cover ./plugins/static/tests -tags=debug - go test -v -race -cover ./plugins/static -tags=debug
\ No newline at end of file + go test -v -race -cover ./plugins/static -tags=debug + go test -v -race -cover ./plugins/headers/tests -tags=debug + +test_headers: + go test -v -race -cover ./plugins/headers/tests -tags=debug
\ No newline at end of file diff --git a/codecov.yml b/codecov.yml index 5dd21786..8b6b8760 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,4 +1,4 @@ coverage: status: - project: off - patch: off
\ No newline at end of file + project: true + patch: false
\ No newline at end of file diff --git a/plugins/gzip/tests/plugin_test.go b/plugins/gzip/tests/plugin_test.go index c7f12643..39979895 100644 --- a/plugins/gzip/tests/plugin_test.go +++ b/plugins/gzip/tests/plugin_test.go @@ -21,7 +21,7 @@ import ( ) func TestGzipPlugin(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -102,7 +102,7 @@ func headerCheck(t *testing.T) { } func TestMiddlewareNotExist(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ diff --git a/plugins/headers/config.go b/plugins/headers/config.go new file mode 100644 index 00000000..8d4e29c2 --- /dev/null +++ b/plugins/headers/config.go @@ -0,0 +1,36 @@ +package headers + +// Config declares headers service configuration. +type Config struct { + Headers struct { + // CORS settings. + CORS *CORSConfig + + // Request headers to add to every payload send to PHP. + Request map[string]string + + // Response headers to add to every payload generated by PHP. + Response map[string]string + } +} + +// CORSConfig headers configuration. +type CORSConfig struct { + // AllowedOrigin: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowedOrigin string + + // AllowedHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowedHeaders string + + // AllowedMethods: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowedMethods string + + // AllowCredentials https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials *bool + + // ExposeHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + ExposedHeaders string + + // MaxAge of CORS headers in seconds/ + MaxAge int +} diff --git a/plugins/headers/plugin.go b/plugins/headers/plugin.go new file mode 100644 index 00000000..f1c6e6f3 --- /dev/null +++ b/plugins/headers/plugin.go @@ -0,0 +1,117 @@ +package headers + +import ( + "net/http" + "strconv" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// ID contains default service name. +const PluginName = "headers" +const RootPluginName = "http" + +// Service serves headers files. Potentially convert into middleware? +type Plugin struct { + // server configuration (location, forbidden files and etc) + cfg *Config +} + +// Init must return configure service and return true if service hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer) error { + const op = errors.Op("headers plugin init") + err := cfg.UnmarshalKey(RootPluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + return nil +} + +// middleware must return true if request/response pair is handled within the middleware. +func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if s.cfg.Headers.Request != nil { + for k, v := range s.cfg.Headers.Request { + r.Header.Add(k, v) + } + } + + if s.cfg.Headers.Response != nil { + for k, v := range s.cfg.Headers.Response { + w.Header().Set(k, v) + } + } + + if s.cfg.Headers.CORS != nil { + if r.Method == http.MethodOptions { + s.preflightRequest(w) + return + } + s.corsHeaders(w) + } + + next.ServeHTTP(w, r) + } +} + +func (s *Plugin) Name() string { + return PluginName +} + +// configure OPTIONS response +func (s *Plugin) preflightRequest(w http.ResponseWriter) { + headers := w.Header() + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.AllowedMethods != "" { + headers.Set("Access-Control-Allow-Methods", s.cfg.Headers.CORS.AllowedMethods) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } + + if s.cfg.Headers.CORS.MaxAge > 0 { + headers.Set("Access-Control-Max-Age", strconv.Itoa(s.cfg.Headers.CORS.MaxAge)) + } + + w.WriteHeader(http.StatusOK) +} + +// configure CORS headers +func (s *Plugin) corsHeaders(w http.ResponseWriter) { + headers := w.Header() + + headers.Add("Vary", "Origin") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.ExposedHeaders != "" { + headers.Set("Access-Control-Expose-Headers", s.cfg.Headers.CORS.ExposedHeaders) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } +} diff --git a/plugins/headers/tests/configs/.rr-cors-headers.yaml b/plugins/headers/tests/configs/.rr-cors-headers.yaml new file mode 100644 index 00000000..5c1a200b --- /dev/null +++ b/plugins/headers/tests/configs/.rr-cors-headers.yaml @@ -0,0 +1,37 @@ +server: + command: "php ../../../tests/http/client.php headers pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:22855 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + cors: + allowedOrigin: "*" + allowedHeaders: "*" + allowedMethods: "GET,POST,PUT,DELETE" + allowCredentials: true + exposedHeaders: "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma" + maxAge: 600 + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/configs/.rr-headers-init.yaml b/plugins/headers/tests/configs/.rr-headers-init.yaml new file mode 100644 index 00000000..252fe8f3 --- /dev/null +++ b/plugins/headers/tests/configs/.rr-headers-init.yaml @@ -0,0 +1,37 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:33453 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + cors: + allowedOrigin: "*" + allowedHeaders: "*" + allowedMethods: "GET,POST,PUT,DELETE" + allowCredentials: true + exposedHeaders: "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma" + maxAge: 600 + request: + "Example-Request-Header": "Value" + response: + "X-Powered-By": "RoadRunner" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/configs/.rr-req-headers.yaml b/plugins/headers/tests/configs/.rr-req-headers.yaml new file mode 100644 index 00000000..9256e98d --- /dev/null +++ b/plugins/headers/tests/configs/.rr-req-headers.yaml @@ -0,0 +1,30 @@ +server: + command: "php ../../../tests/http/client.php header pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:22655 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/configs/.rr-res-headers.yaml b/plugins/headers/tests/configs/.rr-res-headers.yaml new file mode 100644 index 00000000..1bca2c3d --- /dev/null +++ b/plugins/headers/tests/configs/.rr-res-headers.yaml @@ -0,0 +1,30 @@ +server: + command: "php ../../../tests/http/client.php header pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:22455 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/headers_plugin_test.go b/plugins/headers/tests/headers_plugin_test.go new file mode 100644 index 00000000..f1de8cb9 --- /dev/null +++ b/plugins/headers/tests/headers_plugin_test.go @@ -0,0 +1,359 @@ +package tests + +import ( + "io/ioutil" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/headers" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestHeadersInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-headers-init.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 5) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + wg.Wait() +} + +func TestRequestHeaders(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-req-headers.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("RequestHeaders", reqHeaders) + wg.Wait() +} + +func reqHeaders(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:22655?hello=value", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "CUSTOM-HEADER", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestResponseHeaders(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-res-headers.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("ResponseHeaders", resHeaders) + wg.Wait() +} + +func resHeaders(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:22455?hello=value", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, "output-header", r.Header.Get("output")) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "CUSTOM-HEADER", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestCORSHeaders(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-cors-headers.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("CORSHeaders", corsHeaders) + t.Run("CORSHeadersPass", corsHeadersPass) + wg.Wait() +} + +func corsHeadersPass(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:22855", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + + _, err = ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func corsHeaders(t *testing.T) { + req, err := http.NewRequest("OPTIONS", "http://localhost:22855", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers")) + assert.Equal(t, "GET,POST,PUT,DELETE", r.Header.Get("Access-Control-Allow-Methods")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "600", r.Header.Get("Access-Control-Max-Age")) + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + + _, err = ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + err = r.Body.Close() + assert.NoError(t, err) +} diff --git a/plugins/http/plugin.go b/plugins/http/plugin.go index 8e8957bd..79e8aa94 100644 --- a/plugins/http/plugin.go +++ b/plugins/http/plugin.go @@ -27,8 +27,8 @@ import ( ) const ( - // ID contains default service name. - ServiceName = "http" + // PluginName declares plugin name. + PluginName = "http" // EventInitSSL thrown at moment of https initialization. SSL server passed as context. EventInitSSL = 750 @@ -78,7 +78,7 @@ func (s *Plugin) AddListener(listener util.EventListener) { // misconfiguration. Services must not be used without proper configuration pushed first. func (s *Plugin) Init(cfg config.Configurer, log log.Logger, server factory.Server) error { const op = errors.Op("http Init") - err := cfg.UnmarshalKey(ServiceName, &s.cfg) + err := cfg.UnmarshalKey(PluginName, &s.cfg) if err != nil { return errors.E(op, err) } @@ -287,7 +287,7 @@ func (s *Plugin) Workers() []roadrunner.WorkerBase { } func (s *Plugin) Name() string { - return ServiceName + return PluginName } func (s *Plugin) Reset() error { @@ -298,7 +298,7 @@ func (s *Plugin) Reset() error { s.pool.Destroy(context.Background()) // re-read the config - err := s.configurer.UnmarshalKey(ServiceName, &s.cfg) + err := s.configurer.UnmarshalKey(PluginName, &s.cfg) if err != nil { return errors.E(op, err) } diff --git a/plugins/http/tests/http_test.go b/plugins/http/tests/http_test.go index 73bb53a0..06bf3f5d 100644 --- a/plugins/http/tests/http_test.go +++ b/plugins/http/tests/http_test.go @@ -41,7 +41,7 @@ var sslClient = &http.Client{ } func TestHTTPInit(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -104,7 +104,7 @@ func TestHTTPInit(t *testing.T) { } func TestHTTPInformerReset(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -225,7 +225,7 @@ func informerTest(t *testing.T) { } func TestSSL(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -353,7 +353,7 @@ func fcgiEcho(t *testing.T) { } func TestSSLRedirect(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -439,8 +439,7 @@ func sslRedirect(t *testing.T) { } func TestSSLPushPipes(t *testing.T) { - time.Sleep(time.Second) - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -528,7 +527,7 @@ func sslPush(t *testing.T) { } func TestFastCGI_RequestUri(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -612,7 +611,7 @@ func fcgiReqURI(t *testing.T) { } func TestH2CUpgrade(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -701,7 +700,7 @@ func h2cUpgrade(t *testing.T) { } func TestH2C(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -789,7 +788,7 @@ func h2c(t *testing.T) { } func TestHttpMiddleware(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -888,7 +887,7 @@ func middleware(t *testing.T) { } func TestHttpEchoErr(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -978,7 +977,7 @@ func echoError(t *testing.T) { } func TestHttpEnvVariables(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -1062,7 +1061,7 @@ func envVarsTest(t *testing.T) { } func TestHttpBrokenPipes(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ diff --git a/plugins/logger/plugin.go b/plugins/logger/plugin.go index 2937056c..64b77a64 100644 --- a/plugins/logger/plugin.go +++ b/plugins/logger/plugin.go @@ -7,8 +7,8 @@ import ( "go.uber.org/zap" ) -// ServiceName declares service name. -const ServiceName = "logs" +// PluginName declares plugin name. +const PluginName = "logs" // ZapLogger manages zap logger. type ZapLogger struct { @@ -19,12 +19,12 @@ type ZapLogger struct { // Init logger service. func (z *ZapLogger) Init(cfg config.Configurer) error { - err := cfg.UnmarshalKey(ServiceName, &z.cfg) + err := cfg.UnmarshalKey(PluginName, &z.cfg) if err != nil { return err } - err = cfg.UnmarshalKey(ServiceName, &z.channels) + err = cfg.UnmarshalKey(PluginName, &z.channels) if err != nil { return err } diff --git a/plugins/metrics/plugin.go b/plugins/metrics/plugin.go index 3fd42ee4..c115826b 100644 --- a/plugins/metrics/plugin.go +++ b/plugins/metrics/plugin.go @@ -18,8 +18,8 @@ import ( ) const ( - // ID declares public service name. - ServiceName = "metrics" + // PluginName declares plugin name. + PluginName = "metrics" // maxHeaderSize declares max header size for prometheus server maxHeaderSize = 1024 * 1024 * 100 // 104MB ) @@ -42,7 +42,7 @@ type Plugin struct { // Init service. func (m *Plugin) Init(cfg config.Configurer, log log.Logger) error { const op = errors.Op("Metrics Init") - err := cfg.UnmarshalKey(ServiceName, &m.cfg) + err := cfg.UnmarshalKey(PluginName, &m.cfg) if err != nil { return err } @@ -218,7 +218,7 @@ func (m *Plugin) AddStatProvider(name endure.Named, stat metrics.StatProvider) e // RPC interface satisfaction func (m *Plugin) Name() string { - return ServiceName + return PluginName } // RPC interface satisfaction diff --git a/plugins/metrics/tests/metrics_test.go b/plugins/metrics/tests/metrics_test.go index 4709d275..f9014c95 100644 --- a/plugins/metrics/tests/metrics_test.go +++ b/plugins/metrics/tests/metrics_test.go @@ -107,7 +107,7 @@ func TestMetricsInit(t *testing.T) { } func TestMetricsGaugeCollector(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) if err != nil { t.Fatal(err) } @@ -174,7 +174,7 @@ func TestMetricsGaugeCollector(t *testing.T) { } func TestMetricsDifferentRPCCalls(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) if err != nil { t.Fatal(err) } diff --git a/plugins/rpc/plugin.go b/plugins/rpc/plugin.go index 82b30563..24624d91 100755 --- a/plugins/rpc/plugin.go +++ b/plugins/rpc/plugin.go @@ -13,8 +13,8 @@ import ( "github.com/spiral/roadrunner/v2/plugins/config" ) -// ServiceName contains default service name. -const ServiceName = "RPC" +// PluginName contains default plugin name. +const PluginName = "RPC" type pluggable struct { service rpc_.RPCer @@ -34,11 +34,11 @@ type Plugin struct { // Init rpc service. Must return true if service is enabled. func (s *Plugin) Init(cfg config.Configurer, log log.Logger) error { const op = errors.Op("RPC Init") - if !cfg.Has(ServiceName) { + if !cfg.Has(PluginName) { return errors.E(op, errors.Disabled) } - err := cfg.UnmarshalKey(ServiceName, &s.cfg) + err := cfg.UnmarshalKey(PluginName, &s.cfg) if err != nil { return err } @@ -120,7 +120,7 @@ func (s *Plugin) Stop() error { // Name contains service name. func (s *Plugin) Name() string { - return ServiceName + return PluginName } // Depends declares services to collect for RPC. diff --git a/plugins/server/plugin.go b/plugins/server/plugin.go index 3411b007..ea6d42eb 100644 --- a/plugins/server/plugin.go +++ b/plugins/server/plugin.go @@ -15,7 +15,7 @@ import ( "github.com/spiral/roadrunner/v2/util" ) -const ServiceName = "server" +const PluginName = "server" // Plugin manages worker type Plugin struct { @@ -27,7 +27,7 @@ type Plugin struct { // Init application provider. func (server *Plugin) Init(cfg config.Configurer, log log.Logger) error { const op = errors.Op("Init") - err := cfg.UnmarshalKey(ServiceName, &server.cfg) + err := cfg.UnmarshalKey(PluginName, &server.cfg) if err != nil { return errors.E(op, errors.Init, err) } @@ -44,7 +44,7 @@ func (server *Plugin) Init(cfg config.Configurer, log log.Logger) error { // Name contains service name. func (server *Plugin) Name() string { - return ServiceName + return PluginName } func (server *Plugin) Serve() chan error { diff --git a/plugins/static/tests/static_plugin_test.go b/plugins/static/tests/static_plugin_test.go index d0403160..528d5eea 100644 --- a/plugins/static/tests/static_plugin_test.go +++ b/plugins/static/tests/static_plugin_test.go @@ -25,7 +25,7 @@ import ( ) func TestStaticPlugin(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -133,7 +133,7 @@ func serveStaticSample(t *testing.T) { } func TestStaticDisabled(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -206,7 +206,7 @@ func staticDisabled(t *testing.T) { } func TestStaticFilesDisabled(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ @@ -282,7 +282,7 @@ func staticFilesDisabled(t *testing.T) { } func TestStaticFilesForbid(t *testing.T) { - cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel), endure.Visualize(endure.StdOut, "")) + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) assert.NoError(t, err) cfg := &config.Viper{ diff --git a/static_pool_test.go b/static_pool_test.go index 88585318..747f26c4 100755 --- a/static_pool_test.go +++ b/static_pool_test.go @@ -196,7 +196,6 @@ func Test_StaticPool_Broken_Replace(t *testing.T) { p.Destroy(ctx) } -// func Test_StaticPool_Broken_FromOutside(t *testing.T) { ctx := context.Background() p, err := NewPool( |