diff options
Diffstat (limited to 'service/http')
-rw-r--r-- | service/http/config.go | 34 | ||||
-rw-r--r-- | service/http/service.go | 100 |
2 files changed, 134 insertions, 0 deletions
diff --git a/service/http/config.go b/service/http/config.go index 5454f124..e2f42626 100644 --- a/service/http/config.go +++ b/service/http/config.go @@ -33,10 +33,40 @@ type Config struct { // HTTP2 configuration HTTP2 *HTTP2Config + // Middlewares + Middlewares *MiddlewaresConfig + // Workers configures rr server and worker pool. Workers *roadrunner.ServerConfig } +type MiddlewaresConfig struct { + Headers *HeaderMiddlewareConfig + CORS *CORSMiddlewareConfig +} + +type CORSMiddlewareConfig struct { + AllowedOrigin string + AllowedMethods string + AllowedHeaders string + AllowCredentials *bool + ExposedHeaders string + MaxAge int +} + +type HeaderMiddlewareConfig struct { + CustomRequestHeaders map[string]string + CustomResponseHeaders map[string]string +} + +func (c *MiddlewaresConfig) EnableCORS() bool { + return c.CORS != nil +} + +func (c *MiddlewaresConfig) EnableHeaders() bool { + return c.Headers.CustomRequestHeaders != nil || c.Headers.CustomResponseHeaders != nil +} + type FCGIConfig struct { // Port and port to handle as http server. Address string @@ -73,6 +103,10 @@ func (c *Config) EnableHTTP() bool { return c.Address != "" } +func (c *Config) EnableMiddlewares() bool { + return c.Middlewares != nil +} + // EnableTLS returns true if rr must listen TLS connections. func (c *Config) EnableTLS() bool { return c.SSL.Key != "" || c.SSL.Cert != "" diff --git a/service/http/service.go b/service/http/service.go index 00d877ec..f394f6af 100644 --- a/service/http/service.go +++ b/service/http/service.go @@ -12,6 +12,7 @@ import ( "net/http" "net/http/fcgi" "net/url" + "strconv" "strings" "sync" ) @@ -95,6 +96,10 @@ func (s *Service) Serve() error { s.rr.Attach(s.controller) } + if s.cfg.EnableMiddlewares() { + s.initMiddlewares() + } + s.handler = &Handler{cfg: s.cfg, rr: s.rr} s.handler.Listen(s.throw) @@ -247,3 +252,98 @@ func (s *Service) tlsAddr(host string, forcePort bool) string { return host } + +func (s *Service) headersMiddleware(f http.HandlerFunc) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if s.cfg.Middlewares.Headers.CustomRequestHeaders != nil { + for k, v := range s.cfg.Middlewares.Headers.CustomRequestHeaders { + r.Header.Add(k, v) + } + } + + if s.cfg.Middlewares.Headers.CustomResponseHeaders != nil { + for k, v := range s.cfg.Middlewares.Headers.CustomResponseHeaders { + w.Header().Set(k, v) + } + } + + f(w, r) + } +} + +func handlePreflightRequest(w http.ResponseWriter, r *http.Request, options *CORSMiddlewareConfig) { + headers := w.Header() + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + + if options.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", options.AllowedOrigin) + } + + if options.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", options.AllowedHeaders) + } + + if options.AllowedMethods != "" { + headers.Set("Access-Control-Allow-Methods", options.AllowedMethods) + } + + if options.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*options.AllowCredentials)) + } + + if options.MaxAge > 0 { + headers.Set("Access-Control-Max-Age", strconv.Itoa(options.MaxAge)) + } + + w.WriteHeader(http.StatusOK); +} + +func addCORSHeaders(w http.ResponseWriter, r *http.Request, options *CORSMiddlewareConfig) { + headers := w.Header() + + headers.Add("Vary", "Origin") + + if options.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", options.AllowedOrigin) + } + + if options.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", options.AllowedHeaders) + } + + if options.ExposedHeaders != "" { + headers.Set("Access-Control-Expose-Headers", options.ExposedHeaders) + } + + if options.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*options.AllowCredentials)) + } +} + +func (s *Service) corsMiddleware(f http.HandlerFunc) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + handlePreflightRequest(w, r, s.cfg.Middlewares.CORS) + } else { + addCORSHeaders(w, r, s.cfg.Middlewares.CORS) + f(w, r) + } + } +} + +func (s *Service) initMiddlewares() error { + if s.cfg.Middlewares.EnableHeaders() { + s.AddMiddleware(s.headersMiddleware) + } + + if s.cfg.Middlewares.EnableCORS() { + s.AddMiddleware(s.corsMiddleware) + } + + return nil +} |