summaryrefslogtreecommitdiff
path: root/plugins/headers
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2020-11-30 16:19:35 +0300
committerValery Piashchynski <[email protected]>2020-11-30 16:19:35 +0300
commitf44faa14e6aaaf596da806dcbde062b7c4fb30ee (patch)
tree2710552f754d64389cd29c204999038c9574a9ba /plugins/headers
parent0a5116e9dcce76c8f845f4fdda41d448f3e38955 (diff)
Initial commit of headers plugin
Diffstat (limited to 'plugins/headers')
-rw-r--r--plugins/headers/config.go36
-rw-r--r--plugins/headers/plugin.go115
-rw-r--r--plugins/headers/tests/headers_plugin_test.go1
-rw-r--r--plugins/headers/tests/old.go362
4 files changed, 514 insertions, 0 deletions
diff --git a/plugins/headers/config.go b/plugins/headers/config.go
new file mode 100644
index 00000000..8d4e29c2
--- /dev/null
+++ b/plugins/headers/config.go
@@ -0,0 +1,36 @@
+package headers
+
+// Config declares headers service configuration.
+type Config struct {
+ Headers struct {
+ // CORS settings.
+ CORS *CORSConfig
+
+ // Request headers to add to every payload send to PHP.
+ Request map[string]string
+
+ // Response headers to add to every payload generated by PHP.
+ Response map[string]string
+ }
+}
+
+// CORSConfig headers configuration.
+type CORSConfig struct {
+ // AllowedOrigin: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
+ AllowedOrigin string
+
+ // AllowedHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
+ AllowedHeaders string
+
+ // AllowedMethods: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
+ AllowedMethods string
+
+ // AllowCredentials https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
+ AllowCredentials *bool
+
+ // ExposeHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
+ ExposedHeaders string
+
+ // MaxAge of CORS headers in seconds/
+ MaxAge int
+}
diff --git a/plugins/headers/plugin.go b/plugins/headers/plugin.go
new file mode 100644
index 00000000..b49575bb
--- /dev/null
+++ b/plugins/headers/plugin.go
@@ -0,0 +1,115 @@
+package headers
+
+import (
+ "net/http"
+ "strconv"
+
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+)
+
+// ID contains default service name.
+const PluginName = "headers"
+const RootPluginName = "http"
+
+// Service serves headers files. Potentially convert into middleware?
+type Plugin struct {
+ // server configuration (location, forbidden files and etc)
+ cfg *Config
+}
+
+// Init must return configure service and return true if service hasStatus enabled. Must return error in case of
+// misconfiguration. Services must not be used without proper configuration pushed first.
+func (s *Plugin) Init(cfg config.Configurer) error {
+ const op = errors.Op("headers plugin init")
+ err := cfg.UnmarshalKey(RootPluginName, s.cfg)
+ if err != nil {
+ return errors.E(op, err)
+ }
+
+ return nil
+}
+
+// middleware must return true if request/response pair is handled within the middleware.
+func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc {
+ // Define the http.HandlerFunc
+ return func(w http.ResponseWriter, r *http.Request) {
+
+ if s.cfg.Headers.Request != nil {
+ for k, v := range s.cfg.Headers.Request {
+ r.Header.Add(k, v)
+ }
+ }
+
+ if s.cfg.Headers.Response != nil {
+ for k, v := range s.cfg.Headers.Response {
+ w.Header().Set(k, v)
+ }
+ }
+
+ if s.cfg.Headers.CORS != nil {
+ if r.Method == http.MethodOptions {
+ s.preflightRequest(w, r)
+ return
+ }
+
+ s.corsHeaders(w, r)
+ }
+
+ next.ServeHTTP(w, r)
+ }
+}
+
+// configure OPTIONS response
+func (s *Plugin) preflightRequest(w http.ResponseWriter, r *http.Request) {
+ headers := w.Header()
+
+ headers.Add("Vary", "Origin")
+ headers.Add("Vary", "Access-Control-Request-Method")
+ headers.Add("Vary", "Access-Control-Request-Headers")
+
+ if s.cfg.Headers.CORS.AllowedOrigin != "" {
+ headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin)
+ }
+
+ if s.cfg.Headers.CORS.AllowedHeaders != "" {
+ headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders)
+ }
+
+ if s.cfg.Headers.CORS.AllowedMethods != "" {
+ headers.Set("Access-Control-Allow-Methods", s.cfg.Headers.CORS.AllowedMethods)
+ }
+
+ if s.cfg.Headers.CORS.AllowCredentials != nil {
+ headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials))
+ }
+
+ if s.cfg.Headers.CORS.MaxAge > 0 {
+ headers.Set("Access-Control-Max-Age", strconv.Itoa(s.cfg.Headers.CORS.MaxAge))
+ }
+
+ w.WriteHeader(http.StatusOK)
+}
+
+// configure CORS headers
+func (s *Plugin) corsHeaders(w http.ResponseWriter, r *http.Request) {
+ headers := w.Header()
+
+ headers.Add("Vary", "Origin")
+
+ if s.cfg.Headers.CORS.AllowedOrigin != "" {
+ headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin)
+ }
+
+ if s.cfg.Headers.CORS.AllowedHeaders != "" {
+ headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders)
+ }
+
+ if s.cfg.Headers.CORS.ExposedHeaders != "" {
+ headers.Set("Access-Control-Expose-Headers", s.cfg.Headers.CORS.ExposedHeaders)
+ }
+
+ if s.cfg.Headers.CORS.AllowCredentials != nil {
+ headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials))
+ }
+}
diff --git a/plugins/headers/tests/headers_plugin_test.go b/plugins/headers/tests/headers_plugin_test.go
new file mode 100644
index 00000000..ca8701d2
--- /dev/null
+++ b/plugins/headers/tests/headers_plugin_test.go
@@ -0,0 +1 @@
+package tests
diff --git a/plugins/headers/tests/old.go b/plugins/headers/tests/old.go
new file mode 100644
index 00000000..b3d95aa5
--- /dev/null
+++ b/plugins/headers/tests/old.go
@@ -0,0 +1,362 @@
+package tests
+
+//type mockCfg struct{ cfg string }
+//
+//func (cfg *mockCfg) Get(name string) service.Config { return nil }
+//func (cfg *mockCfg) Unmarshal(out interface{}) error {
+// j := json.ConfigCompatibleWithStandardLibrary
+// return j.Unmarshal([]byte(cfg.cfg), out)
+//}
+//
+//func Test_Config_Hydrate_Error1(t *testing.T) {
+// cfg := &mockCfg{`{"request": {"From": "Something"}}`}
+// c := &Config{}
+//
+// assert.NoError(t, c.Hydrate(cfg))
+//}
+//
+//func Test_Config_Hydrate_Error2(t *testing.T) {
+// cfg := &mockCfg{`{"dir": "/dir/"`}
+// c := &Config{}
+//
+// assert.Error(t, c.Hydrate(cfg))
+//}
+//-------------------------------------------------------------------------------
+//import (
+// "github.com/cenkalti/backoff/v4"
+// json "github.com/json-iterator/go"
+// "github.com/sirupsen/logrus"
+// "github.com/sirupsen/logrus/hooks/test"
+// "github.com/spiral/roadrunner/service"
+// rrhttp "github.com/spiral/roadrunner/service/http"
+// "github.com/stretchr/testify/assert"
+// "io/ioutil"
+// "net/http"
+// "testing"
+// "time"
+//)
+//
+//type testCfg struct {
+// httpCfg string
+// headers string
+// target string
+//}
+//
+//func (cfg *testCfg) Get(name string) service.Config {
+// if name == rrhttp.ID {
+// return &testCfg{target: cfg.httpCfg}
+// }
+//
+// if name == ID {
+// return &testCfg{target: cfg.headers}
+// }
+// return nil
+//}
+//
+//func (cfg *testCfg) Unmarshal(out interface{}) error {
+// return json.Unmarshal([]byte(cfg.target), out)
+//}
+//
+//func Test_RequestHeaders(t *testing.T) {
+// bkoff := backoff.NewExponentialBackOff()
+// bkoff.MaxElapsedTime = time.Second * 15
+//
+// err := backoff.Retry(func() error {
+// logger, _ := test.NewNullLogger()
+// logger.SetLevel(logrus.DebugLevel)
+//
+// c := service.NewContainer(logger)
+// c.Register(rrhttp.ID, &rrhttp.Service{})
+// c.Register(ID, &Service{})
+//
+// assert.NoError(t, c.Init(&testCfg{
+// headers: `{"request":{"input": "custom-header"}}`,
+// httpCfg: `{
+// "enable": true,
+// "address": ":6078",
+// "maxRequestSize": 1024,
+// "workers":{
+// "command": "php ../../tests/http/client.php header pipes",
+// "relay": "pipes",
+// "pool": {
+// "numWorkers": 1,
+// "allocateTimeout": 10000000,
+// "destroyTimeout": 10000000
+// }
+// }
+// }`}))
+//
+// go func() {
+// err := c.Serve()
+// if err != nil {
+// t.Errorf("error during Serve: error %v", err)
+// }
+// }()
+//
+// time.Sleep(time.Millisecond * 100)
+// defer c.Stop()
+//
+// req, err := http.NewRequest("GET", "http://localhost:6078?hello=value", nil)
+// if err != nil {
+// return err
+// }
+//
+// r, err := http.DefaultClient.Do(req)
+// if err != nil {
+// return err
+// }
+//
+// b, err := ioutil.ReadAll(r.Body)
+// if err != nil {
+// return err
+// }
+//
+// assert.Equal(t, 200, r.StatusCode)
+// assert.Equal(t, "CUSTOM-HEADER", string(b))
+//
+// err = r.Body.Close()
+// if err != nil {
+// return err
+// }
+//
+// return nil
+// }, bkoff)
+//
+// if err != nil {
+// t.Fatal(err)
+// }
+//}
+//
+//func Test_ResponseHeaders(t *testing.T) {
+// bkoff := backoff.NewExponentialBackOff()
+// bkoff.MaxElapsedTime = time.Second * 15
+//
+// err := backoff.Retry(func() error {
+// logger, _ := test.NewNullLogger()
+// logger.SetLevel(logrus.DebugLevel)
+//
+// c := service.NewContainer(logger)
+// c.Register(rrhttp.ID, &rrhttp.Service{})
+// c.Register(ID, &Service{})
+//
+// assert.NoError(t, c.Init(&testCfg{
+// headers: `{"response":{"output": "output-header"},"request":{"input": "custom-header"}}`,
+// httpCfg: `{
+// "enable": true,
+// "address": ":6079",
+// "maxRequestSize": 1024,
+// "workers":{
+// "command": "php ../../tests/http/client.php header pipes",
+// "relay": "pipes",
+// "pool": {
+// "numWorkers": 1,
+// "allocateTimeout": 10000000,
+// "destroyTimeout": 10000000
+// }
+// }
+// }`}))
+//
+// go func() {
+// err := c.Serve()
+// if err != nil {
+// t.Errorf("error during the Serve: error %v", err)
+// }
+// }()
+// time.Sleep(time.Millisecond * 100)
+// defer c.Stop()
+//
+// req, err := http.NewRequest("GET", "http://localhost:6079?hello=value", nil)
+// if err != nil {
+// return err
+// }
+//
+// r, err := http.DefaultClient.Do(req)
+// if err != nil {
+// return err
+// }
+//
+// assert.Equal(t, "output-header", r.Header.Get("output"))
+//
+// b, err := ioutil.ReadAll(r.Body)
+// if err != nil {
+// return err
+// }
+// assert.Equal(t, 200, r.StatusCode)
+// assert.Equal(t, "CUSTOM-HEADER", string(b))
+//
+// err = r.Body.Close()
+// if err != nil {
+// return err
+// }
+//
+// return nil
+// }, bkoff)
+//
+// if err != nil {
+// t.Fatal(err)
+// }
+//}
+//
+//func TestCORS_OPTIONS(t *testing.T) {
+// bkoff := backoff.NewExponentialBackOff()
+// bkoff.MaxElapsedTime = time.Second * 15
+//
+// err := backoff.Retry(func() error {
+// logger, _ := test.NewNullLogger()
+// logger.SetLevel(logrus.DebugLevel)
+//
+// c := service.NewContainer(logger)
+// c.Register(rrhttp.ID, &rrhttp.Service{})
+// c.Register(ID, &Service{})
+//
+// assert.NoError(t, c.Init(&testCfg{
+// headers: `{
+//"cors":{
+// "allowedOrigin": "*",
+// "allowedHeaders": "*",
+// "allowedMethods": "GET,POST,PUT,DELETE",
+// "allowCredentials": true,
+// "exposedHeaders": "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma",
+// "maxAge": 600
+//}
+//}`,
+// httpCfg: `{
+// "enable": true,
+// "address": ":16379",
+// "maxRequestSize": 1024,
+// "workers":{
+// "command": "php ../../tests/http/client.php headers pipes",
+// "relay": "pipes",
+// "pool": {
+// "numWorkers": 1,
+// "allocateTimeout": 10000000,
+// "destroyTimeout": 10000000
+// }
+// }
+// }`}))
+//
+// go func() {
+// err := c.Serve()
+// if err != nil {
+// t.Errorf("error during the Serve: error %v", err)
+// }
+// }()
+// time.Sleep(time.Millisecond * 100)
+// defer c.Stop()
+//
+// req, err := http.NewRequest("OPTIONS", "http://localhost:16379", nil)
+// if err != nil {
+// return err
+// }
+//
+// r, err := http.DefaultClient.Do(req)
+// if err != nil {
+// return err
+// }
+//
+// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials"))
+// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers"))
+// assert.Equal(t, "GET,POST,PUT,DELETE", r.Header.Get("Access-Control-Allow-Methods"))
+// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin"))
+// assert.Equal(t, "600", r.Header.Get("Access-Control-Max-Age"))
+// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials"))
+//
+// _, err = ioutil.ReadAll(r.Body)
+// if err != nil {
+// return err
+// }
+// assert.Equal(t, 200, r.StatusCode)
+//
+// err = r.Body.Close()
+// if err != nil {
+// return err
+// }
+//
+// return nil
+// }, bkoff)
+//
+// if err != nil {
+// t.Fatal(err)
+// }
+//}
+//
+//func TestCORS_Pass(t *testing.T) {
+// bkoff := backoff.NewExponentialBackOff()
+// bkoff.MaxElapsedTime = time.Second * 15
+//
+// err := backoff.Retry(func() error {
+// logger, _ := test.NewNullLogger()
+// logger.SetLevel(logrus.DebugLevel)
+//
+// c := service.NewContainer(logger)
+// c.Register(rrhttp.ID, &rrhttp.Service{})
+// c.Register(ID, &Service{})
+//
+// assert.NoError(t, c.Init(&testCfg{
+// headers: `{
+//"cors":{
+// "allowedOrigin": "*",
+// "allowedHeaders": "*",
+// "allowedMethods": "GET,POST,PUT,DELETE",
+// "allowCredentials": true,
+// "exposedHeaders": "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma",
+// "maxAge": 600
+//}
+//}`,
+// httpCfg: `{
+// "enable": true,
+// "address": ":6672",
+// "maxRequestSize": 1024,
+// "workers":{
+// "command": "php ../../tests/http/client.php headers pipes",
+// "relay": "pipes",
+// "pool": {
+// "numWorkers": 1,
+// "allocateTimeout": 10000000,
+// "destroyTimeout": 10000000
+// }
+// }
+// }`}))
+//
+// go func() {
+// err := c.Serve()
+// if err != nil {
+// t.Errorf("error during the Serve: error %v", err)
+// }
+// }()
+// time.Sleep(time.Millisecond * 100)
+// defer c.Stop()
+//
+// req, err := http.NewRequest("GET", "http://localhost:6672", nil)
+// if err != nil {
+// return err
+// }
+//
+// r, err := http.DefaultClient.Do(req)
+// if err != nil {
+// return err
+// }
+//
+// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials"))
+// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers"))
+// assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin"))
+// assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials"))
+//
+// _, err = ioutil.ReadAll(r.Body)
+// if err != nil {
+// return err
+// }
+// assert.Equal(t, 200, r.StatusCode)
+//
+// err = r.Body.Close()
+// if err != nil {
+// return err
+// }
+//
+// return nil
+// }, bkoff)
+//
+// if err != nil {
+// t.Fatal(err)
+// }
+//} \ No newline at end of file