summaryrefslogtreecommitdiff
path: root/plugins/http/attributes
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/http/attributes')
-rw-r--r--plugins/http/attributes/attributes.go76
-rw-r--r--plugins/http/attributes/attributes_test.go79
2 files changed, 155 insertions, 0 deletions
diff --git a/plugins/http/attributes/attributes.go b/plugins/http/attributes/attributes.go
new file mode 100644
index 00000000..77d6ea69
--- /dev/null
+++ b/plugins/http/attributes/attributes.go
@@ -0,0 +1,76 @@
+package attributes
+
+import (
+ "context"
+ "errors"
+ "net/http"
+)
+
+type attrKey int
+
+const contextKey attrKey = iota
+
+type attrs map[string]interface{}
+
+func (v attrs) get(key string) interface{} {
+ if v == nil {
+ return ""
+ }
+
+ return v[key]
+}
+
+func (v attrs) set(key string, value interface{}) {
+ v[key] = value
+}
+
+func (v attrs) del(key string) {
+ delete(v, key)
+}
+
+// Init returns request with new context and attribute bag.
+func Init(r *http.Request) *http.Request {
+ return r.WithContext(context.WithValue(r.Context(), contextKey, attrs{}))
+}
+
+// All returns all context attributes.
+func All(r *http.Request) map[string]interface{} {
+ v := r.Context().Value(contextKey)
+ if v == nil {
+ return attrs{}
+ }
+
+ return v.(attrs)
+}
+
+// Get gets the value from request context. It replaces any existing
+// values.
+func Get(r *http.Request, key string) interface{} {
+ v := r.Context().Value(contextKey)
+ if v == nil {
+ return nil
+ }
+
+ return v.(attrs).get(key)
+}
+
+// Set sets the key to value. It replaces any existing
+// values. Context specific.
+func Set(r *http.Request, key string, value interface{}) error {
+ v := r.Context().Value(contextKey)
+ if v == nil {
+ return errors.New("unable to find `psr:attributes` context key")
+ }
+
+ v.(attrs).set(key, value)
+ return nil
+}
+
+// Delete deletes values associated with attribute key.
+func (v attrs) Delete(key string) {
+ if v == nil {
+ return
+ }
+
+ v.del(key)
+}
diff --git a/plugins/http/attributes/attributes_test.go b/plugins/http/attributes/attributes_test.go
new file mode 100644
index 00000000..2360fd12
--- /dev/null
+++ b/plugins/http/attributes/attributes_test.go
@@ -0,0 +1,79 @@
+package attributes
+
+import (
+ "github.com/stretchr/testify/assert"
+ "net/http"
+ "testing"
+)
+
+func TestAllAttributes(t *testing.T) {
+ r := &http.Request{}
+ r = Init(r)
+
+ err := Set(r, "key", "value")
+ if err != nil {
+ t.Errorf("error during the Set: error %v", err)
+ }
+
+ assert.Equal(t, All(r), map[string]interface{}{
+ "key": "value",
+ })
+}
+
+func TestAllAttributesNone(t *testing.T) {
+ r := &http.Request{}
+ r = Init(r)
+
+ assert.Equal(t, All(r), map[string]interface{}{})
+}
+
+func TestAllAttributesNone2(t *testing.T) {
+ r := &http.Request{}
+
+ assert.Equal(t, All(r), map[string]interface{}{})
+}
+
+func TestGetAttribute(t *testing.T) {
+ r := &http.Request{}
+ r = Init(r)
+
+ err := Set(r, "key", "value")
+ if err != nil {
+ t.Errorf("error during the Set: error %v", err)
+ }
+ assert.Equal(t, Get(r, "key"), "value")
+}
+
+func TestGetAttributeNone(t *testing.T) {
+ r := &http.Request{}
+ r = Init(r)
+
+ assert.Equal(t, Get(r, "key"), nil)
+}
+
+func TestGetAttributeNone2(t *testing.T) {
+ r := &http.Request{}
+
+ assert.Equal(t, Get(r, "key"), nil)
+}
+
+func TestSetAttribute(t *testing.T) {
+ r := &http.Request{}
+ r = Init(r)
+
+ err := Set(r, "key", "value")
+ if err != nil {
+ t.Errorf("error during the Set: error %v", err)
+ }
+ assert.Equal(t, Get(r, "key"), "value")
+}
+
+func TestSetAttributeNone(t *testing.T) {
+ r := &http.Request{}
+
+ err := Set(r, "key", "value")
+ if err != nil {
+ t.Errorf("error during the Set: error %v", err)
+ }
+ assert.Equal(t, Get(r, "key"), nil)
+}