diff options
author | Valery Piashchynski <[email protected]> | 2021-06-16 14:04:11 +0300 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2021-06-16 14:04:11 +0300 |
commit | 8483f605037317e9ac266f893d8ef80ff3a6a56e (patch) | |
tree | 98b5c6f5a6fd47b39306b5f30bbaa772b7f59893 /plugins | |
parent | 9dc98d43b0c0de3e1e1bd8fdc97c122c7c7c594f (diff) |
- Add origin check for the websockets
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'plugins')
-rw-r--r-- | plugins/websockets/config.go | 35 | ||||
-rw-r--r-- | plugins/websockets/origin.go | 28 | ||||
-rw-r--r-- | plugins/websockets/origin_test.go | 69 | ||||
-rw-r--r-- | plugins/websockets/plugin.go | 3 | ||||
-rw-r--r-- | plugins/websockets/wildcard.go | 12 |
5 files changed, 142 insertions, 5 deletions
diff --git a/plugins/websockets/config.go b/plugins/websockets/config.go index 93d9ac3b..deb4406c 100644 --- a/plugins/websockets/config.go +++ b/plugins/websockets/config.go @@ -1,6 +1,7 @@ package websockets import ( + "strings" "time" "github.com/spiral/roadrunner/v2/pkg/pool" @@ -57,9 +58,15 @@ type Config struct { PubSubs []string `mapstructure:"pubsubs"` Middleware []string `mapstructure:"middleware"` - Redis *RedisConfig `mapstructure:"redis"` + AllowedOrigin string `mapstructure:"allowed_origin"` + + // wildcard origin + allowedWOrigins []wildcard + allowedOrigins []string + allowedAll bool - Pool *pool.Config `mapstructure:"pool"` + Redis *RedisConfig `mapstructure:"redis"` + Pool *pool.Config `mapstructure:"pool"` } // InitDefault initialize default values for the ws config @@ -67,6 +74,7 @@ func (c *Config) InitDefault() { if c.Path == "" { c.Path = "/ws" } + if len(c.PubSubs) == 0 { // memory used by default c.PubSubs = append(c.PubSubs, "memory") @@ -86,10 +94,9 @@ func (c *Config) InitDefault() { if c.Pool.DestroyTimeout == 0 { c.Pool.DestroyTimeout = time.Minute } - if c.Pool.Supervisor == nil { - return + if c.Pool.Supervisor != nil { + c.Pool.Supervisor.InitDefaults() } - c.Pool.Supervisor.InitDefaults() } if c.Redis != nil { @@ -98,4 +105,22 @@ func (c *Config) InitDefault() { c.Redis.Addrs = append(c.Redis.Addrs, "localhost:6379") } } + + if c.AllowedOrigin == "" { + c.AllowedOrigin = "*" + } + + // Normalize + origin := strings.ToLower(c.AllowedOrigin) + if origin == "*" { + // If "*" is present in the list, turn the whole list into a match all + c.allowedAll = true + return + } else if i := strings.IndexByte(origin, '*'); i >= 0 { + // Split the origin in two: start and end string without the * + w := wildcard{origin[0:i], origin[i+1:]} + c.allowedWOrigins = append(c.allowedWOrigins, w) + } else { + c.allowedOrigins = append(c.allowedOrigins, origin) + } } diff --git a/plugins/websockets/origin.go b/plugins/websockets/origin.go new file mode 100644 index 00000000..c6d9c9b8 --- /dev/null +++ b/plugins/websockets/origin.go @@ -0,0 +1,28 @@ +package websockets + +import ( + "strings" +) + +func isOriginAllowed(origin string, cfg *Config) bool { + if cfg.allowedAll { + return true + } + + origin = strings.ToLower(origin) + // simple case + origin = strings.ToLower(origin) + for _, o := range cfg.allowedOrigins { + if o == origin { + return true + } + } + // check wildcards + for _, w := range cfg.allowedWOrigins { + if w.match(origin) { + return true + } + } + + return false +} diff --git a/plugins/websockets/origin_test.go b/plugins/websockets/origin_test.go new file mode 100644 index 00000000..ccd94d21 --- /dev/null +++ b/plugins/websockets/origin_test.go @@ -0,0 +1,69 @@ +package websockets + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfig_Origin(t *testing.T) { + cfg := &Config{ + AllowedOrigin: "*", + } + + cfg.InitDefault() + + assert.True(t, isOriginAllowed("http://some.some.some.sssome", cfg)) + assert.True(t, isOriginAllowed("http://", cfg)) + assert.True(t, isOriginAllowed("http://google.com", cfg)) + assert.True(t, isOriginAllowed("ws://*", cfg)) + assert.True(t, isOriginAllowed("*", cfg)) + assert.True(t, isOriginAllowed("you are bad programmer", cfg)) // True :( + assert.True(t, isOriginAllowed("****", cfg)) + assert.True(t, isOriginAllowed("asde!@#!!@#!%", cfg)) + assert.True(t, isOriginAllowed("http://*.domain.com", cfg)) +} + +func TestConfig_OriginWildCard(t *testing.T) { + cfg := &Config{ + AllowedOrigin: "https://*my.site.com", + } + + cfg.InitDefault() + + assert.True(t, isOriginAllowed("https://my.site.com", cfg)) + assert.False(t, isOriginAllowed("http://", cfg)) + assert.False(t, isOriginAllowed("http://google.com", cfg)) + assert.False(t, isOriginAllowed("ws://*", cfg)) + assert.False(t, isOriginAllowed("*", cfg)) + assert.False(t, isOriginAllowed("you are bad programmer", cfg)) // True :( + assert.False(t, isOriginAllowed("****", cfg)) + assert.False(t, isOriginAllowed("asde!@#!!@#!%", cfg)) + assert.False(t, isOriginAllowed("http://*.domain.com", cfg)) + + + assert.False(t, isOriginAllowed("https://*site.com", cfg)) + assert.True(t, isOriginAllowed("https://some.my.site.com", cfg)) +} + +func TestConfig_OriginWildCard2(t *testing.T) { + cfg := &Config{ + AllowedOrigin: "https://my.*.com", + } + + cfg.InitDefault() + + assert.True(t, isOriginAllowed("https://my.site.com", cfg)) + assert.False(t, isOriginAllowed("http://", cfg)) + assert.False(t, isOriginAllowed("http://google.com", cfg)) + assert.False(t, isOriginAllowed("ws://*", cfg)) + assert.False(t, isOriginAllowed("*", cfg)) + assert.False(t, isOriginAllowed("you are bad programmer", cfg)) // True :( + assert.False(t, isOriginAllowed("****", cfg)) + assert.False(t, isOriginAllowed("asde!@#!!@#!%", cfg)) + assert.False(t, isOriginAllowed("http://*.domain.com", cfg)) + + + assert.False(t, isOriginAllowed("https://*site.com", cfg)) + assert.True(t, isOriginAllowed("https://my.bad.com", cfg)) +} diff --git a/plugins/websockets/plugin.go b/plugins/websockets/plugin.go index 6dfe6ca3..8b708187 100644 --- a/plugins/websockets/plugin.go +++ b/plugins/websockets/plugin.go @@ -82,6 +82,9 @@ func (p *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Se HandshakeTimeout: time.Second * 60, ReadBufferSize: 1024, WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return isOriginAllowed(r.Header.Get("Origin"), p.cfg) + }, } p.serveExit = make(chan struct{}) p.server = server diff --git a/plugins/websockets/wildcard.go b/plugins/websockets/wildcard.go new file mode 100644 index 00000000..2f1c6601 --- /dev/null +++ b/plugins/websockets/wildcard.go @@ -0,0 +1,12 @@ +package websockets + +import "strings" + +type wildcard struct { + prefix string + suffix string +} + +func (w wildcard) match(s string) bool { + return len(s) >= len(w.prefix)+len(w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix) +} |