diff options
author | Valery Piashchynski <[email protected]> | 2020-11-30 16:19:35 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2020-11-30 16:19:35 +0300 |
commit | f44faa14e6aaaf596da806dcbde062b7c4fb30ee (patch) | |
tree | 2710552f754d64389cd29c204999038c9574a9ba /plugins | |
parent | 0a5116e9dcce76c8f845f4fdda41d448f3e38955 (diff) |
Initial commit of headers plugin
Diffstat (limited to 'plugins')
-rw-r--r-- | plugins/headers/config.go | 36 | ||||
-rw-r--r-- | plugins/headers/plugin.go | 115 | ||||
-rw-r--r-- | plugins/headers/tests/headers_plugin_test.go | 1 | ||||
-rw-r--r-- | plugins/headers/tests/old.go | 362 |
4 files changed, 514 insertions, 0 deletions
diff --git a/plugins/headers/config.go b/plugins/headers/config.go new file mode 100644 index 00000000..8d4e29c2 --- /dev/null +++ b/plugins/headers/config.go @@ -0,0 +1,36 @@ +package headers + +// Config declares headers service configuration. +type Config struct { + Headers struct { + // CORS settings. + CORS *CORSConfig + + // Request headers to add to every payload send to PHP. + Request map[string]string + + // Response headers to add to every payload generated by PHP. + Response map[string]string + } +} + +// CORSConfig headers configuration. +type CORSConfig struct { + // AllowedOrigin: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowedOrigin string + + // AllowedHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowedHeaders string + + // AllowedMethods: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowedMethods string + + // AllowCredentials https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials *bool + + // ExposeHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + ExposedHeaders string + + // MaxAge of CORS headers in seconds/ + MaxAge int +} diff --git a/plugins/headers/plugin.go b/plugins/headers/plugin.go new file mode 100644 index 00000000..b49575bb --- /dev/null +++ b/plugins/headers/plugin.go @@ -0,0 +1,115 @@ +package headers + +import ( + "net/http" + "strconv" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// ID contains default service name. +const PluginName = "headers" +const RootPluginName = "http" + +// Service serves headers files. Potentially convert into middleware? +type Plugin struct { + // server configuration (location, forbidden files and etc) + cfg *Config +} + +// Init must return configure service and return true if service hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer) error { + const op = errors.Op("headers plugin init") + err := cfg.UnmarshalKey(RootPluginName, s.cfg) + if err != nil { + return errors.E(op, err) + } + + return nil +} + +// middleware must return true if request/response pair is handled within the middleware. +func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + + if s.cfg.Headers.Request != nil { + for k, v := range s.cfg.Headers.Request { + r.Header.Add(k, v) + } + } + + if s.cfg.Headers.Response != nil { + for k, v := range s.cfg.Headers.Response { + w.Header().Set(k, v) + } + } + + if s.cfg.Headers.CORS != nil { + if r.Method == http.MethodOptions { + s.preflightRequest(w, r) + return + } + + s.corsHeaders(w, r) + } + + next.ServeHTTP(w, r) + } +} + +// configure OPTIONS response +func (s *Plugin) preflightRequest(w http.ResponseWriter, r *http.Request) { + headers := w.Header() + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.AllowedMethods != "" { + headers.Set("Access-Control-Allow-Methods", s.cfg.Headers.CORS.AllowedMethods) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } + + if s.cfg.Headers.CORS.MaxAge > 0 { + headers.Set("Access-Control-Max-Age", strconv.Itoa(s.cfg.Headers.CORS.MaxAge)) + } + + w.WriteHeader(http.StatusOK) +} + +// configure CORS headers +func (s *Plugin) corsHeaders(w http.ResponseWriter, r *http.Request) { + headers := w.Header() + + headers.Add("Vary", "Origin") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.ExposedHeaders != "" { + headers.Set("Access-Control-Expose-Headers", s.cfg.Headers.CORS.ExposedHeaders) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } +} diff --git a/plugins/headers/tests/headers_plugin_test.go b/plugins/headers/tests/headers_plugin_test.go new file mode 100644 index 00000000..ca8701d2 --- /dev/null +++ b/plugins/headers/tests/headers_plugin_test.go @@ -0,0 +1 @@ +package tests diff --git a/plugins/headers/tests/old.go b/plugins/headers/tests/old.go new file mode 100644 index 00000000..b3d95aa5 --- /dev/null +++ b/plugins/headers/tests/old.go @@ -0,0 +1,362 @@ +package tests + +//type mockCfg struct{ cfg string } +// +//func (cfg *mockCfg) Get(name string) service.Config { return nil } +//func (cfg *mockCfg) Unmarshal(out interface{}) error { +// j := json.ConfigCompatibleWithStandardLibrary +// return j.Unmarshal([]byte(cfg.cfg), out) +//} +// +//func Test_Config_Hydrate_Error1(t *testing.T) { +// cfg := &mockCfg{`{"request": {"From": "Something"}}`} +// c := &Config{} +// +// assert.NoError(t, c.Hydrate(cfg)) +//} +// +//func Test_Config_Hydrate_Error2(t *testing.T) { +// cfg := &mockCfg{`{"dir": "/dir/"`} +// c := &Config{} +// +// assert.Error(t, c.Hydrate(cfg)) +//} +//------------------------------------------------------------------------------- +//import ( +// "github.com/cenkalti/backoff/v4" +// json "github.com/json-iterator/go" +// "github.com/sirupsen/logrus" +// "github.com/sirupsen/logrus/hooks/test" +// "github.com/spiral/roadrunner/service" +// rrhttp "github.com/spiral/roadrunner/service/http" +// "github.com/stretchr/testify/assert" +// "io/ioutil" +// "net/http" +// "testing" +// "time" +//) +// +//type testCfg struct { +// httpCfg string +// headers string +// target string +//} +// +//func (cfg *testCfg) Get(name string) service.Config { +// if name == rrhttp.ID { +// return &testCfg{target: cfg.httpCfg} +// } +// +// if name == ID { +// return &testCfg{target: cfg.headers} +// } +// return nil +//} +// +//func (cfg *testCfg) Unmarshal(out interface{}) error { +// return json.Unmarshal([]byte(cfg.target), out) +//} +// +//func Test_RequestHeaders(t *testing.T) { +// bkoff := backoff.NewExponentialBackOff() +// bkoff.MaxElapsedTime = time.Second * 15 +// +// err := backoff.Retry(func() error { +// logger, _ := test.NewNullLogger() +// logger.SetLevel(logrus.DebugLevel) +// +// c := service.NewContainer(logger) +// c.Register(rrhttp.ID, &rrhttp.Service{}) +// c.Register(ID, &Service{}) +// +// assert.NoError(t, c.Init(&testCfg{ +// headers: `{"request":{"input": "custom-header"}}`, +// httpCfg: `{ +// "enable": true, +// "address": ":6078", +// "maxRequestSize": 1024, +// "workers":{ +// "command": "php ../../tests/http/client.php header pipes", +// "relay": "pipes", +// "pool": { +// "numWorkers": 1, +// "allocateTimeout": 10000000, +// "destroyTimeout": 10000000 +// } +// } +// }`})) +// +// go func() { +// err := c.Serve() +// if err != nil { +// t.Errorf("error during Serve: error %v", err) +// } +// }() +// +// time.Sleep(time.Millisecond * 100) +// defer c.Stop() +// +// req, err := http.NewRequest("GET", "http://localhost:6078?hello=value", nil) +// if err != nil { +// return err +// } +// +// r, err := http.DefaultClient.Do(req) +// if err != nil { +// return err +// } +// +// b, err := ioutil.ReadAll(r.Body) +// if err != nil { +// return err +// } +// +// assert.Equal(t, 200, r.StatusCode) +// assert.Equal(t, "CUSTOM-HEADER", string(b)) +// +// err = r.Body.Close() +// if err != nil { +// return err +// } +// +// return nil +// }, bkoff) +// +// if err != nil { +// t.Fatal(err) +// } +//} +// +//func Test_ResponseHeaders(t *testing.T) { +// bkoff := backoff.NewExponentialBackOff() +// bkoff.MaxElapsedTime = time.Second * 15 +// +// err := backoff.Retry(func() error { +// logger, _ := test.NewNullLogger() +// logger.SetLevel(logrus.DebugLevel) +// +// c := service.NewContainer(logger) +// c.Register(rrhttp.ID, &rrhttp.Service{}) +// c.Register(ID, &Service{}) +// +// assert.NoError(t, c.Init(&testCfg{ +// headers: `{"response":{"output": "output-header"},"request":{"input": "custom-header"}}`, +// httpCfg: `{ +// "enable": true, +// "address": ":6079", +// "maxRequestSize": 1024, +// "workers":{ +// "command": "php ../../tests/http/client.php header pipes", +// "relay": "pipes", +// "pool": { +// "numWorkers": 1, +// "allocateTimeout": 10000000, +// "destroyTimeout": 10000000 +// } +// } +// }`})) +// +// go func() { +// err := c.Serve() +// if err != nil { +// t.Errorf("error during the Serve: error %v", err) +// } +// }() +// time.Sleep(time.Millisecond * 100) +// defer c.Stop() +// +// req, err := http.NewRequest("GET", "http://localhost:6079?hello=value", nil) +// if err != nil { +// return err +// } +// +// r, err := http.DefaultClient.Do(req) +// if err != nil { +// return err +// } +// +// assert.Equal(t, "output-header", r.Header.Get("output")) +// +// b, err := ioutil.ReadAll(r.Body) +// if err != nil { +// return err +// } +// assert.Equal(t, 200, r.StatusCode) +// assert.Equal(t, "CUSTOM-HEADER", string(b)) +// +// err = r.Body.Close() +// if err != nil { +// return err +// } +// +// return nil +// }, bkoff) +// +// if err != nil { +// t.Fatal(err) +// } +//} +// +//func TestCORS_OPTIONS(t *testing.T) { +// bkoff := backoff.NewExponentialBackOff() +// bkoff.MaxElapsedTime = time.Second * 15 +// +// err := backoff.Retry(func() error { +// logger, _ := test.NewNullLogger() +// logger.SetLevel(logrus.DebugLevel) +// +// c := service.NewContainer(logger) +// c.Register(rrhttp.ID, &rrhttp.Service{}) +// c.Register(ID, &Service{}) +// +// assert.NoError(t, c.Init(&testCfg{ +// headers: `{ +//"cors":{ +// "allowedOrigin": "*", +// "allowedHeaders": "*", +// "allowedMethods": "GET,POST,PUT,DELETE", +// "allowCredentials": true, +// "exposedHeaders": "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma", +// "maxAge": 600 +//} +//}`, +// httpCfg: `{ +// "enable": true, +// "address": ":16379", +// "maxRequestSize": 1024, +// "workers":{ +// "command": "php ../../tests/http/client.php headers pipes", +// "relay": "pipes", +// "pool": { +// "numWorkers": 1, +// "allocateTimeout": 10000000, +// "destroyTimeout": 10000000 +// } +// } +// }`})) +// +// go func() { +// err := c.Serve() +// if err != nil { +// t.Errorf("error during the Serve: error %v", err) +// } +// }() +// time.Sleep(time.Millisecond * 100) +// defer c.Stop() +// +// req, err := http.NewRequest("OPTIONS", "http://localhost:16379", nil) +// if err != nil { +// return err +// } +// +// r, err := http.DefaultClient.Do(req) +// if err != nil { +// return err +// } +// +// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) +// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers")) +// assert.Equal(t, "GET,POST,PUT,DELETE", r.Header.Get("Access-Control-Allow-Methods")) +// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin")) +// assert.Equal(t, "600", r.Header.Get("Access-Control-Max-Age")) +// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) +// +// _, err = ioutil.ReadAll(r.Body) +// if err != nil { +// return err +// } +// assert.Equal(t, 200, r.StatusCode) +// +// err = r.Body.Close() +// if err != nil { +// return err +// } +// +// return nil +// }, bkoff) +// +// if err != nil { +// t.Fatal(err) +// } +//} +// +//func TestCORS_Pass(t *testing.T) { +// bkoff := backoff.NewExponentialBackOff() +// bkoff.MaxElapsedTime = time.Second * 15 +// +// err := backoff.Retry(func() error { +// logger, _ := test.NewNullLogger() +// logger.SetLevel(logrus.DebugLevel) +// +// c := service.NewContainer(logger) +// c.Register(rrhttp.ID, &rrhttp.Service{}) +// c.Register(ID, &Service{}) +// +// assert.NoError(t, c.Init(&testCfg{ +// headers: `{ +//"cors":{ +// "allowedOrigin": "*", +// "allowedHeaders": "*", +// "allowedMethods": "GET,POST,PUT,DELETE", +// "allowCredentials": true, +// "exposedHeaders": "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma", +// "maxAge": 600 +//} +//}`, +// httpCfg: `{ +// "enable": true, +// "address": ":6672", +// "maxRequestSize": 1024, +// "workers":{ +// "command": "php ../../tests/http/client.php headers pipes", +// "relay": "pipes", +// "pool": { +// "numWorkers": 1, +// "allocateTimeout": 10000000, +// "destroyTimeout": 10000000 +// } +// } +// }`})) +// +// go func() { +// err := c.Serve() +// if err != nil { +// t.Errorf("error during the Serve: error %v", err) +// } +// }() +// time.Sleep(time.Millisecond * 100) +// defer c.Stop() +// +// req, err := http.NewRequest("GET", "http://localhost:6672", nil) +// if err != nil { +// return err +// } +// +// r, err := http.DefaultClient.Do(req) +// if err != nil { +// return err +// } +// +// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) +// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers")) +// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin")) +// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) +// +// _, err = ioutil.ReadAll(r.Body) +// if err != nil { +// return err +// } +// assert.Equal(t, 200, r.StatusCode) +// +// err = r.Body.Close() +// if err != nil { +// return err +// } +// +// return nil +// }, bkoff) +// +// if err != nil { +// t.Fatal(err) +// } +//}
\ No newline at end of file |