summaryrefslogtreecommitdiff
path: root/service/headers/service.go
diff options
context:
space:
mode:
Diffstat (limited to 'service/headers/service.go')
-rw-r--r--service/headers/service.go115
1 files changed, 115 insertions, 0 deletions
diff --git a/service/headers/service.go b/service/headers/service.go
new file mode 100644
index 00000000..9060e9ff
--- /dev/null
+++ b/service/headers/service.go
@@ -0,0 +1,115 @@
+package headers
+
+import (
+ rrhttp "github.com/spiral/roadrunner/service/http"
+ "net/http"
+ "strconv"
+)
+
+const (
+ // ID contains default service name.
+ ID = "headers"
+)
+
+// Service serves static files. Potentially convert into middleware?
+type Service struct {
+ // server configuration (location, forbidden files and etc)
+ cfg *Config
+}
+
+// Init must return configure service and return true if service hasStatus enabled. Must return error in case of
+// misconfiguration. Services must not be used without proper configuration pushed first.
+func (s *Service) Init(cfg *Config, r *rrhttp.Service) (bool, error) {
+ if r == nil {
+ return false, nil
+ }
+
+ s.cfg = cfg
+ r.AddMiddleware(s.middleware)
+
+ return true, nil
+}
+
+// middleware must return true if request/response pair is handled within the middleware.
+func (s *Service) middleware(f http.HandlerFunc) http.HandlerFunc {
+ // Define the http.HandlerFunc
+ return func(w http.ResponseWriter, r *http.Request) {
+
+ if s.cfg.Request != nil {
+ for k, v := range s.cfg.Request {
+ r.Header.Add(k, v)
+ }
+ }
+
+ if s.cfg.Response != nil {
+ for k, v := range s.cfg.Response {
+ w.Header().Set(k, v)
+ }
+ }
+
+ if s.cfg.CORS != nil {
+ if r.Method == http.MethodOptions {
+ s.preflightRequest(w, r)
+ return
+ }
+
+ s.corsHeaders(w, r)
+ }
+
+ f(w, r)
+ }
+}
+
+// configure OPTIONS response
+func (s *Service) preflightRequest(w http.ResponseWriter, r *http.Request) {
+ headers := w.Header()
+
+ headers.Add("Vary", "Origin")
+ headers.Add("Vary", "Access-Control-Request-Method")
+ headers.Add("Vary", "Access-Control-Request-Headers")
+
+ if s.cfg.CORS.AllowedOrigin != "" {
+ headers.Set("Access-Control-Allow-Origin", s.cfg.CORS.AllowedOrigin)
+ }
+
+ if s.cfg.CORS.AllowedHeaders != "" {
+ headers.Set("Access-Control-Allow-Headers", s.cfg.CORS.AllowedHeaders)
+ }
+
+ if s.cfg.CORS.AllowedMethods != "" {
+ headers.Set("Access-Control-Allow-Methods", s.cfg.CORS.AllowedMethods)
+ }
+
+ if s.cfg.CORS.AllowCredentials != nil {
+ headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.CORS.AllowCredentials))
+ }
+
+ if s.cfg.CORS.MaxAge > 0 {
+ headers.Set("Access-Control-Max-Age", strconv.Itoa(s.cfg.CORS.MaxAge))
+ }
+
+ w.WriteHeader(http.StatusOK)
+}
+
+// configure CORS headers
+func (s *Service) corsHeaders(w http.ResponseWriter, r *http.Request) {
+ headers := w.Header()
+
+ headers.Add("Vary", "Origin")
+
+ if s.cfg.CORS.AllowedOrigin != "" {
+ headers.Set("Access-Control-Allow-Origin", s.cfg.CORS.AllowedOrigin)
+ }
+
+ if s.cfg.CORS.AllowedHeaders != "" {
+ headers.Set("Access-Control-Allow-Headers", s.cfg.CORS.AllowedHeaders)
+ }
+
+ if s.cfg.CORS.ExposedHeaders != "" {
+ headers.Set("Access-Control-Expose-Headers", s.cfg.CORS.ExposedHeaders)
+ }
+
+ if s.cfg.CORS.AllowCredentials != nil {
+ headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.CORS.AllowCredentials))
+ }
+}