summaryrefslogtreecommitdiff
path: root/plugins/websockets/validator/access_validator.go
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/websockets/validator/access_validator.go')
-rw-r--r--plugins/websockets/validator/access_validator.go142
1 files changed, 40 insertions, 102 deletions
diff --git a/plugins/websockets/validator/access_validator.go b/plugins/websockets/validator/access_validator.go
index cd70d9a7..e666f846 100644
--- a/plugins/websockets/validator/access_validator.go
+++ b/plugins/websockets/validator/access_validator.go
@@ -1,138 +1,76 @@
package validator
import (
- "bytes"
- "io"
"net/http"
"strings"
+ json "github.com/json-iterator/go"
"github.com/spiral/errors"
- "github.com/spiral/roadrunner/v2/plugins/channel"
+ handler "github.com/spiral/roadrunner/v2/pkg/worker_handler"
"github.com/spiral/roadrunner/v2/plugins/http/attributes"
)
-type AccessValidator struct {
- buffer *bytes.Buffer
- header http.Header
- status int
-}
-
-func NewValidator() *AccessValidator {
- return &AccessValidator{
- buffer: bytes.NewBuffer(nil),
- header: make(http.Header),
- }
-}
-
-// Copy all content to parent response writer.
-func (w *AccessValidator) Copy(rw http.ResponseWriter) {
- rw.WriteHeader(w.status)
-
- for k, v := range w.header {
- for _, vv := range v {
- rw.Header().Add(k, vv)
- }
- }
+type AccessValidatorFn = func(r *http.Request, channels ...string) (*AccessValidator, error)
- _, _ = io.Copy(rw, w.buffer)
-}
-
-// Header returns the header map that will be sent by WriteHeader.
-func (w *AccessValidator) Header() http.Header {
- return w.header
-}
-
-// Write writes the data to the connection as part of an HTTP reply.
-func (w *AccessValidator) Write(p []byte) (int, error) {
- return w.buffer.Write(p)
-}
-
-// WriteHeader sends an HTTP response header with the provided status code.
-func (w *AccessValidator) WriteHeader(statusCode int) {
- w.status = statusCode
-}
-
-// IsOK returns true if response contained 200 status code.
-func (w *AccessValidator) IsOK() bool {
- return w.status == 200
-}
-
-// Body returns response body to rely to user.
-func (w *AccessValidator) Body() []byte {
- return w.buffer.Bytes()
-}
-
-// Error contains server response.
-func (w *AccessValidator) Error() string {
- return w.buffer.String()
+type AccessValidator struct {
+ Header http.Header `json:"headers"`
+ Status int `json:"status"`
+ Body []byte
}
-// AssertServerAccess checks if user can join server and returns error and body if user can not. Must return nil in
-// case of error
-func (w *AccessValidator) AssertServerAccess(hub channel.Hub, r *http.Request) error {
+func ServerAccessValidator(r *http.Request) ([]byte, error) {
const op = errors.Op("server_access_validator")
+
err := attributes.Set(r, "ws:joinServer", true)
if err != nil {
- return errors.E(op, err)
+ return nil, errors.E(op, err)
}
defer delete(attributes.All(r), "ws:joinServer")
- // send payload to the worker
- hub.ToWorker() <- struct {
- RW http.ResponseWriter
- Req *http.Request
- }{
- w,
- r,
+ req := &handler.Request{
+ RemoteAddr: handler.FetchIP(r.RemoteAddr),
+ Protocol: r.Proto,
+ Method: r.Method,
+ URI: handler.URI(r),
+ Header: r.Header,
+ Cookies: make(map[string]string),
+ RawQuery: r.URL.RawQuery,
+ Attributes: attributes.All(r),
}
- resp := <-hub.FromWorker()
-
- rmsg := resp.(struct {
- RW http.ResponseWriter
- Req *http.Request
- })
-
- if !rmsg.RW.(*AccessValidator).IsOK() {
- return w
+ data, err := json.Marshal(req)
+ if err != nil {
+ return nil, errors.E(op, err)
}
- return nil
+ return data, nil
}
-// AssertTopicsAccess checks if user can access given upstream, the application will receive all user headers and cookies.
-// the decision to authorize user will be based on response code (200).
-func (w *AccessValidator) AssertTopicsAccess(hub channel.Hub, r *http.Request, channels ...string) error {
- const op = errors.Op("topics_access_validator")
-
- err := attributes.Set(r, "ws:joinTopics", strings.Join(channels, ","))
+func TopicsAccessValidator(r *http.Request, topics ...string) ([]byte, error) {
+ const op = errors.Op("topic_access_validator")
+ err := attributes.Set(r, "ws:joinTopics", strings.Join(topics, ","))
if err != nil {
- return errors.E(op, err)
+ return nil, errors.E(op, err)
}
defer delete(attributes.All(r), "ws:joinTopics")
- // send payload to worker
- hub.ToWorker() <- struct {
- RW http.ResponseWriter
- Req *http.Request
- }{
- w,
- r,
+ req := &handler.Request{
+ RemoteAddr: handler.FetchIP(r.RemoteAddr),
+ Protocol: r.Proto,
+ Method: r.Method,
+ URI: handler.URI(r),
+ Header: r.Header,
+ Cookies: make(map[string]string),
+ RawQuery: r.URL.RawQuery,
+ Attributes: attributes.All(r),
}
- // wait response
- resp := <-hub.FromWorker()
-
- rmsg := resp.(struct {
- RW http.ResponseWriter
- Req *http.Request
- })
-
- if !rmsg.RW.(*AccessValidator).IsOK() {
- return w
+ data, err := json.Marshal(req)
+ if err != nil {
+ return nil, errors.E(op, err)
}
- return nil
+ return data, nil
}