diff options
author | Wolfy-J <[email protected]> | 2018-07-08 10:23:30 -0700 |
---|---|---|
committer | Wolfy-J <[email protected]> | 2018-07-08 10:23:30 -0700 |
commit | 3c3a7801100f29c99a5e446646c818bf16ccd5f0 (patch) | |
tree | 4d558ae6e799b5e76e8e56b192cf057d47f3527f | |
parent | 466383c72d921aba728de40b60910741e561c1d1 (diff) |
minor attributes refactoring
-rw-r--r-- | cmd/rr/cmd/root.go | 2 | ||||
-rw-r--r-- | service/container.go | 6 | ||||
-rw-r--r-- | service/http/attributes.go | 69 | ||||
-rw-r--r-- | service/http/attributes/attributes.go | 74 | ||||
-rw-r--r-- | service/http/attributes/attributes_test.go | 67 | ||||
-rw-r--r-- | service/http/attributes_test.go | 67 | ||||
-rw-r--r-- | service/http/request.go | 3 | ||||
-rw-r--r-- | service/http/service.go | 6 |
8 files changed, 151 insertions, 143 deletions
diff --git a/cmd/rr/cmd/root.go b/cmd/rr/cmd/root.go index 1a21cfc9..086f518c 100644 --- a/cmd/rr/cmd/root.go +++ b/cmd/rr/cmd/root.go @@ -58,7 +58,7 @@ type ViperWrapper struct { v *viper.Viper } -// Get nested config section (sub-map), returns nil if section not found. +// get nested config section (sub-map), returns nil if section not found. func (w *ViperWrapper) Get(key string) service.Config { sub := w.v.Sub(key) if sub == nil { diff --git a/service/container.go b/service/container.go index 0987b1ae..12c5a4a1 100644 --- a/service/container.go +++ b/service/container.go @@ -10,7 +10,7 @@ import ( // Config provides ability to slice configuration sections and unmarshal configuration data into // given structure. type Config interface { - // Get nested config section (sub-map), returns nil if section not found. + // get nested config section (sub-map), returns nil if section not found. Get(service string) Config // Unmarshal unmarshal config data into given struct. @@ -28,7 +28,7 @@ type Container interface { // Check if svc has been registered. Has(service string) bool - // Get returns svc instance by it's name or nil if svc not found. Method returns current service status + // get returns svc instance by it's name or nil if svc not found. Method returns current service status // as second value. Get(service string) (svc Service, status int) @@ -81,7 +81,7 @@ func (c *container) Has(target string) bool { return false } -// Get returns svc instance by it's name or nil if svc not found. +// get returns svc instance by it's name or nil if svc not found. func (c *container) Get(target string) (svc Service, status int) { c.mu.Lock() defer c.mu.Unlock() diff --git a/service/http/attributes.go b/service/http/attributes.go deleted file mode 100644 index acea38a1..00000000 --- a/service/http/attributes.go +++ /dev/null @@ -1,69 +0,0 @@ -package http - -import ( - "context" - "net/http" - "errors" -) - -const contextKey = "psr:attributes" - -type attrs map[string]interface{} - -// InitAttributes returns request with new context and attribute bag. -func InitAttributes(r *http.Request) *http.Request { - return r.WithContext(context.WithValue(r.Context(), contextKey, attrs{})) -} - -// AllAttributes returns all context attributes. -func AllAttributes(r *http.Request) map[string]interface{} { - v := r.Context().Value(contextKey) - if v == nil { - return attrs{} - } - - return v.(attrs) -} - -// Get gets the value from request context. It replaces any existing -// values. -func GetAttribute(r *http.Request, key string) interface{} { - v := r.Context().Value(contextKey) - if v == nil { - return nil - } - - return v.(attrs).Get(key) -} - -// Set sets the key to value. It replaces any existing -// values. Context specific. -func SetAttribute(r *http.Request, key string, value interface{}) error { - v := r.Context().Value(contextKey) - if v == nil { - return errors.New("unable to find psr:attributes context value") - } - - v.(attrs).Set(key, value) - return nil -} - -// Get gets the value associated with the given key. -func (v attrs) Get(key string) interface{} { - if v == nil { - return "" - } - - return v[key] -} - -// Set sets the key to value. It replaces any existing -// values. -func (v attrs) Set(key string, value interface{}) { - v[key] = value -} - -// Del deletes the value associated with key. -func (v attrs) Del(key string) { - delete(v, key) -} diff --git a/service/http/attributes/attributes.go b/service/http/attributes/attributes.go new file mode 100644 index 00000000..94d0e9c1 --- /dev/null +++ b/service/http/attributes/attributes.go @@ -0,0 +1,74 @@ +package attributes + +import ( + "context" + "errors" + "net/http" +) + +const contextKey = "psr:attributes" + +type attrs map[string]interface{} + +func (v attrs) get(key string) interface{} { + if v == nil { + return "" + } + + return v[key] +} + +func (v attrs) set(key string, value interface{}) { + v[key] = value +} + +func (v attrs) del(key string) { + delete(v, key) +} + +// Init returns request with new context and attribute bag. +func Init(r *http.Request) *http.Request { + return r.WithContext(context.WithValue(r.Context(), contextKey, attrs{})) +} + +// All returns all context attributes. +func All(r *http.Request) map[string]interface{} { + v := r.Context().Value(contextKey) + if v == nil { + return attrs{} + } + + return v.(attrs) +} + +// get gets the value from request context. It replaces any existing +// values. +func Get(r *http.Request, key string) interface{} { + v := r.Context().Value(contextKey) + if v == nil { + return nil + } + + return v.(attrs).get(key) +} + +// set sets the key to value. It replaces any existing +// values. Context specific. +func Set(r *http.Request, key string, value interface{}) error { + v := r.Context().Value(contextKey) + if v == nil { + return errors.New("unable to find `psr:attributes` context key") + } + + v.(attrs).set(key, value) + return nil +} + +// Delete deletes values associated with attribute key. +func (v attrs) Delete(key string) { + if v == nil { + return + } + + v.del(key) +} diff --git a/service/http/attributes/attributes_test.go b/service/http/attributes/attributes_test.go new file mode 100644 index 00000000..a71d6542 --- /dev/null +++ b/service/http/attributes/attributes_test.go @@ -0,0 +1,67 @@ +package attributes + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestAllAttributes(t *testing.T) { + r := &http.Request{} + r = Init(r) + + Set(r, "key", "value") + + assert.Equal(t, All(r), map[string]interface{}{ + "key": "value", + }) +} + +func TestAllAttributesNone(t *testing.T) { + r := &http.Request{} + r = Init(r) + + assert.Equal(t, All(r), map[string]interface{}{}) +} + +func TestAllAttributesNone2(t *testing.T) { + r := &http.Request{} + + assert.Equal(t, All(r), map[string]interface{}{}) +} + +func TestGetAttribute(t *testing.T) { + r := &http.Request{} + r = Init(r) + + Set(r, "key", "value") + assert.Equal(t, Get(r, "key"), "value") +} + +func TestGetAttributeNone(t *testing.T) { + r := &http.Request{} + r = Init(r) + + assert.Equal(t, Get(r, "key"), nil) +} + +func TestGetAttributeNone2(t *testing.T) { + r := &http.Request{} + + assert.Equal(t, Get(r, "key"), nil) +} + +func TestSetAttribute(t *testing.T) { + r := &http.Request{} + r = Init(r) + + Set(r, "key", "value") + assert.Equal(t, Get(r, "key"), "value") +} + +func TestSetAttributeNone(t *testing.T) { + r := &http.Request{} + + Set(r, "key", "value") + assert.Equal(t, Get(r, "key"), nil) +} diff --git a/service/http/attributes_test.go b/service/http/attributes_test.go deleted file mode 100644 index aeb7fe74..00000000 --- a/service/http/attributes_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package http - -import ( - "testing" - "net/http" - "github.com/stretchr/testify/assert" -) - -func TestAllAttributes(t *testing.T) { - r := &http.Request{} - r = InitAttributes(r) - - SetAttribute(r, "key", "value") - - assert.Equal(t, AllAttributes(r), map[string]interface{}{ - "key": "value", - }) -} - -func TestAllAttributesNone(t *testing.T) { - r := &http.Request{} - r = InitAttributes(r) - - assert.Equal(t, AllAttributes(r), map[string]interface{}{}) -} - -func TestAllAttributesNone2(t *testing.T) { - r := &http.Request{} - - assert.Equal(t, AllAttributes(r), map[string]interface{}{}) -} - -func TestGetAttribute(t *testing.T) { - r := &http.Request{} - r = InitAttributes(r) - - SetAttribute(r, "key", "value") - assert.Equal(t, GetAttribute(r, "key"), "value") -} - -func TestGetAttributeNone(t *testing.T) { - r := &http.Request{} - r = InitAttributes(r) - - assert.Equal(t, GetAttribute(r, "key"), nil) -} - -func TestGetAttributeNone2(t *testing.T) { - r := &http.Request{} - - assert.Equal(t, GetAttribute(r, "key"), nil) -} - -func TestSetAttribute(t *testing.T) { - r := &http.Request{} - r = InitAttributes(r) - - SetAttribute(r, "key", "value") - assert.Equal(t, GetAttribute(r, "key"), "value") -} - -func TestSetAttributeNone(t *testing.T) { - r := &http.Request{} - - SetAttribute(r, "key", "value") - assert.Equal(t, GetAttribute(r, "key"), nil) -}
\ No newline at end of file diff --git a/service/http/request.go b/service/http/request.go index 21566416..912843e9 100644 --- a/service/http/request.go +++ b/service/http/request.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "strings" + "github.com/spiral/roadrunner/service/http/attributes" ) const ( @@ -60,7 +61,7 @@ func NewRequest(r *http.Request, cfg *UploadsConfig) (req *Request, err error) { Headers: r.Header, Cookies: make(map[string]string), RawQuery: r.URL.RawQuery, - Attributes: AllAttributes(r), + Attributes: attributes.All(r), } for _, c := range r.Cookies() { diff --git a/service/http/service.go b/service/http/service.go index 710cd60c..7405bf37 100644 --- a/service/http/service.go +++ b/service/http/service.go @@ -8,6 +8,7 @@ import ( "net/http" "sync" "sync/atomic" + "github.com/spiral/roadrunner/service/http/attributes" ) // ID contains default svc name. @@ -113,16 +114,17 @@ func (s *Service) Stop() { // middleware handles connection using set of mdws and rr PSR-7 server. func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { - r = InitAttributes(r) + r = attributes.Init(r) + // chaining middlewares f := s.srv.ServeHTTP for _, m := range s.mdws { f = m(f) } - f(w, r) } +// listener handles service, server and pool events. func (s *Service) listener(event int, ctx interface{}) { for _, l := range s.lsns { l(event, ctx) |