diff options
Diffstat (limited to 'plugins/http/attributes')
-rw-r--r-- | plugins/http/attributes/attributes.go | 76 | ||||
-rw-r--r-- | plugins/http/attributes/attributes_test.go | 79 |
2 files changed, 155 insertions, 0 deletions
diff --git a/plugins/http/attributes/attributes.go b/plugins/http/attributes/attributes.go new file mode 100644 index 00000000..77d6ea69 --- /dev/null +++ b/plugins/http/attributes/attributes.go @@ -0,0 +1,76 @@ +package attributes + +import ( + "context" + "errors" + "net/http" +) + +type attrKey int + +const contextKey attrKey = iota + +type attrs map[string]interface{} + +func (v attrs) get(key string) interface{} { + if v == nil { + return "" + } + + return v[key] +} + +func (v attrs) set(key string, value interface{}) { + v[key] = value +} + +func (v attrs) del(key string) { + delete(v, key) +} + +// Init returns request with new context and attribute bag. +func Init(r *http.Request) *http.Request { + return r.WithContext(context.WithValue(r.Context(), contextKey, attrs{})) +} + +// All returns all context attributes. +func All(r *http.Request) map[string]interface{} { + v := r.Context().Value(contextKey) + if v == nil { + return attrs{} + } + + return v.(attrs) +} + +// Get gets the value from request context. It replaces any existing +// values. +func Get(r *http.Request, key string) interface{} { + v := r.Context().Value(contextKey) + if v == nil { + return nil + } + + return v.(attrs).get(key) +} + +// Set sets the key to value. It replaces any existing +// values. Context specific. +func Set(r *http.Request, key string, value interface{}) error { + v := r.Context().Value(contextKey) + if v == nil { + return errors.New("unable to find `psr:attributes` context key") + } + + v.(attrs).set(key, value) + return nil +} + +// Delete deletes values associated with attribute key. +func (v attrs) Delete(key string) { + if v == nil { + return + } + + v.del(key) +} diff --git a/plugins/http/attributes/attributes_test.go b/plugins/http/attributes/attributes_test.go new file mode 100644 index 00000000..2360fd12 --- /dev/null +++ b/plugins/http/attributes/attributes_test.go @@ -0,0 +1,79 @@ +package attributes + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestAllAttributes(t *testing.T) { + r := &http.Request{} + r = Init(r) + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + + assert.Equal(t, All(r), map[string]interface{}{ + "key": "value", + }) +} + +func TestAllAttributesNone(t *testing.T) { + r := &http.Request{} + r = Init(r) + + assert.Equal(t, All(r), map[string]interface{}{}) +} + +func TestAllAttributesNone2(t *testing.T) { + r := &http.Request{} + + assert.Equal(t, All(r), map[string]interface{}{}) +} + +func TestGetAttribute(t *testing.T) { + r := &http.Request{} + r = Init(r) + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + assert.Equal(t, Get(r, "key"), "value") +} + +func TestGetAttributeNone(t *testing.T) { + r := &http.Request{} + r = Init(r) + + assert.Equal(t, Get(r, "key"), nil) +} + +func TestGetAttributeNone2(t *testing.T) { + r := &http.Request{} + + assert.Equal(t, Get(r, "key"), nil) +} + +func TestSetAttribute(t *testing.T) { + r := &http.Request{} + r = Init(r) + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + assert.Equal(t, Get(r, "key"), "value") +} + +func TestSetAttributeNone(t *testing.T) { + r := &http.Request{} + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + assert.Equal(t, Get(r, "key"), nil) +} |