summaryrefslogtreecommitdiff
path: root/plugins/websockets/validator
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2021-05-27 00:09:33 +0300
committerValery Piashchynski <[email protected]>2021-05-27 00:09:33 +0300
commitdc3c5455e5c9b32737a0620c8bdb8bda0226dba7 (patch)
tree6ba562da6de7f32a8d528b72cbb56a8bc98c1b30 /plugins/websockets/validator
parentd2e9d8320857f5768c54843a43ad16f59d6a3e8f (diff)
- Update all main abstractions
- Desighn a new interfaces responsible for the whole PubSub - New plugin - websockets Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins/websockets/validator')
-rw-r--r--plugins/websockets/validator/access_validator.go102
-rw-r--r--plugins/websockets/validator/access_validator_test.go35
2 files changed, 137 insertions, 0 deletions
diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go
new file mode 100644
index 00000000..9d9522d4
--- /dev/null
+++ b/plugins/websockets/validator/access_validator.go
@@ -0,0 +1,102 @@
+package validator
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/spiral/roadrunner/v2/plugins/http/attributes"
+)
+
+type AccessValidator struct {
+ buffer *bytes.Buffer
+ header http.Header
+ status int
+}
+
+func NewValidator() *AccessValidator {
+ return &AccessValidator{
+ buffer: bytes.NewBuffer(nil),
+ header: make(http.Header),
+ }
+}
+
+// Copy all content to parent response writer.
+func (w *AccessValidator) Copy(rw http.ResponseWriter) {
+ rw.WriteHeader(w.status)
+
+ for k, v := range w.header {
+ for _, vv := range v {
+ rw.Header().Add(k, vv)
+ }
+ }
+
+ _, _ = io.Copy(rw, w.buffer)
+}
+
+// Header returns the header map that will be sent by WriteHeader.
+func (w *AccessValidator) Header() http.Header {
+ return w.header
+}
+
+// Write writes the data to the connection as part of an HTTP reply.
+func (w *AccessValidator) Write(p []byte) (int, error) {
+ return w.buffer.Write(p)
+}
+
+// WriteHeader sends an HTTP response header with the provided status code.
+func (w *AccessValidator) WriteHeader(statusCode int) {
+ w.status = statusCode
+}
+
+// IsOK returns true if response contained 200 status code.
+func (w *AccessValidator) IsOK() bool {
+ return w.status == 200
+}
+
+// Body returns response body to rely to user.
+func (w *AccessValidator) Body() []byte {
+ return w.buffer.Bytes()
+}
+
+// Error contains server response.
+func (w *AccessValidator) Error() string {
+ return w.buffer.String()
+}
+
+// AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in
+// case of error
+func (w *AccessValidator) AssertServerAccess(f http.HandlerFunc, r *http.Request) error {
+ if err := attributes.Set(r, "ws:joinServer", true); err != nil {
+ return err
+ }
+
+ defer delete(attributes.All(r), "ws:joinServer")
+
+ f(w, r)
+
+ if !w.IsOK() {
+ return w
+ }
+
+ return nil
+}
+
+// AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies.
+// the decision to authorize user will be based on response code (200).
+func (w *AccessValidator) AssertTopicsAccess(f http.HandlerFunc, r *http.Request, channels ...string) error {
+ if err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ",")); err != nil {
+ return err
+ }
+
+ defer delete(attributes.All(r), "ws:joinTopics")
+
+ f(w, r)
+
+ if !w.IsOK() {
+ return w
+ }
+
+ return nil
+}
diff --git a/plugins/websockets/validator/access_validator_test.go b/plugins/websockets/validator/access_validator_test.go
new file mode 100644
index 00000000..4a07b00f
--- /dev/null
+++ b/plugins/websockets/validator/access_validator_test.go
@@ -0,0 +1,35 @@
+package validator
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestResponseWrapper_Body(t *testing.T) {
+ w := NewValidator()
+ _, _ = w.Write([]byte("hello"))
+
+ assert.Equal(t, []byte("hello"), w.Body())
+}
+
+func TestResponseWrapper_Header(t *testing.T) {
+ w := NewValidator()
+ w.Header().Set("k", "value")
+
+ assert.Equal(t, "value", w.Header().Get("k"))
+}
+
+func TestResponseWrapper_StatusCode(t *testing.T) {
+ w := NewValidator()
+ w.WriteHeader(200)
+
+ assert.True(t, w.IsOK())
+}
+
+func TestResponseWrapper_StatusCodeBad(t *testing.T) {
+ w := NewValidator()
+ w.WriteHeader(400)
+
+ assert.False(t, w.IsOK())
+}