diff options
Diffstat (limited to 'plugins')
115 files changed, 12919 insertions, 0 deletions
diff --git a/plugins/checker/config.go b/plugins/checker/config.go new file mode 100644 index 00000000..5f952592 --- /dev/null +++ b/plugins/checker/config.go @@ -0,0 +1,5 @@ +package checker + +type Config struct { + Address string +} diff --git a/plugins/checker/plugin.go b/plugins/checker/plugin.go new file mode 100644 index 00000000..e6250697 --- /dev/null +++ b/plugins/checker/plugin.go @@ -0,0 +1,152 @@ +package checker + +import ( + "fmt" + "net/http" + "time" + + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/interfaces/status" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +const ( + // PluginName declares public plugin name. + PluginName = "status" +) + +type Plugin struct { + registry map[string]status.Checker + server *fiber.App + log log.Logger + cfg *Config +} + +func (c *Plugin) Init(log log.Logger, cfg config.Configurer) error { + const op = errors.Op("status plugin init") + err := cfg.UnmarshalKey(PluginName, &c.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + if c.cfg == nil { + return errors.E(errors.Disabled) + } + + c.registry = make(map[string]status.Checker) + c.log = log + return nil +} + +func (c *Plugin) Serve() chan error { + errCh := make(chan error, 1) + c.server = fiber.New(fiber.Config{ + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + IdleTimeout: time.Second * 5, + }) + c.server.Group("/v1", c.healthHandler) + c.server.Use(logger.New()) + c.server.Use("/health", c.healthHandler) + + go func() { + err := c.server.Listen(c.cfg.Address) + if err != nil { + errCh <- err + } + }() + + return errCh +} + +func (c *Plugin) Stop() error { + const op = errors.Op("checker stop") + err := c.server.Shutdown() + if err != nil { + return errors.E(op, err) + } + return nil +} + +// Reset named service. +func (c *Plugin) Status(name string) (status.Status, error) { + const op = errors.Op("get status") + svc, ok := c.registry[name] + if !ok { + return status.Status{}, errors.E(op, errors.Errorf("no such service: %s", name)) + } + + return svc.Status(), nil +} + +// CollectTarget collecting services which can provide Status. +func (c *Plugin) CollectTarget(name endure.Named, r status.Checker) error { + c.registry[name.Name()] = r + return nil +} + +// Collects declares services to be collected. +func (c *Plugin) Collects() []interface{} { + return []interface{}{ + c.CollectTarget, + } +} + +// Name of the service. +func (c *Plugin) Name() string { + return PluginName +} + +// RPCService returns associated rpc service. +func (c *Plugin) RPC() interface{} { + return &rpc{srv: c, log: c.log} +} + +type Plugins struct { + Plugins []string `query:"plugin"` +} + +const template string = "Service: %s: Status: %d\n" + +func (c *Plugin) healthHandler(ctx *fiber.Ctx) error { + const op = errors.Op("health_handler") + plugins := &Plugins{} + err := ctx.QueryParser(plugins) + if err != nil { + return errors.E(op, err) + } + + if len(plugins.Plugins) == 0 { + ctx.Status(http.StatusOK) + _, _ = ctx.WriteString("No plugins provided in query. Query should be in form of: /v1/health?plugin=plugin1&plugin=plugin2 \n") + return nil + } + + failed := false + // iterate over all provided plugins + for i := 0; i < len(plugins.Plugins); i++ { + // check if the plugin exists + if plugin, ok := c.registry[plugins.Plugins[i]]; ok { + st := plugin.Status() + if st.Code >= 500 { + failed = true + continue + } else if st.Code >= 100 && st.Code <= 400 { + _, _ = ctx.WriteString(fmt.Sprintf(template, plugins.Plugins[i], st.Code)) + } + } else { + _, _ = ctx.WriteString(fmt.Sprintf("Service: %s not found", plugins.Plugins[i])) + } + } + if failed { + ctx.Status(http.StatusInternalServerError) + return nil + } + + ctx.Status(http.StatusOK) + return nil +} diff --git a/plugins/checker/rpc.go b/plugins/checker/rpc.go new file mode 100644 index 00000000..d03d1638 --- /dev/null +++ b/plugins/checker/rpc.go @@ -0,0 +1,28 @@ +package checker + +import ( + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/interfaces/status" +) + +type rpc struct { + srv *Plugin + log log.Logger +} + +// Status return current status of the provided plugin +func (rpc *rpc) Status(service string, status *status.Status) error { + const op = errors.Op("status") + rpc.log.Debug("started Status method", "service", service) + st, err := rpc.srv.Status(service) + if err != nil { + return errors.E(op, err) + } + + *status = st + + rpc.log.Debug("status code", "code", st.Code) + rpc.log.Debug("successfully finished Status method") + return nil +} diff --git a/plugins/checker/tests/configs/.rr-checker-init.yaml b/plugins/checker/tests/configs/.rr-checker-init.yaml new file mode 100755 index 00000000..5943551b --- /dev/null +++ b/plugins/checker/tests/configs/.rr-checker-init.yaml @@ -0,0 +1,29 @@ +rpc: + listen: tcp://127.0.0.1:6005 + disabled: false + +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +status: + address: "127.0.0.1:34333" + +http: + debug: true + address: 127.0.0.1:11933 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/checker/tests/plugin_test.go b/plugins/checker/tests/plugin_test.go new file mode 100644 index 00000000..c93b68af --- /dev/null +++ b/plugins/checker/tests/plugin_test.go @@ -0,0 +1,188 @@ +package tests + +import ( + "io/ioutil" + "net" + "net/http" + "net/rpc" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/goridge/v2" + "github.com/spiral/roadrunner/v2/interfaces/status" + "github.com/spiral/roadrunner/v2/plugins/checker" + "github.com/spiral/roadrunner/v2/plugins/config" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/logger" + rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestStatusHttp(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-checker-init.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &checker.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("CheckerGetStatus", checkHTTPStatus) + wg.Wait() +} + +const resp = `Service: http: Status: 200 +Service: rpc not found` + +func checkHTTPStatus(t *testing.T) { + req, err := http.NewRequest("GET", "http://127.0.0.1:34333/v1/health?plugin=http&plugin=rpc", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, resp, string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestStatusRPC(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-checker-init.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &checker.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("CheckerGetStatusRpc", checkRPCStatus) + wg.Wait() +} + +func checkRPCStatus(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6005") + assert.NoError(t, err) + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + + st := &status.Status{} + + err = client.Call("status.Status", "http", &st) + assert.NoError(t, err) + assert.Equal(t, st.Code, 200) +} diff --git a/plugins/config/configurer.go b/plugins/config/configurer.go new file mode 100755 index 00000000..00010eae --- /dev/null +++ b/plugins/config/configurer.go @@ -0,0 +1,19 @@ +package config + +type Configurer interface { + // UnmarshalKey reads configuration section into configuration object. + // + // func (h *HttpService) Init(cp config.Configurer) error { + // h.config := &HttpConfig{} + // if err := configProvider.UnmarshalKey("http", h.config); err != nil { + // return err + // } + // } + UnmarshalKey(name string, out interface{}) error + + // Get used to get config section + Get(name string) interface{} + + // Has checks if config section exists. + Has(name string) bool +} diff --git a/plugins/config/plugin.go b/plugins/config/plugin.go new file mode 100755 index 00000000..2555d28a --- /dev/null +++ b/plugins/config/plugin.go @@ -0,0 +1,91 @@ +package config + +import ( + "errors" + "fmt" + "strings" + + "github.com/spf13/viper" +) + +type Viper struct { + viper *viper.Viper + Path string + Prefix string +} + +// Inits config provider. +func (v *Viper) Init() error { + v.viper = viper.New() + + // read in environment variables that match + v.viper.AutomaticEnv() + if v.Prefix == "" { + return errors.New("prefix should be set") + } + + v.viper.SetEnvPrefix(v.Prefix) + if v.Path == "" { + return errors.New("path should be set") + } + + v.viper.SetConfigFile(v.Path) + v.viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + + return v.viper.ReadInConfig() +} + +// Overwrite overwrites existing config with provided values +func (v *Viper) Overwrite(values map[string]string) error { + if len(values) != 0 { + for _, flag := range values { + key, value, err := parseFlag(flag) + if err != nil { + return err + } + v.viper.Set(key, value) + } + } + + return nil +} + +// UnmarshalKey reads configuration section into configuration object. +func (v *Viper) UnmarshalKey(name string, out interface{}) error { + err := v.viper.UnmarshalKey(name, &out) + if err != nil { + return err + } + return nil +} + +// Get raw config in a form of config section. +func (v *Viper) Get(name string) interface{} { + return v.viper.Get(name) +} + +// Has checks if config section exists. +func (v *Viper) Has(name string) bool { + return v.viper.IsSet(name) +} + +func parseFlag(flag string) (string, string, error) { + if !strings.Contains(flag, "=") { + return "", "", fmt.Errorf("invalid flag `%s`", flag) + } + + parts := strings.SplitN(strings.TrimLeft(flag, " \"'`"), "=", 2) + + return strings.Trim(parts[0], " \n\t"), parseValue(strings.Trim(parts[1], " \n\t")), nil +} + +func parseValue(value string) string { + escape := []rune(value)[0] + + if escape == '"' || escape == '\'' || escape == '`' { + value = strings.Trim(value, string(escape)) + value = strings.ReplaceAll(value, fmt.Sprintf("\\%s", string(escape)), string(escape)) + } + + return value +} diff --git a/plugins/config/tests/.rr.yaml b/plugins/config/tests/.rr.yaml new file mode 100755 index 00000000..732a1366 --- /dev/null +++ b/plugins/config/tests/.rr.yaml @@ -0,0 +1,18 @@ +reload: + enabled: true + interval: 1s + patterns: [".php"] + services: + http: + recursive: true + ignore: ["vendor"] + patterns: [".php", ".go",".md",] + dirs: ["."] + jobs: + recursive: false + ignore: ["service/metrics"] + dirs: ["./jobs"] + rpc: + recursive: true + patterns: [".json"] + dirs: [""] diff --git a/plugins/config/tests/config_test.go b/plugins/config/tests/config_test.go new file mode 100755 index 00000000..422e7eee --- /dev/null +++ b/plugins/config/tests/config_test.go @@ -0,0 +1,66 @@ +package tests + +import ( + "os" + "os/signal" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/stretchr/testify/assert" +) + +func TestViperProvider_Init(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + vp := &config.Viper{} + vp.Path = ".rr.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + errCh, err := container.Serve() + if err != nil { + t.Fatal(err) + } + + // stop by CTRL+C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + tt := time.NewTicker(time.Second * 2) + + for { + select { + case e := <-errCh: + assert.NoError(t, e.Error) + assert.NoError(t, container.Stop()) + return + case <-c: + er := container.Stop() + if er != nil { + panic(er) + } + return + case <-tt.C: + tt.Stop() + assert.NoError(t, container.Stop()) + return + } + } +} diff --git a/plugins/config/tests/plugin1.go b/plugins/config/tests/plugin1.go new file mode 100755 index 00000000..a276c15f --- /dev/null +++ b/plugins/config/tests/plugin1.go @@ -0,0 +1,53 @@ +package tests + +import ( + "errors" + "time" + + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// ReloadConfig is a Reload configuration point. +type ReloadConfig struct { + Interval time.Duration + Patterns []string + Services map[string]ServiceConfig +} + +type ServiceConfig struct { + Enabled bool + Recursive bool + Patterns []string + Dirs []string + Ignore []string +} + +type Foo struct { + configProvider config.Configurer +} + +// Depends on S2 and DB (S3 in the current case) +func (f *Foo) Init(p config.Configurer) error { + f.configProvider = p + return nil +} + +func (f *Foo) Serve() chan error { + errCh := make(chan error, 1) + + r := &ReloadConfig{} + err := f.configProvider.UnmarshalKey("reload", r) + if err != nil { + errCh <- err + } + + if len(r.Patterns) == 0 { + errCh <- errors.New("should be at least one pattern, but got 0") + } + + return errCh +} + +func (f *Foo) Stop() error { + return nil +} diff --git a/plugins/gzip/plugin.go b/plugins/gzip/plugin.go new file mode 100644 index 00000000..e5b9e4f5 --- /dev/null +++ b/plugins/gzip/plugin.go @@ -0,0 +1,25 @@ +package gzip + +import ( + "net/http" + + "github.com/NYTimes/gziphandler" +) + +const PluginName = "gzip" + +type Gzip struct{} + +func (g *Gzip) Init() error { + return nil +} + +func (g *Gzip) Middleware(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + gziphandler.GzipHandler(next).ServeHTTP(w, r) + } +} + +func (g *Gzip) Name() string { + return PluginName +} diff --git a/plugins/gzip/tests/configs/.rr-http-middlewareNotExist.yaml b/plugins/gzip/tests/configs/.rr-http-middlewareNotExist.yaml new file mode 100644 index 00000000..df02c043 --- /dev/null +++ b/plugins/gzip/tests/configs/.rr-http-middlewareNotExist.yaml @@ -0,0 +1,22 @@ +server: + command: "php psr-worker.php" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:18103 + maxRequestSize: 1024 + middleware: [ "gzip", "foo" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/gzip/tests/configs/.rr-http-withGzip.yaml b/plugins/gzip/tests/configs/.rr-http-withGzip.yaml new file mode 100644 index 00000000..b91a9aad --- /dev/null +++ b/plugins/gzip/tests/configs/.rr-http-withGzip.yaml @@ -0,0 +1,22 @@ +server: + command: "php psr-worker.php" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:18953 + maxRequestSize: 1024 + middleware: [ "gzip" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/gzip/tests/plugin_test.go b/plugins/gzip/tests/plugin_test.go new file mode 100644 index 00000000..39979895 --- /dev/null +++ b/plugins/gzip/tests/plugin_test.go @@ -0,0 +1,170 @@ +package tests + +import ( + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/mocks" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/gzip" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestGzipPlugin(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http-withGzip.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &gzip.Gzip{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + tt := time.NewTimer(time.Second * 10) + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + t.Run("GzipCheckHeader", headerCheck) + wg.Wait() +} + +func headerCheck(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:18953", nil) + assert.NoError(t, err) + client := &http.Client{ + Transport: &http.Transport{ + DisableCompression: false, + }, + } + + r, err := client.Do(req) + assert.NoError(t, err) + assert.True(t, r.Uncompressed) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestMiddlewareNotExist(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http-middlewareNotExist.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Warn("requested middleware does not exist", "requested", "foo") + + err = cont.RegisterAll( + cfg, + mockLogger, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &gzip.Gzip{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + tt := time.NewTimer(time.Second * 5) + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + wg.Wait() +} diff --git a/plugins/gzip/tests/psr-worker.php b/plugins/gzip/tests/psr-worker.php new file mode 100644 index 00000000..ed936bde --- /dev/null +++ b/plugins/gzip/tests/psr-worker.php @@ -0,0 +1,23 @@ +<?php +/** + * @var Goridge\RelayInterface $relay + */ +use Spiral\Goridge; +use Spiral\RoadRunner; + +ini_set('display_errors', 'stderr'); +require dirname(__DIR__) . "/../../vendor_php/autoload.php"; + +$worker = new RoadRunner\Worker(new Goridge\StreamRelay(STDIN, STDOUT)); +$psr7 = new RoadRunner\PSR7Client($worker); + +while ($req = $psr7->acceptRequest()) { + try { + $resp = new \Zend\Diactoros\Response(); + $resp->getBody()->write(str_repeat("hello world", 1000)); + + $psr7->respond($resp); + } catch (\Throwable $e) { + $psr7->getWorker()->error((string)$e); + } +}
\ No newline at end of file diff --git a/plugins/headers/config.go b/plugins/headers/config.go new file mode 100644 index 00000000..8d4e29c2 --- /dev/null +++ b/plugins/headers/config.go @@ -0,0 +1,36 @@ +package headers + +// Config declares headers service configuration. +type Config struct { + Headers struct { + // CORS settings. + CORS *CORSConfig + + // Request headers to add to every payload send to PHP. + Request map[string]string + + // Response headers to add to every payload generated by PHP. + Response map[string]string + } +} + +// CORSConfig headers configuration. +type CORSConfig struct { + // AllowedOrigin: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowedOrigin string + + // AllowedHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowedHeaders string + + // AllowedMethods: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowedMethods string + + // AllowCredentials https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials *bool + + // ExposeHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + ExposedHeaders string + + // MaxAge of CORS headers in seconds/ + MaxAge int +} diff --git a/plugins/headers/plugin.go b/plugins/headers/plugin.go new file mode 100644 index 00000000..f1c6e6f3 --- /dev/null +++ b/plugins/headers/plugin.go @@ -0,0 +1,117 @@ +package headers + +import ( + "net/http" + "strconv" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// ID contains default service name. +const PluginName = "headers" +const RootPluginName = "http" + +// Service serves headers files. Potentially convert into middleware? +type Plugin struct { + // server configuration (location, forbidden files and etc) + cfg *Config +} + +// Init must return configure service and return true if service hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer) error { + const op = errors.Op("headers plugin init") + err := cfg.UnmarshalKey(RootPluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + return nil +} + +// middleware must return true if request/response pair is handled within the middleware. +func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if s.cfg.Headers.Request != nil { + for k, v := range s.cfg.Headers.Request { + r.Header.Add(k, v) + } + } + + if s.cfg.Headers.Response != nil { + for k, v := range s.cfg.Headers.Response { + w.Header().Set(k, v) + } + } + + if s.cfg.Headers.CORS != nil { + if r.Method == http.MethodOptions { + s.preflightRequest(w) + return + } + s.corsHeaders(w) + } + + next.ServeHTTP(w, r) + } +} + +func (s *Plugin) Name() string { + return PluginName +} + +// configure OPTIONS response +func (s *Plugin) preflightRequest(w http.ResponseWriter) { + headers := w.Header() + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.AllowedMethods != "" { + headers.Set("Access-Control-Allow-Methods", s.cfg.Headers.CORS.AllowedMethods) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } + + if s.cfg.Headers.CORS.MaxAge > 0 { + headers.Set("Access-Control-Max-Age", strconv.Itoa(s.cfg.Headers.CORS.MaxAge)) + } + + w.WriteHeader(http.StatusOK) +} + +// configure CORS headers +func (s *Plugin) corsHeaders(w http.ResponseWriter) { + headers := w.Header() + + headers.Add("Vary", "Origin") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.ExposedHeaders != "" { + headers.Set("Access-Control-Expose-Headers", s.cfg.Headers.CORS.ExposedHeaders) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } +} diff --git a/plugins/headers/tests/configs/.rr-cors-headers.yaml b/plugins/headers/tests/configs/.rr-cors-headers.yaml new file mode 100644 index 00000000..5c1a200b --- /dev/null +++ b/plugins/headers/tests/configs/.rr-cors-headers.yaml @@ -0,0 +1,37 @@ +server: + command: "php ../../../tests/http/client.php headers pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:22855 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + cors: + allowedOrigin: "*" + allowedHeaders: "*" + allowedMethods: "GET,POST,PUT,DELETE" + allowCredentials: true + exposedHeaders: "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma" + maxAge: 600 + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/configs/.rr-headers-init.yaml b/plugins/headers/tests/configs/.rr-headers-init.yaml new file mode 100644 index 00000000..252fe8f3 --- /dev/null +++ b/plugins/headers/tests/configs/.rr-headers-init.yaml @@ -0,0 +1,37 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:33453 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + cors: + allowedOrigin: "*" + allowedHeaders: "*" + allowedMethods: "GET,POST,PUT,DELETE" + allowCredentials: true + exposedHeaders: "Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma" + maxAge: 600 + request: + "Example-Request-Header": "Value" + response: + "X-Powered-By": "RoadRunner" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/configs/.rr-req-headers.yaml b/plugins/headers/tests/configs/.rr-req-headers.yaml new file mode 100644 index 00000000..9256e98d --- /dev/null +++ b/plugins/headers/tests/configs/.rr-req-headers.yaml @@ -0,0 +1,30 @@ +server: + command: "php ../../../tests/http/client.php header pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:22655 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/configs/.rr-res-headers.yaml b/plugins/headers/tests/configs/.rr-res-headers.yaml new file mode 100644 index 00000000..1bca2c3d --- /dev/null +++ b/plugins/headers/tests/configs/.rr-res-headers.yaml @@ -0,0 +1,30 @@ +server: + command: "php ../../../tests/http/client.php header pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:22455 + maxRequestSize: 1024 + middleware: [ "headers" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + # Additional HTTP headers and CORS control. + headers: + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/headers/tests/headers_plugin_test.go b/plugins/headers/tests/headers_plugin_test.go new file mode 100644 index 00000000..f1de8cb9 --- /dev/null +++ b/plugins/headers/tests/headers_plugin_test.go @@ -0,0 +1,359 @@ +package tests + +import ( + "io/ioutil" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/headers" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestHeadersInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-headers-init.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 5) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + wg.Wait() +} + +func TestRequestHeaders(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-req-headers.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("RequestHeaders", reqHeaders) + wg.Wait() +} + +func reqHeaders(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:22655?hello=value", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "CUSTOM-HEADER", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestResponseHeaders(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-res-headers.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("ResponseHeaders", resHeaders) + wg.Wait() +} + +func resHeaders(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:22455?hello=value", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, "output-header", r.Header.Get("output")) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "CUSTOM-HEADER", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestCORSHeaders(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-cors-headers.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &headers.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("CORSHeaders", corsHeaders) + t.Run("CORSHeadersPass", corsHeadersPass) + wg.Wait() +} + +func corsHeadersPass(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:22855", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + + _, err = ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func corsHeaders(t *testing.T) { + req, err := http.NewRequest("OPTIONS", "http://localhost:22855", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Headers")) + assert.Equal(t, "GET,POST,PUT,DELETE", r.Header.Get("Access-Control-Allow-Methods")) + assert.Equal(t, "*", r.Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "600", r.Header.Get("Access-Control-Max-Age")) + assert.Equal(t, "true", r.Header.Get("Access-Control-Allow-Credentials")) + + _, err = ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + err = r.Body.Close() + assert.NoError(t, err) +} diff --git a/plugins/http/attributes/attributes.go b/plugins/http/attributes/attributes.go new file mode 100644 index 00000000..4c453766 --- /dev/null +++ b/plugins/http/attributes/attributes.go @@ -0,0 +1,85 @@ +package attributes + +import ( + "context" + "errors" + "net/http" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return k.name } + +var ( + // PsrContextKey is a context key. It can be used in the http attributes + PsrContextKey = &contextKey{"psr_attributes"} +) + +type attrs map[string]interface{} + +func (v attrs) get(key string) interface{} { + if v == nil { + return "" + } + + return v[key] +} + +func (v attrs) set(key string, value interface{}) { + v[key] = value +} + +func (v attrs) del(key string) { + delete(v, key) +} + +// Init returns request with new context and attribute bag. +func Init(r *http.Request) *http.Request { + return r.WithContext(context.WithValue(r.Context(), PsrContextKey, attrs{})) +} + +// All returns all context attributes. +func All(r *http.Request) map[string]interface{} { + v := r.Context().Value(PsrContextKey) + if v == nil { + return attrs{} + } + + return v.(attrs) +} + +// Get gets the value from request context. It replaces any existing +// values. +func Get(r *http.Request, key string) interface{} { + v := r.Context().Value(PsrContextKey) + if v == nil { + return nil + } + + return v.(attrs).get(key) +} + +// Set sets the key to value. It replaces any existing +// values. Context specific. +func Set(r *http.Request, key string, value interface{}) error { + v := r.Context().Value(PsrContextKey) + if v == nil { + return errors.New("unable to find `psr:attributes` context key") + } + + v.(attrs).set(key, value) + return nil +} + +// Delete deletes values associated with attribute key. +func (v attrs) Delete(key string) { + if v == nil { + return + } + + v.del(key) +} diff --git a/plugins/http/attributes/attributes_test.go b/plugins/http/attributes/attributes_test.go new file mode 100644 index 00000000..5622deb4 --- /dev/null +++ b/plugins/http/attributes/attributes_test.go @@ -0,0 +1,77 @@ +package attributes + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllAttributes(t *testing.T) { + r := &http.Request{} + r = Init(r) + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + + assert.Equal(t, All(r), map[string]interface{}{ + "key": "value", + }) +} + +func TestAllAttributesNone(t *testing.T) { + r := &http.Request{} + r = Init(r) + + assert.Equal(t, All(r), map[string]interface{}{}) +} + +func TestAllAttributesNone2(t *testing.T) { + r := &http.Request{} + + assert.Equal(t, All(r), map[string]interface{}{}) +} + +func TestGetAttribute(t *testing.T) { + r := &http.Request{} + r = Init(r) + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + assert.Equal(t, Get(r, "key"), "value") +} + +func TestGetAttributeNone(t *testing.T) { + r := &http.Request{} + r = Init(r) + + assert.Equal(t, Get(r, "key"), nil) +} + +func TestGetAttributeNone2(t *testing.T) { + r := &http.Request{} + + assert.Equal(t, Get(r, "key"), nil) +} + +func TestSetAttribute(t *testing.T) { + r := &http.Request{} + r = Init(r) + + err := Set(r, "key", "value") + if err != nil { + t.Errorf("error during the Set: error %v", err) + } + assert.Equal(t, Get(r, "key"), "value") +} + +func TestSetAttributeNone(t *testing.T) { + r := &http.Request{} + err := Set(r, "key", "value") + assert.Error(t, err) + assert.Equal(t, Get(r, "key"), nil) +} diff --git a/plugins/http/config.go b/plugins/http/config.go new file mode 100644 index 00000000..d6efe310 --- /dev/null +++ b/plugins/http/config.go @@ -0,0 +1,291 @@ +package http + +import ( + "net" + "os" + "runtime" + "strings" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" +) + +type Cidrs []*net.IPNet + +func (c *Cidrs) IsTrusted(ip string) bool { + if len(*c) == 0 { + return false + } + + i := net.ParseIP(ip) + if i == nil { + return false + } + + for _, cird := range *c { + if cird.Contains(i) { + return true + } + } + + return false +} + +// Config configures RoadRunner HTTP server. +type Config struct { + // Port and port to handle as http server. + Address string + + // SSL defines https server options. + SSL *SSLConfig + + // FCGI configuration. You can use FastCGI without HTTP server. + FCGI *FCGIConfig + + // HTTP2 configuration + HTTP2 *HTTP2Config + + // MaxRequestSize specified max size for payload body in megabytes, set 0 to unlimited. + MaxRequestSize uint64 + + // TrustedSubnets declare IP subnets which are allowed to set ip using X-Real-Ip and X-Forwarded-For + TrustedSubnets []string + + // Uploads configures uploads configuration. + Uploads *UploadsConfig + + // Pool configures worker pool. + Pool *roadrunner.PoolConfig + + // Env is environment variables passed to the http pool + Env map[string]string + + // List of the middleware names (order will be preserved) + Middleware []string + + // slice of net.IPNet + cidrs Cidrs +} + +// FCGIConfig for FastCGI server. +type FCGIConfig struct { + // Address and port to handle as http server. + Address string +} + +// HTTP2Config HTTP/2 server customizations. +type HTTP2Config struct { + // Enable or disable HTTP/2 extension, default enable. + Enabled bool + + // H2C enables HTTP/2 over TCP + H2C bool + + // MaxConcurrentStreams defaults to 128. + MaxConcurrentStreams uint32 +} + +// InitDefaults sets default values for HTTP/2 configuration. +func (cfg *HTTP2Config) InitDefaults() error { + cfg.Enabled = true + cfg.MaxConcurrentStreams = 128 + + return nil +} + +// SSLConfig defines https server configuration. +type SSLConfig struct { + // Port to listen as HTTPS server, defaults to 443. + Port int + + // Redirect when enabled forces all http connections to switch to https. + Redirect bool + + // Key defined private server key. + Key string + + // Cert is https certificate. + Cert string + + // Root CA file + RootCA string +} + +// EnableHTTP is true when http server must run. +func (c *Config) EnableHTTP() bool { + return c.Address != "" +} + +// EnableTLS returns true if pool must listen TLS connections. +func (c *Config) EnableTLS() bool { + return c.SSL.Key != "" || c.SSL.Cert != "" || c.SSL.RootCA != "" +} + +// EnableHTTP2 when HTTP/2 extension must be enabled (only with TSL). +func (c *Config) EnableHTTP2() bool { + return c.HTTP2.Enabled +} + +// EnableH2C when HTTP/2 extension must be enabled on TCP. +func (c *Config) EnableH2C() bool { + return c.HTTP2.H2C +} + +// EnableFCGI is true when FastCGI server must be enabled. +func (c *Config) EnableFCGI() bool { + return c.FCGI.Address != "" +} + +// Hydrate must populate Config values using given Config source. Must return error if Config is not valid. +func (c *Config) InitDefaults() error { + if c.Pool == nil { + // default pool + c.Pool = &roadrunner.PoolConfig{ + Debug: false, + NumWorkers: int64(runtime.NumCPU()), + MaxJobs: 1000, + AllocateTimeout: time.Second * 60, + DestroyTimeout: time.Second * 60, + Supervisor: nil, + } + } + + if c.HTTP2 == nil { + c.HTTP2 = &HTTP2Config{} + } + + if c.FCGI == nil { + c.FCGI = &FCGIConfig{} + } + + if c.Uploads == nil { + c.Uploads = &UploadsConfig{} + } + + if c.SSL == nil { + c.SSL = &SSLConfig{} + } + + if c.SSL.Port == 0 { + c.SSL.Port = 443 + } + + err := c.HTTP2.InitDefaults() + if err != nil { + return err + } + err = c.Uploads.InitDefaults() + if err != nil { + return err + } + + if c.TrustedSubnets == nil { + // @see https://en.wikipedia.org/wiki/Reserved_IP_addresses + c.TrustedSubnets = []string{ + "10.0.0.0/8", + "127.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } + } + + cidrs, err := ParseCIDRs(c.TrustedSubnets) + if err != nil { + return err + } + c.cidrs = cidrs + + return c.Valid() +} + +func ParseCIDRs(subnets []string) (Cidrs, error) { + c := make(Cidrs, 0, len(subnets)) + for _, cidr := range subnets { + _, cr, err := net.ParseCIDR(cidr) + if err != nil { + return nil, err + } + + c = append(c, cr) + } + + return c, nil +} + +// IsTrusted if api can be trusted to use X-Real-Ip, X-Forwarded-For +func (c *Config) IsTrusted(ip string) bool { + if c.cidrs == nil { + return false + } + + i := net.ParseIP(ip) + if i == nil { + return false + } + + for _, cird := range c.cidrs { + if cird.Contains(i) { + return true + } + } + + return false +} + +// Valid validates the configuration. +func (c *Config) Valid() error { + const op = errors.Op("validation") + if c.Uploads == nil { + return errors.E(op, errors.Str("malformed uploads config")) + } + + if c.HTTP2 == nil { + return errors.E(op, errors.Str("malformed http2 config")) + } + + if c.Pool == nil { + return errors.E(op, "malformed pool config") + } + + if !c.EnableHTTP() && !c.EnableTLS() && !c.EnableFCGI() { + return errors.E(op, errors.Str("unable to run http service, no method has been specified (http, https, http/2 or FastCGI)")) + } + + if c.Address != "" && !strings.Contains(c.Address, ":") { + return errors.E(op, errors.Str("malformed http server address")) + } + + if c.EnableTLS() { + if _, err := os.Stat(c.SSL.Key); err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("key file '%s' does not exists", c.SSL.Key)) + } + + return err + } + + if _, err := os.Stat(c.SSL.Cert); err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("cert file '%s' does not exists", c.SSL.Cert)) + } + + return err + } + + // RootCA is optional, but if provided - check it + if c.SSL.RootCA != "" { + if _, err := os.Stat(c.SSL.RootCA); err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("root ca path provided, but path '%s' does not exists", c.SSL.RootCA)) + } + return err + } + } + } + + return nil +} diff --git a/plugins/http/constants.go b/plugins/http/constants.go new file mode 100644 index 00000000..773d1f46 --- /dev/null +++ b/plugins/http/constants.go @@ -0,0 +1,6 @@ +package http + +import "net/http" + +var http2pushHeaderKey = http.CanonicalHeaderKey("http2-push") +var TrailerHeaderKey = http.CanonicalHeaderKey("trailer") diff --git a/plugins/http/errors.go b/plugins/http/errors.go new file mode 100644 index 00000000..fb8762ef --- /dev/null +++ b/plugins/http/errors.go @@ -0,0 +1,25 @@ +// +build !windows + +package http + +import ( + "errors" + "net" + "os" + "syscall" +) + +// Broken pipe +var errEPIPE = errors.New("EPIPE(32) -> connection reset by peer") + +// handleWriteError just check if error was caused by aborted connection on linux +func handleWriteError(err error) error { + if netErr, ok2 := err.(*net.OpError); ok2 { + if syscallErr, ok3 := netErr.Err.(*os.SyscallError); ok3 { + if syscallErr.Err == syscall.EPIPE { + return errEPIPE + } + } + } + return err +} diff --git a/plugins/http/errors_windows.go b/plugins/http/errors_windows.go new file mode 100644 index 00000000..3d0ba04c --- /dev/null +++ b/plugins/http/errors_windows.go @@ -0,0 +1,27 @@ +// +build windows + +package http + +import ( + "errors" + "net" + "os" + "syscall" +) + +//Software caused connection abort. +//An established connection was aborted by the software in your host computer, +//possibly due to a data transmission time-out or protocol error. +var errEPIPE = errors.New("WSAECONNABORTED (10053) -> an established connection was aborted by peer") + +// handleWriteError just check if error was caused by aborted connection on windows +func handleWriteError(err error) error { + if netErr, ok2 := err.(*net.OpError); ok2 { + if syscallErr, ok3 := netErr.Err.(*os.SyscallError); ok3 { + if syscallErr.Err == syscall.WSAECONNABORTED { + return errEPIPE + } + } + } + return err +} diff --git a/plugins/http/handler.go b/plugins/http/handler.go new file mode 100644 index 00000000..74b038ff --- /dev/null +++ b/plugins/http/handler.go @@ -0,0 +1,238 @@ +package http + +import ( + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/util" +) + +const ( + // EventResponse thrown after the request been processed. See ErrorEvent as payload. + EventResponse = iota + 500 + + // EventError thrown on any non job error provided by road runner server. + EventError +) + +type Handle interface { + AddListener(l util.EventListener) + ServeHTTP(w http.ResponseWriter, r *http.Request) +} + +// ErrorEvent represents singular http error event. +type ErrorEvent struct { + // Request contains client request, must not be stored. + Request *http.Request + + // Error - associated error, if any. + Error error + + // event timings + start time.Time + elapsed time.Duration +} + +// Elapsed returns duration of the invocation. +func (e *ErrorEvent) Elapsed() time.Duration { + return e.elapsed +} + +// ResponseEvent represents singular http response event. +type ResponseEvent struct { + // Request contains client request, must not be stored. + Request *Request + + // Response contains service response. + Response *Response + + // event timings + start time.Time + elapsed time.Duration +} + +// Elapsed returns duration of the invocation. +func (e *ResponseEvent) Elapsed() time.Duration { + return e.elapsed +} + +// Handler serves http connections to underlying PHP application using PSR-7 protocol. Context will include request headers, +// parsed files and query, payload will include parsed form dataTree (if any). +type handler struct { + maxRequestSize uint64 + uploads UploadsConfig + trusted Cidrs + log log.Logger + pool roadrunner.Pool + mul sync.Mutex + lsn util.EventListener +} + +func NewHandler(maxReqSize uint64, uploads UploadsConfig, trusted Cidrs, pool roadrunner.Pool) (Handle, error) { + if pool == nil { + return nil, errors.E(errors.Str("pool should be initialized")) + } + return &handler{ + maxRequestSize: maxReqSize * roadrunner.MB, + uploads: uploads, + pool: pool, + trusted: trusted, + }, nil +} + +// Listen attaches handler event controller. +func (h *handler) AddListener(l util.EventListener) { + h.mul.Lock() + defer h.mul.Unlock() + + h.lsn = l +} + +// mdwr serve using PSR-7 requests passed to underlying application. Attempts to serve static files first if enabled. +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + const op = errors.Op("ServeHTTP") + start := time.Now() + + // validating request size + if h.maxRequestSize != 0 { + err := h.maxSize(w, r, start, op) + if err != nil { + return + } + } + + req, err := NewRequest(r, h.uploads) + if err != nil { + h.handleError(w, r, err, start) + return + } + + // proxy IP resolution + h.resolveIP(req) + + req.Open(h.log) + defer req.Close(h.log) + + p, err := req.Payload() + if err != nil { + h.handleError(w, r, err, start) + return + } + + rsp, err := h.pool.Exec(p) + if err != nil { + h.handleError(w, r, err, start) + return + } + + resp, err := NewResponse(rsp) + if err != nil { + h.handleError(w, r, err, start) + return + } + + h.handleResponse(req, resp, start) + err = resp.Write(w) + if err != nil { + h.handleError(w, r, err, start) + } +} + +func (h *handler) maxSize(w http.ResponseWriter, r *http.Request, start time.Time, op errors.Op) error { + if length := r.Header.Get("content-length"); length != "" { + if size, err := strconv.ParseInt(length, 10, 64); err != nil { + h.handleError(w, r, err, start) + return err + } else if size > int64(h.maxRequestSize) { + h.handleError(w, r, errors.E(op, errors.Str("request body max size is exceeded")), start) + return err + } + } + return nil +} + +// handleError sends error. +func (h *handler) handleError(w http.ResponseWriter, r *http.Request, err error, start time.Time) { + h.mul.Lock() + defer h.mul.Unlock() + // if pipe is broken, there is no sense to write the header + // in this case we just report about error + if err == errEPIPE { + h.throw(ErrorEvent{Request: r, Error: err, start: start, elapsed: time.Since(start)}) + return + } + err = multierror.Append(err) + // ResponseWriter is ok, write the error code + w.WriteHeader(500) + _, err2 := w.Write([]byte(err.Error())) + // error during the writing to the ResponseWriter + if err2 != nil { + err = multierror.Append(err2, err) + // concat original error with ResponseWriter error + h.throw(ErrorEvent{Request: r, Error: errors.E(err), start: start, elapsed: time.Since(start)}) + return + } + h.throw(ErrorEvent{Request: r, Error: err, start: start, elapsed: time.Since(start)}) +} + +// handleResponse triggers response event. +func (h *handler) handleResponse(req *Request, resp *Response, start time.Time) { + h.throw(ResponseEvent{Request: req, Response: resp, start: start, elapsed: time.Since(start)}) +} + +// throw invokes event handler if any. +func (h *handler) throw(ctx interface{}) { + if h.lsn != nil { + h.lsn(ctx) + } +} + +// get real ip passing multiple proxy +func (h *handler) resolveIP(r *Request) { + if h.trusted.IsTrusted(r.RemoteAddr) == false { + return + } + + if r.Header.Get("X-Forwarded-For") != "" { + ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",") + ipCount := len(ips) + + for i := ipCount - 1; i >= 0; i-- { + addr := strings.TrimSpace(ips[i]) + if net.ParseIP(addr) != nil { + r.RemoteAddr = addr + return + } + } + + return + } + + // The logic here is the following: + // In general case, we only expect X-Real-Ip header. If it exist, we get the IP address from header and set request Remote address + // But, if there is no X-Real-Ip header, we also trying to check CloudFlare headers + // True-Client-IP is a general CF header in which copied information from X-Real-Ip in CF. + // CF-Connecting-IP is an Enterprise feature and we check it last in order. + // This operations are near O(1) because Headers struct are the map type -> type MIMEHeader map[string][]string + if r.Header.Get("X-Real-Ip") != "" { + r.RemoteAddr = fetchIP(r.Header.Get("X-Real-Ip")) + return + } + + if r.Header.Get("True-Client-IP") != "" { + r.RemoteAddr = fetchIP(r.Header.Get("True-Client-IP")) + return + } + + if r.Header.Get("CF-Connecting-IP") != "" { + r.RemoteAddr = fetchIP(r.Header.Get("CF-Connecting-IP")) + } +} diff --git a/plugins/http/parse.go b/plugins/http/parse.go new file mode 100644 index 00000000..d4a1604b --- /dev/null +++ b/plugins/http/parse.go @@ -0,0 +1,147 @@ +package http + +import ( + "net/http" +) + +// MaxLevel defines maximum tree depth for incoming request data and files. +const MaxLevel = 127 + +type dataTree map[string]interface{} +type fileTree map[string]interface{} + +// parseData parses incoming request body into data tree. +func parseData(r *http.Request) dataTree { + data := make(dataTree) + if r.PostForm != nil { + for k, v := range r.PostForm { + data.push(k, v) + } + } + + if r.MultipartForm != nil { + for k, v := range r.MultipartForm.Value { + data.push(k, v) + } + } + + return data +} + +// pushes value into data tree. +func (d dataTree) push(k string, v []string) { + keys := FetchIndexes(k) + if len(keys) <= MaxLevel { + d.mount(keys, v) + } +} + +// mount mounts data tree recursively. +func (d dataTree) mount(i []string, v []string) { + if len(i) == 1 { + // single value context (last element) + d[i[0]] = v[len(v)-1] + return + } + + if len(i) == 2 && i[1] == "" { + // non associated array of elements + d[i[0]] = v + return + } + + if p, ok := d[i[0]]; ok { + p.(dataTree).mount(i[1:], v) + return + } + + d[i[0]] = make(dataTree) + d[i[0]].(dataTree).mount(i[1:], v) +} + +// parse incoming dataTree request into JSON (including contentMultipart form dataTree) +func parseUploads(r *http.Request, cfg UploadsConfig) *Uploads { + u := &Uploads{ + cfg: cfg, + tree: make(fileTree), + list: make([]*FileUpload, 0), + } + + for k, v := range r.MultipartForm.File { + files := make([]*FileUpload, 0, len(v)) + for _, f := range v { + files = append(files, NewUpload(f)) + } + + u.list = append(u.list, files...) + u.tree.push(k, files) + } + + return u +} + +// pushes new file upload into it's proper place. +func (d fileTree) push(k string, v []*FileUpload) { + keys := FetchIndexes(k) + if len(keys) <= MaxLevel { + d.mount(keys, v) + } +} + +// mount mounts data tree recursively. +func (d fileTree) mount(i []string, v []*FileUpload) { + if len(i) == 1 { + // single value context + d[i[0]] = v[0] + return + } + + if len(i) == 2 && i[1] == "" { + // non associated array of elements + d[i[0]] = v + return + } + + if p, ok := d[i[0]]; ok { + p.(fileTree).mount(i[1:], v) + return + } + + d[i[0]] = make(fileTree) + d[i[0]].(fileTree).mount(i[1:], v) +} + +// FetchIndexes parses input name and splits it into separate indexes list. +func FetchIndexes(s string) []string { + var ( + pos int + ch string + keys = make([]string, 1) + ) + + for _, c := range s { + ch = string(c) + switch ch { + case " ": + // ignore all spaces + continue + case "[": + pos = 1 + continue + case "]": + if pos == 1 { + keys = append(keys, "") + } + pos = 2 + default: + if pos == 1 || pos == 2 { + keys = append(keys, "") + } + + keys[len(keys)-1] += ch + pos = 0 + } + } + + return keys +} diff --git a/plugins/http/plugin.go b/plugins/http/plugin.go new file mode 100644 index 00000000..371cdb91 --- /dev/null +++ b/plugins/http/plugin.go @@ -0,0 +1,541 @@ +package http + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net/http" + "net/http/fcgi" + "net/url" + "strings" + "sync" + + "github.com/hashicorp/go-multierror" + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/log" + factory "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/interfaces/status" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" + "github.com/spiral/roadrunner/v2/util" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "golang.org/x/sys/cpu" +) + +const ( + // PluginName declares plugin name. + PluginName = "http" + + // EventInitSSL thrown at moment of https initialization. SSL server passed as context. + EventInitSSL = 750 +) + +// Middleware interface +type Middleware interface { + Middleware(f http.Handler) http.HandlerFunc +} + +type middleware map[string]Middleware + +// Service manages pool, http servers. +type Plugin struct { + sync.Mutex + + configurer config.Configurer + server factory.Server + log log.Logger + + cfg *Config + // middlewares to chain + mdwr middleware + // Event listener to stdout + listener util.EventListener + + // Pool which attached to all servers + pool roadrunner.Pool + + // servers RR handler + handler Handle + + // servers + http *http.Server + https *http.Server + fcgi *http.Server +} + +// AddListener attaches server event controller. +func (s *Plugin) AddListener(listener util.EventListener) { + // save listeners for Reset + s.listener = listener + s.pool.AddListener(listener) +} + +// Init must return configure svc and return true if svc hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer, log log.Logger, server factory.Server) error { + const op = errors.Op("http Init") + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, err) + } + + err = s.cfg.InitDefaults() + if err != nil { + return errors.E(op, err) + } + + s.configurer = cfg + s.log = log + s.mdwr = make(map[string]Middleware) + + if !s.cfg.EnableHTTP() && !s.cfg.EnableTLS() && !s.cfg.EnableFCGI() { + return errors.E(op, errors.Disabled) + } + + s.pool, err = server.NewWorkerPool(context.Background(), roadrunner.PoolConfig{ + Debug: s.cfg.Pool.Debug, + NumWorkers: s.cfg.Pool.NumWorkers, + MaxJobs: s.cfg.Pool.MaxJobs, + AllocateTimeout: s.cfg.Pool.AllocateTimeout, + DestroyTimeout: s.cfg.Pool.DestroyTimeout, + Supervisor: s.cfg.Pool.Supervisor, + }, s.cfg.Env) + if err != nil { + return errors.E(op, err) + } + + s.server = server + + s.AddListener(s.logCallback) + + return nil +} + +func (s *Plugin) logCallback(event interface{}) { + switch ev := event.(type) { + case ResponseEvent: + s.log.Debug("http handler response received", "elapsed", ev.Elapsed().String(), "remote address", ev.Request.RemoteAddr) + case ErrorEvent: + s.log.Error("error event received", "elapsed", ev.Elapsed().String(), "error", ev.Error) + case roadrunner.WorkerEvent: + s.log.Debug("worker event received", "event", ev.Event, "worker state", ev.Worker.State()) + default: + fmt.Println(event) + } +} + +// Serve serves the svc. +func (s *Plugin) Serve() chan error { + s.Lock() + defer s.Unlock() + + const op = errors.Op("serve http") + errCh := make(chan error, 2) + + var err error + s.handler, err = NewHandler( + s.cfg.MaxRequestSize, + *s.cfg.Uploads, + s.cfg.cidrs, + s.pool, + ) + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + + if s.listener != nil { + s.handler.AddListener(s.listener) + } + + if s.cfg.EnableHTTP() { + if s.cfg.EnableH2C() { + s.http = &http.Server{Addr: s.cfg.Address, Handler: h2c.NewHandler(s, &http2.Server{})} + } else { + s.http = &http.Server{Addr: s.cfg.Address, Handler: s} + } + } + + if s.cfg.EnableTLS() { + s.https = s.initSSL() + if s.cfg.SSL.RootCA != "" { + err = s.appendRootCa() + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + } + + if s.cfg.EnableHTTP2() { + if err := s.initHTTP2(); err != nil { + errCh <- errors.E(op, err) + return errCh + } + } + } + + if s.cfg.EnableFCGI() { + s.fcgi = &http.Server{Handler: s} + } + + if s.http != nil { + go func() { + httpErr := s.http.ListenAndServe() + if httpErr != nil && httpErr != http.ErrServerClosed { + errCh <- errors.E(op, httpErr) + return + } + }() + } + + if s.https != nil { + go func() { + httpErr := s.https.ListenAndServeTLS( + s.cfg.SSL.Cert, + s.cfg.SSL.Key, + ) + + if httpErr != nil && httpErr != http.ErrServerClosed { + errCh <- errors.E(op, httpErr) + return + } + }() + } + + if s.fcgi != nil { + go func() { + httpErr := s.serveFCGI() + if httpErr != nil && httpErr != http.ErrServerClosed { + errCh <- errors.E(op, httpErr) + return + } + }() + } + + if len(s.mdwr) > 0 { + s.addMiddlewares() + } + + return errCh +} + +// Stop stops the http. +func (s *Plugin) Stop() error { + s.Lock() + defer s.Unlock() + + var err error + if s.fcgi != nil { + err = s.fcgi.Shutdown(context.Background()) + if err != nil && err != http.ErrServerClosed { + s.log.Error("error shutting down the fcgi server", "error", err) + // write error and try to stop other transport + err = multierror.Append(err) + } + } + + if s.https != nil { + err = s.https.Shutdown(context.Background()) + if err != nil && err != http.ErrServerClosed { + s.log.Error("error shutting down the https server", "error", err) + // write error and try to stop other transport + err = multierror.Append(err) + } + } + + if s.http != nil { + err = s.http.Shutdown(context.Background()) + if err != nil && err != http.ErrServerClosed { + s.log.Error("error shutting down the http server", "error", err) + // write error and try to stop other transport + err = multierror.Append(err) + } + } + + return err +} + +// ServeHTTP handles connection using set of middleware and pool PSR-7 server. +func (s *Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if headerContainsUpgrade(r, s) { + http.Error(w, "server does not support upgrade header", http.StatusInternalServerError) + return + } + + if s.redirect(w, r) { + return + } + + if s.https != nil && r.TLS != nil { + w.Header().Add("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + } + + r = attributes.Init(r) + // protect the case, when user send Reset and we are replacing handler with pool + s.Lock() + s.handler.ServeHTTP(w, r) + s.Unlock() +} + +// Server returns associated pool workers +func (s *Plugin) Workers() []roadrunner.WorkerBase { + return s.pool.Workers() +} + +func (s *Plugin) Name() string { + return PluginName +} + +func (s *Plugin) Reset() error { + s.Lock() + defer s.Unlock() + const op = errors.Op("http reset") + s.log.Info("Resetting http plugin") + s.pool.Destroy(context.Background()) + + // re-read the config + err := s.configurer.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, err) + } + + s.pool, err = s.server.NewWorkerPool(context.Background(), roadrunner.PoolConfig{ + Debug: s.cfg.Pool.Debug, + NumWorkers: s.cfg.Pool.NumWorkers, + MaxJobs: s.cfg.Pool.MaxJobs, + AllocateTimeout: s.cfg.Pool.AllocateTimeout, + DestroyTimeout: s.cfg.Pool.DestroyTimeout, + Supervisor: s.cfg.Pool.Supervisor, + }, s.cfg.Env) + if err != nil { + return errors.E(op, err) + } + + s.handler, err = NewHandler( + s.cfg.MaxRequestSize, + *s.cfg.Uploads, + s.cfg.cidrs, + s.pool, + ) + if err != nil { + return errors.E(op, err) + } + + // restore original listeners + s.pool.AddListener(s.listener) + + return nil +} + +func (s *Plugin) Collects() []interface{} { + return []interface{}{ + s.AddMiddleware, + } +} + +func (s *Plugin) AddMiddleware(name endure.Named, m Middleware) { + s.mdwr[name.Name()] = m +} + +// Status return status of the particular plugin +func (s *Plugin) Status() status.Status { + workers := s.Workers() + for i := 0; i < len(workers); i++ { + if workers[i].State().IsActive() { + return status.Status{ + Code: http.StatusOK, + } + } + } + // if there are no workers, threat this as error + return status.Status{ + Code: http.StatusInternalServerError, + } +} + +func (s *Plugin) redirect(w http.ResponseWriter, r *http.Request) bool { + if s.https != nil && r.TLS == nil && s.cfg.SSL.Redirect { + target := &url.URL{ + Scheme: "https", + Host: s.tlsAddr(r.Host, false), + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + + http.Redirect(w, r, target.String(), http.StatusTemporaryRedirect) + return true + } + return false +} + +func headerContainsUpgrade(r *http.Request, s *Plugin) bool { + if _, ok := r.Header["Upgrade"]; ok { + // https://golang.org/pkg/net/http/#Hijacker + s.log.Error("server does not support Upgrade header") + return true + } + return false +} + +// append RootCA to the https server TLS config +func (s *Plugin) appendRootCa() error { + const op = errors.Op("append root CA") + rootCAs, err := x509.SystemCertPool() + if err != nil { + return nil + } + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + + CA, err := ioutil.ReadFile(s.cfg.SSL.RootCA) + if err != nil { + return err + } + + // should append our CA cert + ok := rootCAs.AppendCertsFromPEM(CA) + if !ok { + return errors.E(op, errors.Str("could not append Certs from PEM")) + } + // disable "G402 (CWE-295): TLS MinVersion too low. (Confidence: HIGH, Severity: HIGH)" + // #nosec G402 + cfg := &tls.Config{ + InsecureSkipVerify: false, + RootCAs: rootCAs, + } + s.http.TLSConfig = cfg + + return nil +} + +// Init https server +func (s *Plugin) initSSL() *http.Server { + var topCipherSuites []uint16 + var defaultCipherSuitesTLS13 []uint16 + + hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + + if hasGCMAsm { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } + + DefaultCipherSuites := make([]uint16, 0, 22) + DefaultCipherSuites = append(DefaultCipherSuites, topCipherSuites...) + DefaultCipherSuites = append(DefaultCipherSuites, defaultCipherSuitesTLS13...) + + server := &http.Server{ + Addr: s.tlsAddr(s.cfg.Address, true), + Handler: s, + TLSConfig: &tls.Config{ + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.CurveP384, + tls.CurveP521, + tls.X25519, + }, + CipherSuites: DefaultCipherSuites, + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + }, + } + + return server +} + +// init http/2 server +func (s *Plugin) initHTTP2() error { + return http2.ConfigureServer(s.https, &http2.Server{ + MaxConcurrentStreams: s.cfg.HTTP2.MaxConcurrentStreams, + }) +} + +// serveFCGI starts FastCGI server. +func (s *Plugin) serveFCGI() error { + l, err := util.CreateListener(s.cfg.FCGI.Address) + if err != nil { + return err + } + + err = fcgi.Serve(l, s.fcgi.Handler) + if err != nil { + return err + } + + return nil +} + +// tlsAddr replaces listen or host port with port configured by SSL config. +func (s *Plugin) tlsAddr(host string, forcePort bool) string { + // remove current forcePort first + host = strings.Split(host, ":")[0] + + if forcePort || s.cfg.SSL.Port != 443 { + host = fmt.Sprintf("%s:%v", host, s.cfg.SSL.Port) + } + + return host +} + +func (s *Plugin) addMiddlewares() { + if s.http != nil { + applyMiddlewares(s.http, s.mdwr, s.cfg.Middleware, s.log) + } + if s.https != nil { + applyMiddlewares(s.https, s.mdwr, s.cfg.Middleware, s.log) + } + + if s.fcgi != nil { + applyMiddlewares(s.fcgi, s.mdwr, s.cfg.Middleware, s.log) + } +} + +func applyMiddlewares(server *http.Server, middlewares map[string]Middleware, order []string, log log.Logger) { + for i := 0; i < len(order); i++ { + if mdwr, ok := middlewares[order[i]]; ok { + server.Handler = mdwr.Middleware(server.Handler) + } else { + log.Warn("requested middleware does not exist", "requested", order[i]) + } + } +} diff --git a/plugins/http/request.go b/plugins/http/request.go new file mode 100644 index 00000000..640bdec2 --- /dev/null +++ b/plugins/http/request.go @@ -0,0 +1,186 @@ +package http + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + + j "github.com/json-iterator/go" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" +) + +var json = j.ConfigCompatibleWithStandardLibrary + +const ( + defaultMaxMemory = 32 << 20 // 32 MB + contentNone = iota + 900 + contentStream + contentMultipart + contentFormData +) + +// Request maps net/http requests to PSR7 compatible structure and managed state of temporary uploaded files. +type Request struct { + // RemoteAddr contains ip address of client, make sure to check X-Real-Ip and X-Forwarded-For for real client address. + RemoteAddr string `json:"remoteAddr"` + + // Protocol includes HTTP protocol version. + Protocol string `json:"protocol"` + + // Method contains name of HTTP method used for the request. + Method string `json:"method"` + + // URI contains full request URI with scheme and query. + URI string `json:"uri"` + + // Header contains list of request headers. + Header http.Header `json:"headers"` + + // Cookies contains list of request cookies. + Cookies map[string]string `json:"cookies"` + + // RawQuery contains non parsed query string (to be parsed on php end). + RawQuery string `json:"rawQuery"` + + // Parsed indicates that request body has been parsed on RR end. + Parsed bool `json:"parsed"` + + // Uploads contains list of uploaded files, their names, sized and associations with temporary files. + Uploads *Uploads `json:"uploads"` + + // Attributes can be set by chained mdwr to safely pass value from Golang to PHP. See: GetAttribute, SetAttribute functions. + Attributes map[string]interface{} `json:"attributes"` + + // request body can be parsedData or []byte + body interface{} +} + +func fetchIP(pair string) string { + if !strings.ContainsRune(pair, ':') { + return pair + } + + addr, _, _ := net.SplitHostPort(pair) + return addr +} + +// NewRequest creates new PSR7 compatible request using net/http request. +func NewRequest(r *http.Request, cfg UploadsConfig) (*Request, error) { + req := &Request{ + RemoteAddr: fetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: uri(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), + } + + for _, c := range r.Cookies() { + if v, err := url.QueryUnescape(c.Value); err == nil { + req.Cookies[c.Name] = v + } + } + + switch req.contentType() { + case contentNone: + return req, nil + + case contentStream: + var err error + req.body, err = ioutil.ReadAll(r.Body) + return req, err + + case contentMultipart: + if err := r.ParseMultipartForm(defaultMaxMemory); err != nil { + return nil, err + } + + req.Uploads = parseUploads(r, cfg) + fallthrough + case contentFormData: + if err := r.ParseForm(); err != nil { + return nil, err + } + + req.body = parseData(r) + } + + req.Parsed = true + return req, nil +} + +// Open moves all uploaded files to temporary directory so it can be given to php later. +func (r *Request) Open(log log.Logger) { + if r.Uploads == nil { + return + } + + r.Uploads.Open(log) +} + +// Close clears all temp file uploads +func (r *Request) Close(log log.Logger) { + if r.Uploads == nil { + return + } + + r.Uploads.Clear(log) +} + +// Payload request marshaled RoadRunner payload based on PSR7 data. values encode method is JSON. Make sure to open +// files prior to calling this method. +func (r *Request) Payload() (roadrunner.Payload, error) { + p := roadrunner.Payload{} + + var err error + if p.Context, err = json.Marshal(r); err != nil { + return roadrunner.EmptyPayload, err + } + + if r.Parsed { + if p.Body, err = json.Marshal(r.body); err != nil { + return roadrunner.EmptyPayload, err + } + } else if r.body != nil { + p.Body = r.body.([]byte) + } + + return p, nil +} + +// contentType returns the payload content type. +func (r *Request) contentType() int { + if r.Method == "HEAD" || r.Method == "OPTIONS" { + return contentNone + } + + ct := r.Header.Get("content-type") + if strings.Contains(ct, "application/x-www-form-urlencoded") { + return contentFormData + } + + if strings.Contains(ct, "multipart/form-data") { + return contentMultipart + } + + return contentStream +} + +// uri fetches full uri from request in a form of string (including https scheme if TLS connection is enabled). +func uri(r *http.Request) string { + if r.URL.Host != "" { + return r.URL.String() + } + if r.TLS != nil { + return fmt.Sprintf("https://%s%s", r.Host, r.URL.String()) + } + + return fmt.Sprintf("http://%s%s", r.Host, r.URL.String()) +} diff --git a/plugins/http/response.go b/plugins/http/response.go new file mode 100644 index 00000000..e3ac2756 --- /dev/null +++ b/plugins/http/response.go @@ -0,0 +1,105 @@ +package http + +import ( + "io" + "net/http" + "strings" + "sync" + + "github.com/spiral/roadrunner/v2" +) + +// Response handles PSR7 response logic. +type Response struct { + // Status contains response status. + Status int `json:"status"` + + // Header contains list of response headers. + Headers map[string][]string `json:"headers"` + + // associated Body payload. + Body interface{} + sync.Mutex +} + +// NewResponse creates new response based on given pool payload. +func NewResponse(p roadrunner.Payload) (*Response, error) { + r := &Response{Body: p.Body} + if err := json.Unmarshal(p.Context, r); err != nil { + return nil, err + } + + return r, nil +} + +// Write writes response headers, status and body into ResponseWriter. +func (r *Response) Write(w http.ResponseWriter) error { + // INFO map is the reference type in golang + p := handlePushHeaders(r.Headers) + if pusher, ok := w.(http.Pusher); ok { + for _, v := range p { + err := pusher.Push(v, nil) + if err != nil { + return err + } + } + } + + handleTrailers(r.Headers) + for n, h := range r.Headers { + for _, v := range h { + w.Header().Add(n, v) + } + } + + w.WriteHeader(r.Status) + + if data, ok := r.Body.([]byte); ok { + _, err := w.Write(data) + if err != nil { + return handleWriteError(err) + } + } + + if rc, ok := r.Body.(io.Reader); ok { + if _, err := io.Copy(w, rc); err != nil { + return err + } + } + + return nil +} + +func handlePushHeaders(h map[string][]string) []string { + var p []string + pushHeader, ok := h[http2pushHeaderKey] + if !ok { + return p + } + + p = append(p, pushHeader...) + + delete(h, http2pushHeaderKey) + + return p +} + +func handleTrailers(h map[string][]string) { + trailers, ok := h[TrailerHeaderKey] + if !ok { + return + } + + for _, tr := range trailers { + for _, n := range strings.Split(tr, ",") { + n = strings.Trim(n, "\t ") + if v, ok := h[n]; ok { + h["Trailer:"+n] = v + + delete(h, n) + } + } + } + + delete(h, TrailerHeaderKey) +} diff --git a/plugins/http/tests/config_test.go b/plugins/http/tests/config_test.go new file mode 100644 index 00000000..068bd66e --- /dev/null +++ b/plugins/http/tests/config_test.go @@ -0,0 +1,330 @@ +package tests + +//import ( +// "os" +// "testing" +// "time" +// +// json "github.com/json-iterator/go" +// "github.com/stretchr/testify/assert" +//) +// +//type mockCfg struct{ cfg string } +// +//func (cfg *mockCfg) Get(name string) service.Config { return nil } +//func (cfg *mockCfg) Unmarshal(out interface{}) error { +// j := json.ConfigCompatibleWithStandardLibrary +// return j.Unmarshal([]byte(cfg.cfg), out) +//} +// +//func Test_Config_Hydrate_Error1(t *testing.T) { +// cfg := &mockCfg{`{"address": "localhost:8080"}`} +// c := &Config{} +// +// assert.NoError(t, c.Hydrate(cfg)) +//} +// +//func Test_Config_Hydrate_Error2(t *testing.T) { +// cfg := &mockCfg{`{"dir": "/dir/"`} +// c := &Config{} +// +// assert.Error(t, c.Hydrate(cfg)) +//} +// +//func Test_Config_Valid(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.NoError(t, cfg.Valid()) +//} +// +//func Test_Trusted_Subnets(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// TrustedSubnets: []string{"200.1.0.0/16"}, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.NoError(t, cfg.parseCIDRs()) +// +// assert.True(t, cfg.IsTrusted("200.1.0.10")) +// assert.False(t, cfg.IsTrusted("127.0.0.0.1")) +//} +// +//func Test_Trusted_Subnets_Err(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// TrustedSubnets: []string{"200.1.0.0"}, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.parseCIDRs()) +//} +// +//func Test_Config_Valid_SSL(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// SSL: SSLConfig{ +// Cert: "fixtures/server.crt", +// Key: "fixtures/server.key", +// }, +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Hydrate(&testCfg{httpCfg: "{}"})) +// +// assert.NoError(t, cfg.Valid()) +// assert.True(t, cfg.EnableTLS()) +// assert.Equal(t, 443, cfg.SSL.Port) +//} +// +//func Test_Config_SSL_No_key(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// SSL: SSLConfig{ +// Cert: "fixtures/server.crt", +// }, +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_SSL_No_Cert(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// SSL: SSLConfig{ +// Key: "fixtures/server.key", +// }, +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_NoUploads(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_NoHTTP2(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 0, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_NoWorkers(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_NoPool(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 0, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_DeadPool(t *testing.T) { +// cfg := &Config{ +// Address: ":8080", +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} +// +//func Test_Config_InvalidAddress(t *testing.T) { +// cfg := &Config{ +// Address: "unexpected_address", +// MaxRequestSize: 1024, +// Uploads: &UploadsConfig{ +// Dir: os.TempDir(), +// Forbid: []string{".go"}, +// }, +// HTTP2: &HTTP2Config{ +// Enabled: true, +// }, +// Workers: &roadrunner.ServerConfig{ +// Command: "php tests/client.php echo pipes", +// Relay: "pipes", +// Pool: &roadrunner.Config{ +// NumWorkers: 1, +// AllocateTimeout: time.Second, +// DestroyTimeout: time.Second, +// }, +// }, +// } +// +// assert.Error(t, cfg.Valid()) +//} diff --git a/plugins/http/tests/configs/.rr-broken-pipes.yaml b/plugins/http/tests/configs/.rr-broken-pipes.yaml new file mode 100644 index 00000000..aacc303e --- /dev/null +++ b/plugins/http/tests/configs/.rr-broken-pipes.yaml @@ -0,0 +1,29 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +server: + command: "php ../../../tests/http/client.php broken pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:12384 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + + diff --git a/plugins/http/tests/configs/.rr-echoErr.yaml b/plugins/http/tests/configs/.rr-echoErr.yaml new file mode 100644 index 00000000..6ecdbb2a --- /dev/null +++ b/plugins/http/tests/configs/.rr-echoErr.yaml @@ -0,0 +1,28 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +server: + command: "php ../../../tests/http/client.php echoerr pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:8080 + maxRequestSize: 1024 + middleware: [ "pluginMiddleware", "pluginMiddleware2" ] + uploads: + forbid: [ "" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/http/tests/configs/.rr-env.yaml b/plugins/http/tests/configs/.rr-env.yaml new file mode 100644 index 00000000..c9fdc798 --- /dev/null +++ b/plugins/http/tests/configs/.rr-env.yaml @@ -0,0 +1,31 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +server: + command: "php ../../../tests/http/client.php env pipes" + user: "" + group: "" + env: + "env_key": "ENV_VALUE" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:12084 + maxRequestSize: 1024 + middleware: [ "" ] + env: + "RR_HTTP": "true" + "env_key": "ENV_VALUE" + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/http/tests/configs/.rr-fcgi-reqUri.yaml b/plugins/http/tests/configs/.rr-fcgi-reqUri.yaml new file mode 100644 index 00000000..dbd19445 --- /dev/null +++ b/plugins/http/tests/configs/.rr-fcgi-reqUri.yaml @@ -0,0 +1,35 @@ +server: + command: "php ../../../tests/http/client.php request-uri pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: :8082 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 1 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8890 + redirect: false + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://127.0.0.1:6921 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128
\ No newline at end of file diff --git a/plugins/http/tests/configs/.rr-fcgi.yaml b/plugins/http/tests/configs/.rr-fcgi.yaml new file mode 100644 index 00000000..0cbd6d02 --- /dev/null +++ b/plugins/http/tests/configs/.rr-fcgi.yaml @@ -0,0 +1,35 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: :8081 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 1 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8889 + redirect: false + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:6920 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128
\ No newline at end of file diff --git a/plugins/http/tests/configs/.rr-h2c.yaml b/plugins/http/tests/configs/.rr-h2c.yaml new file mode 100644 index 00000000..316daea9 --- /dev/null +++ b/plugins/http/tests/configs/.rr-h2c.yaml @@ -0,0 +1,35 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: :8083 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 1 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8891 + redirect: false + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:6920 + http2: + enabled: true + h2c: true + maxConcurrentStreams: 128
\ No newline at end of file diff --git a/plugins/http/tests/configs/.rr-http.yaml b/plugins/http/tests/configs/.rr-http.yaml new file mode 100644 index 00000000..7b91f735 --- /dev/null +++ b/plugins/http/tests/configs/.rr-http.yaml @@ -0,0 +1,41 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:18903 + maxRequestSize: 1024 + middleware: [ "pluginMiddleware", "pluginMiddleware2" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8892 + redirect: false + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:7921 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128 + + diff --git a/plugins/http/tests/configs/.rr-init.yaml b/plugins/http/tests/configs/.rr-init.yaml new file mode 100644 index 00000000..50aa91ec --- /dev/null +++ b/plugins/http/tests/configs/.rr-init.yaml @@ -0,0 +1,41 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:15395 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8892 + redirect: false + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:7921 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128 + + diff --git a/plugins/http/tests/configs/.rr-resetter.yaml b/plugins/http/tests/configs/.rr-resetter.yaml new file mode 100644 index 00000000..b46b21f5 --- /dev/null +++ b/plugins/http/tests/configs/.rr-resetter.yaml @@ -0,0 +1,28 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:10084 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + diff --git a/plugins/http/tests/configs/.rr-ssl-push.yaml b/plugins/http/tests/configs/.rr-ssl-push.yaml new file mode 100644 index 00000000..90a99192 --- /dev/null +++ b/plugins/http/tests/configs/.rr-ssl-push.yaml @@ -0,0 +1,35 @@ +server: + command: "php ../../../tests/http/client.php push pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: :8086 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 1 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8894 + redirect: true + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:6920 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128
\ No newline at end of file diff --git a/plugins/http/tests/configs/.rr-ssl-redirect.yaml b/plugins/http/tests/configs/.rr-ssl-redirect.yaml new file mode 100644 index 00000000..1878ba53 --- /dev/null +++ b/plugins/http/tests/configs/.rr-ssl-redirect.yaml @@ -0,0 +1,35 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: :8087 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 1 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8895 + redirect: true + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:6920 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128
\ No newline at end of file diff --git a/plugins/http/tests/configs/.rr-ssl.yaml b/plugins/http/tests/configs/.rr-ssl.yaml new file mode 100644 index 00000000..127c1678 --- /dev/null +++ b/plugins/http/tests/configs/.rr-ssl.yaml @@ -0,0 +1,35 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: :8085 + maxRequestSize: 1024 + middleware: [ "" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + pool: + numWorkers: 1 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s + + ssl: + port: 8893 + redirect: false + cert: fixtures/server.crt + key: fixtures/server.key + # rootCa: root.crt + fcgi: + address: tcp://0.0.0.0:6920 + http2: + enabled: false + h2c: false + maxConcurrentStreams: 128
\ No newline at end of file diff --git a/plugins/http/tests/fixtures/server.crt b/plugins/http/tests/fixtures/server.crt new file mode 100644 index 00000000..24d67fd7 --- /dev/null +++ b/plugins/http/tests/fixtures/server.crt @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICTTCCAdOgAwIBAgIJAOKyUd+llTRKMAoGCCqGSM49BAMCMGMxCzAJBgNVBAYT +AlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1TYW4gRnJhbmNpc2Nv +MRMwEQYDVQQKDApSb2FkUnVubmVyMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcNMTgw +OTMwMTMzNDUzWhcNMjgwOTI3MTMzNDUzWjBjMQswCQYDVQQGEwJVUzETMBEGA1UE +CAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuIEZyYW5jaXNjbzETMBEGA1UECgwK +Um9hZFJ1bm5lcjESMBAGA1UEAwwJbG9jYWxob3N0MHYwEAYHKoZIzj0CAQYFK4EE +ACIDYgAEVnbShsM+l5RR3wfWWmGhzuFGwNzKCk7i9xyobDIyBUxG/UUSfj7KKlUX +puDnDEtF5xXcepl744CyIAYFLOXHb5WqI4jCOzG0o9f/00QQ4bQudJOdbqV910QF +C2vb7Fxro1MwUTAdBgNVHQ4EFgQU9xUexnbB6ORKayA7Pfjzs33otsAwHwYDVR0j +BBgwFoAU9xUexnbB6ORKayA7Pfjzs33otsAwDwYDVR0TAQH/BAUwAwEB/zAKBggq +hkjOPQQDAgNoADBlAjEAue3HhR/MUhxoa9tSDBtOJT3FYbDQswrsdqBTz97CGKst +e7XeZ3HMEvEXy0hGGEMhAjAqcD/4k9vViVppgWFtkk6+NFbm+Kw/QeeAiH5FgFSj +8xQcb+b7nPwNLp3JOkXkVd4= +-----END CERTIFICATE----- diff --git a/plugins/http/tests/fixtures/server.key b/plugins/http/tests/fixtures/server.key new file mode 100644 index 00000000..7501dd46 --- /dev/null +++ b/plugins/http/tests/fixtures/server.key @@ -0,0 +1,9 @@ +-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCQP8utxNbHR6xZOLAJgUhn88r6IrPqmN0MsgGJM/jePB+T9UhkmIU8 +PMm2HeScbcugBwYFK4EEACKhZANiAARWdtKGwz6XlFHfB9ZaYaHO4UbA3MoKTuL3 +HKhsMjIFTEb9RRJ+PsoqVRem4OcMS0XnFdx6mXvjgLIgBgUs5cdvlaojiMI7MbSj +1//TRBDhtC50k51upX3XRAULa9vsXGs= +-----END EC PRIVATE KEY----- diff --git a/plugins/http/tests/handler_test.go b/plugins/http/tests/handler_test.go new file mode 100644 index 00000000..0c6a39ef --- /dev/null +++ b/plugins/http/tests/handler_test.go @@ -0,0 +1,1859 @@ +package tests + +import ( + "bytes" + "context" + "io/ioutil" + "mime/multipart" + "net/url" + "os/exec" + "runtime" + "strings" + + "github.com/spiral/roadrunner/v2" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/stretchr/testify/assert" + + "net/http" + "os" + "testing" + "time" +) + +func TestHandler_Echo(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "echo", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + go func(server *http.Server) { + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }(hs) + time.Sleep(time.Millisecond * 10) + + body, r, err := get("http://localhost:8177/?hello=world") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", body) +} + +func Test_HandlerErrors(t *testing.T) { + _, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, nil) + assert.Error(t, err) +} + +func TestHandler_Headers(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "header", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8078", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 100) + + req, err := http.NewRequest("GET", "http://localhost:8078?hello=world", nil) + assert.NoError(t, err) + + req.Header.Add("input", "sample") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "world", r.Header.Get("Header")) + assert.Equal(t, "SAMPLE", string(b)) +} + +func TestHandler_Empty_User_Agent(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "user-agent", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8088", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + req, err := http.NewRequest("GET", "http://localhost:8088?hello=world", nil) + assert.NoError(t, err) + + req.Header.Add("user-agent", "") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "", string(b)) +} + +func TestHandler_User_Agent(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "user-agent", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8088", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + req, err := http.NewRequest("GET", "http://localhost:8088?hello=world", nil) + assert.NoError(t, err) + + req.Header.Add("User-Agent", "go-agent") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "go-agent", string(b)) +} + +func TestHandler_Cookies(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "cookie", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8079", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + req, err := http.NewRequest("GET", "http://localhost:8079", nil) + assert.NoError(t, err) + + req.AddCookie(&http.Cookie{Name: "input", Value: "input-value"}) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "INPUT-VALUE", string(b)) + + for _, c := range r.Cookies() { + assert.Equal(t, "output", c.Name) + assert.Equal(t, "cookie-output", c.Value) + } +} + +func TestHandler_JsonPayload_POST(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "payload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8090", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + req, err := http.NewRequest( + "POST", + "http://localhost"+hs.Addr, + bytes.NewBufferString(`{"key":"value"}`), + ) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/json") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, `{"value":"key"}`, string(b)) +} + +func TestHandler_JsonPayload_PUT(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "payload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8081", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + req, err := http.NewRequest("PUT", "http://localhost"+hs.Addr, bytes.NewBufferString(`{"key":"value"}`)) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/json") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, `{"value":"key"}`, string(b)) +} + +func TestHandler_JsonPayload_PATCH(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "payload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8082", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + req, err := http.NewRequest("PATCH", "http://localhost"+hs.Addr, bytes.NewBufferString(`{"key":"value"}`)) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/json") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, `{"value":"key"}`, string(b)) +} + +func TestHandler_FormData_POST(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8083", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 500) + + form := url.Values{} + + form.Add("key", "value") + form.Add("name[]", "name1") + form.Add("name[]", "name2") + form.Add("name[]", "name3") + form.Add("arr[x][y][z]", "y") + form.Add("arr[x][y][e]", "f") + form.Add("arr[c]p", "l") + form.Add("arr[c]z", "") + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, strings.NewReader(form.Encode())) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + // Sorted + assert.Equal(t, "{\"arr\":{\"c\":{\"p\":\"l\",\"z\":\"\"},\"x\":{\"y\":{\"e\":\"f\",\"z\":\"y\"}}},\"key\":\"value\",\"name\":[\"name1\",\"name2\",\"name3\"]}", string(b)) +} + +func TestHandler_FormData_POST_Overwrite(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8083", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + form := url.Values{} + + form.Add("key", "value") + form.Add("key", "value2") + form.Add("name[]", "name1") + form.Add("name[]", "name2") + form.Add("name[]", "name3") + form.Add("arr[x][y][z]", "y") + form.Add("arr[x][y][e]", "f") + form.Add("arr[c]p", "l") + form.Add("arr[c]z", "") + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, strings.NewReader(form.Encode())) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, `{"arr":{"c":{"p":"l","z":""},"x":{"y":{"e":"f","z":"y"}}},"key":"value2","name":["name1","name2","name3"]}`, string(b)) +} + +func TestHandler_FormData_POST_Form_UrlEncoded_Charset(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8083", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + form := url.Values{} + + form.Add("key", "value") + form.Add("name[]", "name1") + form.Add("name[]", "name2") + form.Add("name[]", "name3") + form.Add("arr[x][y][z]", "y") + form.Add("arr[x][y][e]", "f") + form.Add("arr[c]p", "l") + form.Add("arr[c]z", "") + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, strings.NewReader(form.Encode())) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, `{"arr":{"c":{"p":"l","z":""},"x":{"y":{"e":"f","z":"y"}}},"key":"value","name":["name1","name2","name3"]}`, string(b)) +} + +func TestHandler_FormData_PUT(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":17834", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 500) + + form := url.Values{} + + form.Add("key", "value") + form.Add("name[]", "name1") + form.Add("name[]", "name2") + form.Add("name[]", "name3") + form.Add("arr[x][y][z]", "y") + form.Add("arr[x][y][e]", "f") + form.Add("arr[c]p", "l") + form.Add("arr[c]z", "") + + req, err := http.NewRequest("PUT", "http://localhost"+hs.Addr, strings.NewReader(form.Encode())) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, `{"arr":{"c":{"p":"l","z":""},"x":{"y":{"e":"f","z":"y"}}},"key":"value","name":["name1","name2","name3"]}`, string(b)) +} + +func TestHandler_FormData_PATCH(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8085", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + form := url.Values{} + + form.Add("key", "value") + form.Add("name[]", "name1") + form.Add("name[]", "name2") + form.Add("name[]", "name3") + form.Add("arr[x][y][z]", "y") + form.Add("arr[x][y][e]", "f") + form.Add("arr[c]p", "l") + form.Add("arr[c]z", "") + + req, err := http.NewRequest("PATCH", "http://localhost"+hs.Addr, strings.NewReader(form.Encode())) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, "{\"arr\":{\"c\":{\"p\":\"l\",\"z\":\"\"},\"x\":{\"y\":{\"e\":\"f\",\"z\":\"y\"}}},\"key\":\"value\",\"name\":[\"name1\",\"name2\",\"name3\"]}", string(b)) +} + +func TestHandler_Multipart_POST(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8019", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + err = w.WriteField("key", "value") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("key", "value") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name1") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name2") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name3") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[x][y][z]", "y") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[x][y][e]", "f") + + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[c]p", "l") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[c]z", "") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the writer: error %v", err) + } + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, "{\"arr\":{\"c\":{\"p\":\"l\",\"z\":\"\"},\"x\":{\"y\":{\"e\":\"f\",\"z\":\"y\"}}},\"key\":\"value\",\"name\":[\"name1\",\"name2\",\"name3\"]}", string(b)) +} + +func TestHandler_Multipart_PUT(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8020", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 500) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + err = w.WriteField("key", "value") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("key", "value") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name1") + + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name2") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name3") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[x][y][z]", "y") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[x][y][e]", "f") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[c]p", "l") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[c]z", "") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the writer: error %v", err) + } + + req, err := http.NewRequest("PUT", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, `{"arr":{"c":{"p":"l","z":""},"x":{"y":{"e":"f","z":"y"}}},"key":"value","name":["name1","name2","name3"]}`, string(b)) +} + +func TestHandler_Multipart_PATCH(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "data", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8021", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 500) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + err = w.WriteField("key", "value") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("key", "value") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name1") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name2") + + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("name[]", "name3") + + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[x][y][z]", "y") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[x][y][e]", "f") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[c]p", "l") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.WriteField("arr[c]z", "") + if err != nil { + t.Errorf("error writing the field: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the writer: error %v", err) + } + + req, err := http.NewRequest("PATCH", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + assert.Equal(t, `{"arr":{"c":{"p":"l","z":""},"x":{"y":{"e":"f","z":"y"}}},"key":"value","name":["name1","name2","name3"]}`, string(b)) +} + +func TestHandler_Error(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "error", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + _, r, err := get("http://localhost:8177/?hello=world") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + assert.Equal(t, 500, r.StatusCode) +} + +func TestHandler_Error2(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "error2", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + _, r, err := get("http://localhost:8177/?hello=world") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + assert.Equal(t, 500, r.StatusCode) +} + +func TestHandler_Error3(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "pid", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err = hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + b2 := &bytes.Buffer{} + for i := 0; i < 1024*1024; i++ { + b2.Write([]byte(" ")) + } + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, b2) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err = r.Body.Close() + if err != nil { + t.Errorf("error during the closing Body: error %v", err) + } + }() + + assert.NoError(t, err) + assert.Equal(t, 500, r.StatusCode) +} + +func TestHandler_ResponseDuration(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "echo", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + gotresp := make(chan interface{}) + h.AddListener(func(event interface{}) { + switch t := event.(type) { + case httpPlugin.ResponseEvent: + if t.Elapsed() > 0 { + close(gotresp) + } + default: + } + }) + + body, r, err := get("http://localhost:8177/?hello=world") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + + <-gotresp + + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", body) +} + +func TestHandler_ResponseDurationDelayed(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "echoDelay", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + gotresp := make(chan interface{}) + h.AddListener(func(event interface{}) { + switch tp := event.(type) { + case httpPlugin.ResponseEvent: + if tp.Elapsed() > time.Second { + close(gotresp) + } + default: + } + }) + + body, r, err := get("http://localhost:8177/?hello=world") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + <-gotresp + + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", body) +} + +func TestHandler_ErrorDuration(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "error", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err = hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + goterr := make(chan interface{}) + h.AddListener(func(event interface{}) { + switch tp := event.(type) { + case httpPlugin.ErrorEvent: + if tp.Elapsed() > 0 { + close(goterr) + } + default: + } + }) + + _, r, err := get("http://localhost:8177/?hello=world") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + + <-goterr + + assert.Equal(t, 500, r.StatusCode) +} + +func TestHandler_IP(t *testing.T) { + trusted := []string{ + "10.0.0.0/8", + "127.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } + + cidrs, err := httpPlugin.ParseCIDRs(trusted) + assert.NoError(t, err) + assert.NotNil(t, cidrs) + + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "ip", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, cidrs, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: "127.0.0.1:8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + body, r, err := get("http://127.0.0.1:8177/") + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "127.0.0.1", body) +} + +func TestHandler_XRealIP(t *testing.T) { + trusted := []string{ + "10.0.0.0/8", + "127.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } + + cidrs, err := httpPlugin.ParseCIDRs(trusted) + assert.NoError(t, err) + assert.NotNil(t, cidrs) + + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "ip", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, cidrs, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: "127.0.0.1:8179", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + body, r, err := getHeader("http://127.0.0.1:8179/", map[string]string{ + "X-Real-Ip": "200.0.0.1", + }) + + assert.NoError(t, err) + defer func() { + _ = r.Body.Close() + }() + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "200.0.0.1", body) +} + +func TestHandler_XForwardedFor(t *testing.T) { + trusted := []string{ + "10.0.0.0/8", + "127.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "100.0.0.0/16", + "200.0.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } + + cidrs, err := httpPlugin.ParseCIDRs(trusted) + assert.NoError(t, err) + assert.NotNil(t, cidrs) + + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "ip", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, cidrs, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: "127.0.0.1:8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + body, r, err := getHeader("http://127.0.0.1:8177/", map[string]string{ + "X-Forwarded-For": "100.0.0.1, 200.0.0.1, invalid, 101.0.0.1", + }) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "101.0.0.1", body) + _ = r.Body.Close() + + body, r, err = getHeader("http://127.0.0.1:8177/", map[string]string{ + "X-Forwarded-For": "100.0.0.1, 200.0.0.1, 101.0.0.1, invalid", + }) + + assert.NoError(t, err) + _ = r.Body.Close() + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "101.0.0.1", body) +} + +func TestHandler_XForwardedFor_NotTrustedRemoteIp(t *testing.T) { + trusted := []string{ + "10.0.0.0/8", + } + + cidrs, err := httpPlugin.ParseCIDRs(trusted) + assert.NoError(t, err) + assert.NotNil(t, cidrs) + + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "ip", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, cidrs, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: "127.0.0.1:8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + body, r, err := getHeader("http://127.0.0.1:8177/", map[string]string{ + "X-Forwarded-For": "100.0.0.1, 200.0.0.1, invalid, 101.0.0.1", + }) + + assert.NoError(t, err) + _ = r.Body.Close() + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "127.0.0.1", body) +} + +func BenchmarkHandler_Listen_Echo(b *testing.B) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "echo", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: int64(runtime.NumCPU()), + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + b.Fatal(err) + } + defer func() { + pool.Destroy(context.Background()) + }() + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(b, err) + + hs := &http.Server{Addr: ":8177", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + b.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err = hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + b.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + b.ResetTimer() + b.ReportAllocs() + bb := "WORLD" + for n := 0; n < b.N; n++ { + r, err := http.Get("http://localhost:8177/?hello=world") + if err != nil { + b.Fail() + } + // Response might be nil here + if r != nil { + br, err := ioutil.ReadAll(r.Body) + if err != nil { + b.Errorf("error reading Body: error %v", err) + } + if string(br) != bb { + b.Fail() + } + err = r.Body.Close() + if err != nil { + b.Errorf("error closing the Body: error %v", err) + } + } else { + b.Errorf("got nil response") + } + } +} diff --git a/plugins/http/tests/http_test.go b/plugins/http/tests/http_test.go new file mode 100644 index 00000000..06bf3f5d --- /dev/null +++ b/plugins/http/tests/http_test.go @@ -0,0 +1,1130 @@ +package tests + +import ( + "bytes" + "crypto/tls" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/rpc" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/spiral/endure" + "github.com/spiral/goridge/v2" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/mocks" + "github.com/spiral/roadrunner/v2/plugins/config" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/informer" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/resetter" + "github.com/yookoala/gofast" + + rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +var sslClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, +} + +func TestHTTPInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-init.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 5) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + wg.Wait() +} + +func TestHTTPInformerReset(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-resetter.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &informer.Plugin{}, + &resetter.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + go func() { + tt := time.NewTimer(time.Second * 10) + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("HTTPInformerTest", informerTest) + t.Run("HTTPEchoTestBefore", echoHTTP) + t.Run("HTTPResetTest", resetTest) + t.Run("HTTPEchoTestAfter", echoHTTP) + + wg.Wait() +} + +func echoHTTP(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:10084?hello=world", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func resetTest(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6001") + assert.NoError(t, err) + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + // WorkerList contains list of workers. + + var ret bool + err = client.Call("resetter.Reset", "http", &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + var services []string + err = client.Call("resetter.List", nil, &services) + assert.NoError(t, err) + if services[0] != "http" { + t.Fatal("no enough services") + } +} + +func informerTest(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6001") + assert.NoError(t, err) + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + // WorkerList contains list of workers. + list := struct { + // Workers is list of workers. + Workers []roadrunner.ProcessState `json:"workers"` + }{} + + err = client.Call("informer.Workers", "http", &list) + assert.NoError(t, err) + assert.Len(t, list.Workers, 2) +} + +func TestSSL(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-ssl.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("SSLEcho", sslEcho) + t.Run("SSLNoRedirect", sslNoRedirect) + t.Run("fCGIecho", fcgiEcho) + wg.Wait() +} + +func sslNoRedirect(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:8085?hello=world", nil) + assert.NoError(t, err) + + r, err := sslClient.Do(req) + assert.NoError(t, err) + + assert.Nil(t, r.TLS) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + + err2 := r.Body.Close() + if err2 != nil { + t.Errorf("fail to close the Body: error %v", err2) + } +} + +func sslEcho(t *testing.T) { + req, err := http.NewRequest("GET", "https://localhost:8893?hello=world", nil) + assert.NoError(t, err) + + r, err := sslClient.Do(req) + assert.NoError(t, err) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + + err2 := r.Body.Close() + if err2 != nil { + t.Errorf("fail to close the Body: error %v", err2) + } +} + +func fcgiEcho(t *testing.T) { + fcgiConnFactory := gofast.SimpleConnFactory("tcp", "0.0.0.0:6920") + + fcgiHandler := gofast.NewHandler( + gofast.BasicParamsMap(gofast.BasicSession), + gofast.SimpleClientFactory(fcgiConnFactory, 0), + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://site.local/?hello=world", nil) + fcgiHandler.ServeHTTP(w, req) + + body, err := ioutil.ReadAll(w.Result().Body) + + assert.NoError(t, err) + assert.Equal(t, 201, w.Result().StatusCode) + assert.Equal(t, "WORLD", string(body)) +} + +func TestSSLRedirect(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-ssl-redirect.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("SSLRedirect", sslRedirect) + wg.Wait() +} + +func sslRedirect(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:8087?hello=world", nil) + assert.NoError(t, err) + + r, err := sslClient.Do(req) + assert.NoError(t, err) + assert.NotNil(t, r.TLS) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + + err2 := r.Body.Close() + if err2 != nil { + t.Errorf("fail to close the Body: error %v", err2) + } +} + +func TestSSLPushPipes(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-ssl-push.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + tt := time.NewTimer(time.Second * 10) + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("SSLPush", sslPush) + wg.Wait() +} + +func sslPush(t *testing.T) { + req, err := http.NewRequest("GET", "https://localhost:8894?hello=world", nil) + assert.NoError(t, err) + + r, err := sslClient.Do(req) + assert.NoError(t, err) + + assert.NotNil(t, r.TLS) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, "", r.Header.Get("Http2-Push")) + + assert.NoError(t, err) + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + + err2 := r.Body.Close() + if err2 != nil { + t.Errorf("fail to close the Body: error %v", err2) + } +} + +func TestFastCGI_RequestUri(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-fcgi-reqUri.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("FastCGIServiceRequestUri", fcgiReqURI) + wg.Wait() +} + +func fcgiReqURI(t *testing.T) { + time.Sleep(time.Second * 2) + fcgiConnFactory := gofast.SimpleConnFactory("tcp", "127.0.0.1:6921") + + fcgiHandler := gofast.NewHandler( + gofast.BasicParamsMap(gofast.BasicSession), + gofast.SimpleClientFactory(fcgiConnFactory, 0), + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://site.local/hello-world", nil) + fcgiHandler.ServeHTTP(w, req) + + body, err := ioutil.ReadAll(w.Result().Body) + assert.NoError(t, err) + assert.Equal(t, 200, w.Result().StatusCode) + assert.Equal(t, "http://site.local/hello-world", string(body)) +} + +func TestH2CUpgrade(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-h2c.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("H2cUpgrade", h2cUpgrade) + wg.Wait() +} + +func h2cUpgrade(t *testing.T) { + req, err := http.NewRequest("PRI", "http://localhost:8083?hello=world", nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Add("Upgrade", "h2c") + req.Header.Add("Connection", "HTTP2-Settings") + req.Header.Add("HTTP2-Settings", "") + + r, err2 := http.DefaultClient.Do(req) + if err2 != nil { + t.Fatal(err) + } + + assert.Equal(t, "101 Switching Protocols", r.Status) + + err3 := r.Body.Close() + if err3 != nil { + t.Fatal(err) + } +} + +func TestH2C(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-h2c.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("H2c", h2c) + wg.Wait() +} + +func h2c(t *testing.T) { + req, err := http.NewRequest("PRI", "http://localhost:8083?hello=world", nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Add("Connection", "HTTP2-Settings") + req.Header.Add("HTTP2-Settings", "") + + r, err2 := http.DefaultClient.Do(req) + if err2 != nil { + t.Fatal(err) + } + + assert.Equal(t, "201 Created", r.Status) + + err3 := r.Body.Close() + if err3 != nil { + t.Fatal(err) + } +} + +func TestHttpMiddleware(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &PluginMiddleware{}, + &PluginMiddleware2{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 15) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("MiddlewareTest", middleware) + wg.Wait() +} + +func middleware(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:18903?hello=world", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest("GET", "http://localhost:18903/halt", nil) + assert.NoError(t, err) + + r, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + b, err = ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, 500, r.StatusCode) + assert.Equal(t, "halted", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestHttpEchoErr(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-echoErr.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("http handler response received", "elapsed", gomock.Any(), "remote address", "127.0.0.1") + mockLogger.EXPECT().Debug("WORLD", "pid", gomock.Any()) + mockLogger.EXPECT().Debug("worker event received", "event", roadrunner.EventWorkerLog, "worker state", gomock.Any()) + + err = cont.RegisterAll( + cfg, + mockLogger, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &PluginMiddleware{}, + &PluginMiddleware2{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("HttpEchoError", echoError) + wg.Wait() +} + +func echoError(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:8080?hello=world", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, 201, r.StatusCode) + assert.Equal(t, "WORLD", string(b)) + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestHttpEnvVariables(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-env.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &PluginMiddleware{}, + &PluginMiddleware2{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second * 1) + t.Run("EnvVariablesTest", envVarsTest) + wg.Wait() +} + +func envVarsTest(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:12084", nil) + assert.NoError(t, err) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.Equal(t, 200, r.StatusCode) + assert.Equal(t, "ENV_VALUE", string(b)) + + err = r.Body.Close() + assert.NoError(t, err) +} + +func TestHttpBrokenPipes(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-broken-pipes.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &PluginMiddleware{}, + &PluginMiddleware2{}, + ) + assert.NoError(t, err) + + err = cont.Init() + assert.Error(t, err) + + _, err = cont.Serve() + assert.Error(t, err) +} + +func get(url string) (string, *http.Response, error) { + r, err := http.Get(url) + if err != nil { + return "", nil, err + } + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return "", nil, err + } + defer func() { + _ = r.Body.Close() + }() + return string(b), r, err +} + +// get request and return body +func getHeader(url string, h map[string]string) (string, *http.Response, error) { + req, err := http.NewRequest("GET", url, bytes.NewBuffer(nil)) + if err != nil { + return "", nil, err + } + + for k, v := range h { + req.Header.Set(k, v) + } + + r, err := http.DefaultClient.Do(req) + if err != nil { + return "", nil, err + } + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return "", nil, err + } + + err = r.Body.Close() + if err != nil { + return "", nil, err + } + return string(b), r, err +} diff --git a/plugins/http/tests/parse_test.go b/plugins/http/tests/parse_test.go new file mode 100644 index 00000000..a93bc059 --- /dev/null +++ b/plugins/http/tests/parse_test.go @@ -0,0 +1,54 @@ +package tests + +import ( + "testing" + + "github.com/spiral/roadrunner/v2/plugins/http" +) + +var samples = []struct { + in string + out []string +}{ + {"key", []string{"key"}}, + {"key[subkey]", []string{"key", "subkey"}}, + {"key[subkey]value", []string{"key", "subkey", "value"}}, + {"key[subkey][value]", []string{"key", "subkey", "value"}}, + {"key[subkey][value][]", []string{"key", "subkey", "value", ""}}, + {"key[subkey] [value][]", []string{"key", "subkey", "value", ""}}, + {"key [ subkey ] [ value ] [ ]", []string{"key", "subkey", "value", ""}}, +} + +func Test_FetchIndexes(t *testing.T) { + for i := 0; i < len(samples); i++ { + r := http.FetchIndexes(samples[i].in) + if !same(r, samples[i].out) { + t.Errorf("got %q, want %q", r, samples[i].out) + } + } +} + +func BenchmarkConfig_FetchIndexes(b *testing.B) { + for _, tt := range samples { + for n := 0; n < b.N; n++ { + r := http.FetchIndexes(tt.in) + if !same(r, tt.out) { + b.Fail() + } + } + } +} + +func same(in, out []string) bool { + if len(in) != len(out) { + return false + } + + for i, v := range in { + if v != out[i] { + return false + } + } + + return true +} diff --git a/plugins/http/tests/plugin1.go b/plugins/http/tests/plugin1.go new file mode 100644 index 00000000..1cbca744 --- /dev/null +++ b/plugins/http/tests/plugin1.go @@ -0,0 +1,25 @@ +package tests + +import "github.com/spiral/roadrunner/v2/plugins/config" + +type Plugin1 struct { + config config.Configurer +} + +func (p1 *Plugin1) Init(cfg config.Configurer) error { + p1.config = cfg + return nil +} + +func (p1 *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (p1 *Plugin1) Stop() error { + return nil +} + +func (p1 *Plugin1) Name() string { + return "http_test.plugin1" +} diff --git a/plugins/http/tests/plugin_middleware.go b/plugins/http/tests/plugin_middleware.go new file mode 100644 index 00000000..de829d34 --- /dev/null +++ b/plugins/http/tests/plugin_middleware.go @@ -0,0 +1,61 @@ +package tests + +import ( + "net/http" + + "github.com/spiral/roadrunner/v2/plugins/config" +) + +type PluginMiddleware struct { + config config.Configurer +} + +func (p *PluginMiddleware) Init(cfg config.Configurer) error { + p.config = cfg + return nil +} + +func (p *PluginMiddleware) Middleware(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/halt" { + w.WriteHeader(500) + _, err := w.Write([]byte("halted")) + if err != nil { + panic("error writing the data to the http reply") + } + } else { + next.ServeHTTP(w, r) + } + } +} + +func (p *PluginMiddleware) Name() string { + return "pluginMiddleware" +} + +type PluginMiddleware2 struct { + config config.Configurer +} + +func (p *PluginMiddleware2) Init(cfg config.Configurer) error { + p.config = cfg + return nil +} + +func (p *PluginMiddleware2) Middleware(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/boom" { + w.WriteHeader(555) + _, err := w.Write([]byte("boom")) + if err != nil { + panic("error writing the data to the http reply") + } + } else { + next.ServeHTTP(w, r) + } + } +} + +func (p *PluginMiddleware2) Name() string { + return "pluginMiddleware2" +} diff --git a/plugins/http/tests/response_test.go b/plugins/http/tests/response_test.go new file mode 100644 index 00000000..2bfe7d56 --- /dev/null +++ b/plugins/http/tests/response_test.go @@ -0,0 +1,163 @@ +package tests + +import ( + "bytes" + "errors" + "net/http" + "testing" + + "github.com/spiral/roadrunner/v2" + http2 "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/stretchr/testify/assert" +) + +type testWriter struct { + h http.Header + buf bytes.Buffer + wroteHeader bool + code int + err error + pushErr error + pushes []string +} + +func (tw *testWriter) Header() http.Header { return tw.h } + +func (tw *testWriter) Write(p []byte) (int, error) { + if !tw.wroteHeader { + tw.WriteHeader(http.StatusOK) + } + + n, e := tw.buf.Write(p) + if e == nil { + e = tw.err + } + + return n, e +} + +func (tw *testWriter) WriteHeader(code int) { tw.wroteHeader = true; tw.code = code } + +func (tw *testWriter) Push(target string, opts *http.PushOptions) error { + tw.pushes = append(tw.pushes, target) + + return tw.pushErr +} + +func TestNewResponse_Error(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{Context: []byte(`invalid payload`)}) + assert.Error(t, err) + assert.Nil(t, r) +} + +func TestNewResponse_Write(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{ + Context: []byte(`{"headers":{"key":["value"]},"status": 301}`), + Body: []byte(`sample body`), + }) + + assert.NoError(t, err) + assert.NotNil(t, r) + + w := &testWriter{h: http.Header(make(map[string][]string))} + assert.NoError(t, r.Write(w)) + + assert.Equal(t, 301, w.code) + assert.Equal(t, "value", w.h.Get("key")) + assert.Equal(t, "sample body", w.buf.String()) +} + +func TestNewResponse_Stream(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{ + Context: []byte(`{"headers":{"key":["value"]},"status": 301}`), + }) + + // r is pointer, so, it might be nil + if r == nil { + t.Fatal("response is nil") + } + + r.Body = &bytes.Buffer{} + r.Body.(*bytes.Buffer).WriteString("hello world") + + assert.NoError(t, err) + assert.NotNil(t, r) + + w := &testWriter{h: http.Header(make(map[string][]string))} + assert.NoError(t, r.Write(w)) + + assert.Equal(t, 301, w.code) + assert.Equal(t, "value", w.h.Get("key")) + assert.Equal(t, "hello world", w.buf.String()) +} + +func TestNewResponse_StreamError(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{ + Context: []byte(`{"headers":{"key":["value"]},"status": 301}`), + }) + + // r is pointer, so, it might be nil + if r == nil { + t.Fatal("response is nil") + } + + r.Body = &bytes.Buffer{} + r.Body.(*bytes.Buffer).WriteString("hello world") + + assert.NoError(t, err) + assert.NotNil(t, r) + + w := &testWriter{h: http.Header(make(map[string][]string)), err: errors.New("error")} + assert.Error(t, r.Write(w)) +} + +func TestWrite_HandlesPush(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{ + Context: []byte(`{"headers":{"Http2-Push":["/test.js"],"content-type":["text/html"]},"status": 200}`), + }) + + assert.NoError(t, err) + assert.NotNil(t, r) + + w := &testWriter{h: http.Header(make(map[string][]string))} + assert.NoError(t, r.Write(w)) + + assert.Nil(t, w.h["Http2-Push"]) + assert.Equal(t, []string{"/test.js"}, w.pushes) +} + +func TestWrite_HandlesTrailers(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{ + Context: []byte(`{"headers":{"Trailer":["foo, bar", "baz"],"foo":["test"],"bar":["demo"]},"status": 200}`), + }) + + assert.NoError(t, err) + assert.NotNil(t, r) + + w := &testWriter{h: http.Header(make(map[string][]string))} + assert.NoError(t, r.Write(w)) + + assert.Nil(t, w.h[http2.TrailerHeaderKey]) + assert.Nil(t, w.h["foo"]) //nolint:golint,staticcheck + assert.Nil(t, w.h["baz"]) //nolint:golint,staticcheck + + assert.Equal(t, "test", w.h.Get("Trailer:foo")) + assert.Equal(t, "demo", w.h.Get("Trailer:bar")) +} + +func TestWrite_HandlesHandlesWhitespacesInTrailer(t *testing.T) { + r, err := http2.NewResponse(roadrunner.Payload{ + Context: []byte( + `{"headers":{"Trailer":["foo\t,bar , baz"],"foo":["a"],"bar":["b"],"baz":["c"]},"status": 200}`), + }) + + assert.NoError(t, err) + assert.NotNil(t, r) + + w := &testWriter{h: http.Header(make(map[string][]string))} + assert.NoError(t, r.Write(w)) + + assert.Equal(t, "a", w.h.Get("Trailer:foo")) + assert.Equal(t, "b", w.h.Get("Trailer:bar")) + assert.Equal(t, "c", w.h.Get("Trailer:baz")) +} diff --git a/plugins/http/tests/uploads_config_test.go b/plugins/http/tests/uploads_config_test.go new file mode 100644 index 00000000..497cd54f --- /dev/null +++ b/plugins/http/tests/uploads_config_test.go @@ -0,0 +1,26 @@ +package tests + +import ( + "os" + "testing" + + "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/stretchr/testify/assert" +) + +func TestFsConfig_Forbids(t *testing.T) { + cfg := http.UploadsConfig{Forbid: []string{".php"}} + + assert.True(t, cfg.Forbids("index.php")) + assert.True(t, cfg.Forbids("index.PHP")) + assert.True(t, cfg.Forbids("phpadmin/index.bak.php")) + assert.False(t, cfg.Forbids("index.html")) +} + +func TestFsConfig_TmpFallback(t *testing.T) { + cfg := http.UploadsConfig{Dir: "test"} + assert.Equal(t, "test", cfg.TmpDir()) + + cfg = http.UploadsConfig{Dir: ""} + assert.Equal(t, os.TempDir(), cfg.TmpDir()) +} diff --git a/plugins/http/tests/uploads_test.go b/plugins/http/tests/uploads_test.go new file mode 100644 index 00000000..ee244c06 --- /dev/null +++ b/plugins/http/tests/uploads_test.go @@ -0,0 +1,431 @@ +package tests + +import ( + "bytes" + "context" + "crypto/sha512" + "encoding/hex" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "os" + "os/exec" + "testing" + "time" + + j "github.com/json-iterator/go" + "github.com/spiral/roadrunner/v2" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/stretchr/testify/assert" +) + +var json = j.ConfigCompatibleWithStandardLibrary + +const testFile = "uploads_test.go" + +func TestHandler_Upload_File(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "upload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8021", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + + f := mustOpen(testFile) + defer func() { + err := f.Close() + if err != nil { + t.Errorf("failed to close a file: error %v", err) + } + }() + fw, err := w.CreateFormFile("upload", f.Name()) + assert.NotNil(t, fw) + assert.NoError(t, err) + _, err = io.Copy(fw, f) + if err != nil { + t.Errorf("error copying the file: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the file: error %v", err) + } + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error closing the Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + fs := fileString(testFile, 0, "application/octet-stream") + + assert.Equal(t, `{"upload":`+fs+`}`, string(b)) +} + +func TestHandler_Upload_NestedFile(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "upload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8021", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + + f := mustOpen(testFile) + defer func() { + err := f.Close() + if err != nil { + t.Errorf("failed to close a file: error %v", err) + } + }() + fw, err := w.CreateFormFile("upload[x][y][z][]", f.Name()) + assert.NotNil(t, fw) + assert.NoError(t, err) + _, err = io.Copy(fw, f) + if err != nil { + t.Errorf("error copying the file: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the file: error %v", err) + } + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error closing the Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + fs := fileString(testFile, 0, "application/octet-stream") + + assert.Equal(t, `{"upload":{"x":{"y":{"z":[`+fs+`]}}}}`, string(b)) +} + +func TestHandler_Upload_File_NoTmpDir(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "upload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: "-------", + Forbid: []string{}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8021", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + + f := mustOpen(testFile) + defer func() { + err := f.Close() + if err != nil { + t.Errorf("failed to close a file: error %v", err) + } + }() + fw, err := w.CreateFormFile("upload", f.Name()) + assert.NotNil(t, fw) + assert.NoError(t, err) + _, err = io.Copy(fw, f) + if err != nil { + t.Errorf("error copying the file: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the file: error %v", err) + } + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error closing the Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + fs := fileString(testFile, 5, "application/octet-stream") + + assert.Equal(t, `{"upload":`+fs+`}`, string(b)) +} + +func TestHandler_Upload_File_Forbids(t *testing.T) { + pool, err := roadrunner.NewPool(context.Background(), + func() *exec.Cmd { return exec.Command("php", "../../../tests/http/client.php", "upload", "pipes") }, + roadrunner.NewPipeFactory(), + roadrunner.PoolConfig{ + NumWorkers: 1, + AllocateTimeout: time.Second * 1000, + DestroyTimeout: time.Second * 1000, + }) + if err != nil { + t.Fatal(err) + } + + h, err := httpPlugin.NewHandler(1024, httpPlugin.UploadsConfig{ + Dir: os.TempDir(), + Forbid: []string{".go"}, + }, nil, pool) + assert.NoError(t, err) + + hs := &http.Server{Addr: ":8021", Handler: h} + defer func() { + err := hs.Shutdown(context.Background()) + if err != nil { + t.Errorf("error during the shutdown: error %v", err) + } + }() + + go func() { + err := hs.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + t.Errorf("error listening the interface: error %v", err) + } + }() + time.Sleep(time.Millisecond * 10) + + var mb bytes.Buffer + w := multipart.NewWriter(&mb) + + f := mustOpen(testFile) + defer func() { + err := f.Close() + if err != nil { + t.Errorf("failed to close a file: error %v", err) + } + }() + fw, err := w.CreateFormFile("upload", f.Name()) + assert.NotNil(t, fw) + assert.NoError(t, err) + _, err = io.Copy(fw, f) + if err != nil { + t.Errorf("error copying the file: error %v", err) + } + + err = w.Close() + if err != nil { + t.Errorf("error closing the file: error %v", err) + } + + req, err := http.NewRequest("POST", "http://localhost"+hs.Addr, &mb) + assert.NoError(t, err) + + req.Header.Set("Content-Type", w.FormDataContentType()) + + r, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer func() { + err := r.Body.Close() + if err != nil { + t.Errorf("error closing the Body: error %v", err) + } + }() + + b, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, 200, r.StatusCode) + + fs := fileString(testFile, 7, "application/octet-stream") + + assert.Equal(t, `{"upload":`+fs+`}`, string(b)) +} + +func Test_FileExists(t *testing.T) { + assert.True(t, exists(testFile)) + assert.False(t, exists("uploads_test.")) +} + +func mustOpen(f string) *os.File { + r, err := os.Open(f) + if err != nil { + panic(err) + } + return r +} + +type fInfo struct { + Name string `json:"name"` + Size int64 `json:"size"` + Mime string `json:"mime"` + Error int `json:"error"` + Sha512 string `json:"sha512,omitempty"` +} + +func fileString(f string, errNo int, mime string) string { + s, err := os.Stat(f) + if err != nil { + fmt.Println(fmt.Errorf("error stat the file, error: %v", err)) + } + + ff, err := os.Open(f) + if err != nil { + fmt.Println(fmt.Errorf("error opening the file, error: %v", err)) + } + + defer func() { + er := ff.Close() + if er != nil { + fmt.Println(fmt.Errorf("error closing the file, error: %v", er)) + } + }() + + h := sha512.New() + _, err = io.Copy(h, ff) + if err != nil { + fmt.Println(fmt.Errorf("error copying the file, error: %v", err)) + } + + v := &fInfo{ + Name: s.Name(), + Size: s.Size(), + Error: errNo, + Mime: mime, + Sha512: hex.EncodeToString(h.Sum(nil)), + } + + if errNo != 0 { + v.Sha512 = "" + v.Size = 0 + } + + r, err := json.Marshal(v) + if err != nil { + fmt.Println(fmt.Errorf("error marshalling fInfo, error: %v", err)) + } + return string(r) +} + +// exists if file exists. +func exists(path string) bool { + if _, err := os.Stat(path); os.IsNotExist(err) { + return false + } + return true +} diff --git a/plugins/http/uploads.go b/plugins/http/uploads.go new file mode 100644 index 00000000..5fddb75d --- /dev/null +++ b/plugins/http/uploads.go @@ -0,0 +1,158 @@ +package http + +import ( + "github.com/spiral/roadrunner/v2/interfaces/log" + + "io" + "io/ioutil" + "mime/multipart" + "os" + "sync" +) + +const ( + // UploadErrorOK - no error, the file uploaded with success. + UploadErrorOK = 0 + + // UploadErrorNoFile - no file was uploaded. + UploadErrorNoFile = 4 + + // UploadErrorNoTmpDir - missing a temporary folder. + UploadErrorNoTmpDir = 5 + + // UploadErrorCantWrite - failed to write file to disk. + UploadErrorCantWrite = 6 + + // UploadErrorExtension - forbidden file extension. + UploadErrorExtension = 7 +) + +// Uploads tree manages uploaded files tree and temporary files. +type Uploads struct { + // associated temp directory and forbidden extensions. + cfg UploadsConfig + + // pre processed data tree for Uploads. + tree fileTree + + // flat list of all file Uploads. + list []*FileUpload +} + +// MarshalJSON marshal tree tree into JSON. +func (u *Uploads) MarshalJSON() ([]byte, error) { + return json.Marshal(u.tree) +} + +// Open moves all uploaded files to temp directory, return error in case of issue with temp directory. File errors +// will be handled individually. +func (u *Uploads) Open(log log.Logger) { + var wg sync.WaitGroup + for _, f := range u.list { + wg.Add(1) + go func(f *FileUpload) { + defer wg.Done() + err := f.Open(u.cfg) + if err != nil && log != nil { + log.Error("error opening the file", "err", err) + } + }(f) + } + + wg.Wait() +} + +// Clear deletes all temporary files. +func (u *Uploads) Clear(log log.Logger) { + for _, f := range u.list { + if f.TempFilename != "" && exists(f.TempFilename) { + err := os.Remove(f.TempFilename) + if err != nil && log != nil { + log.Error("error removing the file", "err", err) + } + } + } +} + +// FileUpload represents singular file NewUpload. +type FileUpload struct { + // ID contains filename specified by the client. + Name string `json:"name"` + + // Mime contains mime-type provided by the client. + Mime string `json:"mime"` + + // Size of the uploaded file. + Size int64 `json:"size"` + + // Error indicates file upload error (if any). See http://php.net/manual/en/features.file-upload.errors.php + Error int `json:"error"` + + // TempFilename points to temporary file location. + TempFilename string `json:"tmpName"` + + // associated file header + header *multipart.FileHeader +} + +// NewUpload wraps net/http upload into PRS-7 compatible structure. +func NewUpload(f *multipart.FileHeader) *FileUpload { + return &FileUpload{ + Name: f.Filename, + Mime: f.Header.Get("Content-Type"), + Error: UploadErrorOK, + header: f, + } +} + +// Open moves file content into temporary file available for PHP. +// NOTE: +// There is 2 deferred functions, and in case of getting 2 errors from both functions +// error from close of temp file would be overwritten by error from the main file +// STACK +// DEFER FILE CLOSE (2) +// DEFER TMP CLOSE (1) +func (f *FileUpload) Open(cfg UploadsConfig) (err error) { + if cfg.Forbids(f.Name) { + f.Error = UploadErrorExtension + return nil + } + + file, err := f.header.Open() + if err != nil { + f.Error = UploadErrorNoFile + return err + } + + defer func() { + // close the main file + err = file.Close() + }() + + tmp, err := ioutil.TempFile(cfg.TmpDir(), "upload") + if err != nil { + // most likely cause of this issue is missing tmp dir + f.Error = UploadErrorNoTmpDir + return err + } + + f.TempFilename = tmp.Name() + defer func() { + // close the temp file + err = tmp.Close() + }() + + if f.Size, err = io.Copy(tmp, file); err != nil { + f.Error = UploadErrorCantWrite + } + + return err +} + +// exists if file exists. +func exists(path string) bool { + if _, err := os.Stat(path); os.IsNotExist(err) { + return false + } + return true +} diff --git a/plugins/http/uploads_config.go b/plugins/http/uploads_config.go new file mode 100644 index 00000000..4c20c8e8 --- /dev/null +++ b/plugins/http/uploads_config.go @@ -0,0 +1,46 @@ +package http + +import ( + "os" + "path" + "strings" +) + +// UploadsConfig describes file location and controls access to them. +type UploadsConfig struct { + // Dir contains name of directory to control access to. + Dir string + + // Forbid specifies list of file extensions which are forbidden for access. + // Example: .php, .exe, .bat, .htaccess and etc. + Forbid []string +} + +// InitDefaults sets missing values to their default values. +func (cfg *UploadsConfig) InitDefaults() error { + cfg.Forbid = []string{".php", ".exe", ".bat"} + cfg.Dir = os.TempDir() + return nil +} + +// TmpDir returns temporary directory. +func (cfg *UploadsConfig) TmpDir() string { + if cfg.Dir != "" { + return cfg.Dir + } + + return os.TempDir() +} + +// Forbids must return true if file extension is not allowed for the upload. +func (cfg *UploadsConfig) Forbids(filename string) bool { + ext := strings.ToLower(path.Ext(filename)) + + for _, v := range cfg.Forbid { + if ext == v { + return true + } + } + + return false +} diff --git a/plugins/informer/plugin.go b/plugins/informer/plugin.go new file mode 100644 index 00000000..f3013394 --- /dev/null +++ b/plugins/informer/plugin.go @@ -0,0 +1,56 @@ +package informer + +import ( + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/informer" + "github.com/spiral/roadrunner/v2/interfaces/log" +) + +const PluginName = "informer" + +type Plugin struct { + registry map[string]informer.Informer + log log.Logger +} + +func (p *Plugin) Init(log log.Logger) error { + p.registry = make(map[string]informer.Informer) + p.log = log + return nil +} + +// Workers provides WorkerBase slice with workers for the requested plugin +func (p *Plugin) Workers(name string) ([]roadrunner.WorkerBase, error) { + const op = errors.Op("get workers") + svc, ok := p.registry[name] + if !ok { + return nil, errors.E(op, errors.Errorf("no such service: %s", name)) + } + + return svc.Workers(), nil +} + +// CollectTarget resettable service. +func (p *Plugin) CollectTarget(name endure.Named, r informer.Informer) error { + p.registry[name.Name()] = r + return nil +} + +// Collects declares services to be collected. +func (p *Plugin) Collects() []interface{} { + return []interface{}{ + p.CollectTarget, + } +} + +// Name of the service. +func (p *Plugin) Name() string { + return PluginName +} + +// RPCService returns associated rpc service. +func (p *Plugin) RPC() interface{} { + return &rpc{srv: p, log: p.log} +} diff --git a/plugins/informer/rpc.go b/plugins/informer/rpc.go new file mode 100644 index 00000000..d6b7bf01 --- /dev/null +++ b/plugins/informer/rpc.go @@ -0,0 +1,53 @@ +package informer + +import ( + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/log" +) + +type rpc struct { + srv *Plugin + log log.Logger +} + +// WorkerList contains list of workers. +type WorkerList struct { + // Workers is list of workers. + Workers []roadrunner.ProcessState `json:"workers"` +} + +// List all resettable services. +func (rpc *rpc) List(_ bool, list *[]string) error { + rpc.log.Debug("Started List method") + *list = make([]string, 0, len(rpc.srv.registry)) + + for name := range rpc.srv.registry { + *list = append(*list, name) + } + rpc.log.Debug("list of services", "list", *list) + + rpc.log.Debug("successfully finished List method") + return nil +} + +// Workers state of a given service. +func (rpc *rpc) Workers(service string, list *WorkerList) error { + rpc.log.Debug("started Workers method", "service", service) + workers, err := rpc.srv.Workers(service) + if err != nil { + return err + } + + list.Workers = make([]roadrunner.ProcessState, 0) + for _, w := range workers { + ps, err := roadrunner.WorkerProcessState(w) + if err != nil { + continue + } + + list.Workers = append(list.Workers, ps) + } + rpc.log.Debug("list of workers", "workers", list.Workers) + rpc.log.Debug("successfully finished Workers method") + return nil +} diff --git a/plugins/informer/tests/.rr-informer.yaml b/plugins/informer/tests/.rr-informer.yaml new file mode 100644 index 00000000..83ecd582 --- /dev/null +++ b/plugins/informer/tests/.rr-informer.yaml @@ -0,0 +1,13 @@ +server: + command: "php ../../../tests/client.php echo pipes" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "pipes" + relayTimeout: "20s" + +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false
\ No newline at end of file diff --git a/plugins/informer/tests/informer_test.go b/plugins/informer/tests/informer_test.go new file mode 100644 index 00000000..fbe33a7d --- /dev/null +++ b/plugins/informer/tests/informer_test.go @@ -0,0 +1,97 @@ +package tests + +import ( + "net" + "net/rpc" + "os" + "os/signal" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/goridge/v2" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/informer" + "github.com/spiral/roadrunner/v2/plugins/logger" + rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestInformerInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + cfg := &config.Viper{ + Path: ".rr-informer.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &server.Plugin{}, + &logger.ZapLogger{}, + &informer.Plugin{}, + &rpcPlugin.Plugin{}, + &Plugin1{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + tt := time.NewTimer(time.Second * 15) + + t.Run("InformerRpcTest", informerRPCTest) + + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } +} + +func informerRPCTest(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6001") + assert.NoError(t, err) + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + // WorkerList contains list of workers. + list := struct { + // Workers is list of workers. + Workers []roadrunner.ProcessState `json:"workers"` + }{} + + err = client.Call("informer.Workers", "informer.plugin1", &list) + assert.NoError(t, err) + assert.Len(t, list.Workers, 10) +} diff --git a/plugins/informer/tests/test_plugin.go b/plugins/informer/tests/test_plugin.go new file mode 100644 index 00000000..473b6de7 --- /dev/null +++ b/plugins/informer/tests/test_plugin.go @@ -0,0 +1,58 @@ +package tests + +import ( + "context" + "time" + + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +var testPoolConfig = roadrunner.PoolConfig{ + NumWorkers: 10, + MaxJobs: 100, + AllocateTimeout: time.Second * 10, + DestroyTimeout: time.Second * 10, + Supervisor: &roadrunner.SupervisorConfig{ + WatchTick: 60, + TTL: 1000, + IdleTTL: 10, + ExecTTL: 10, + MaxWorkerMemory: 1000, + }, +} + +// Gauge ////////////// +type Plugin1 struct { + config config.Configurer + server server.Server +} + +func (p1 *Plugin1) Init(cfg config.Configurer, server server.Server) error { + p1.config = cfg + p1.server = server + return nil +} + +func (p1 *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (p1 *Plugin1) Stop() error { + return nil +} + +func (p1 *Plugin1) Name() string { + return "informer.plugin1" +} + +func (p1 *Plugin1) Workers() []roadrunner.WorkerBase { + pool, err := p1.server.NewWorkerPool(context.Background(), testPoolConfig, nil) + if err != nil { + panic(err) + } + + return pool.Workers() +} diff --git a/plugins/logger/config.go b/plugins/logger/config.go new file mode 100644 index 00000000..f7a5742c --- /dev/null +++ b/plugins/logger/config.go @@ -0,0 +1,94 @@ +package logger + +import ( + "strings" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// ChannelConfig configures loggers per channel. +type ChannelConfig struct { + // Dedicated channels per logger. By default logger allocated via named logger. + Channels map[string]Config `json:"channels" yaml:"channels"` +} + +type Config struct { + // Mode configures logger based on some default template (development, production, off). + Mode string `json:"mode" yaml:"mode"` + + // Level is the minimum enabled logging level. Note that this is a dynamic + // level, so calling ChannelConfig.Level.SetLevel will atomically change the log + // level of all loggers descended from this config. + Level string `json:"level" yaml:"level"` + + // Encoding sets the logger's encoding. Valid values are "json" and + // "console", as well as any third-party encodings registered via + // RegisterEncoder. + Encoding string `json:"encoding" yaml:"encoding"` + + // Output is a list of URLs or file paths to write logging output to. + // See Open for details. + Output []string `json:"output" yaml:"output"` + + // ErrorOutput is a list of URLs to write internal logger errors to. + // The default is standard error. + // + // Note that this setting only affects internal errors; for sample code that + // sends error-level logs to a different location from info- and debug-level + // logs, see the package-level AdvancedConfiguration example. + ErrorOutput []string `json:"errorOutput" yaml:"errorOutput"` +} + +// ZapConfig converts config into Zap configuration. +func (cfg *Config) BuildLogger() (*zap.Logger, error) { + var zCfg zap.Config + switch strings.ToLower(cfg.Mode) { + case "off", "none": + return zap.NewNop(), nil + case "production": + zCfg = zap.NewProductionConfig() + case "development": + zCfg = zap.NewDevelopmentConfig() + default: + zCfg = zap.Config{ + Level: zap.NewAtomicLevelAt(zap.DebugLevel), + Encoding: "console", + EncoderConfig: zapcore.EncoderConfig{ + MessageKey: "message", + LevelKey: "level", + TimeKey: "time", + NameKey: "name", + EncodeName: ColoredHashedNameEncoder, + EncodeLevel: ColoredLevelEncoder, + EncodeTime: UTCTimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + } + } + + if cfg.Level != "" { + level := zap.NewAtomicLevel() + if err := level.UnmarshalText([]byte(cfg.Level)); err == nil { + zCfg.Level = level + } + } + + if cfg.Encoding != "" { + zCfg.Encoding = cfg.Encoding + } + + if len(cfg.Output) != 0 { + zCfg.OutputPaths = cfg.Output + } + + if len(cfg.ErrorOutput) != 0 { + zCfg.ErrorOutputPaths = cfg.ErrorOutput + } + + // todo: https://github.com/uber-go/zap/blob/master/FAQ.md#does-zap-support-log-rotation + + return zCfg.Build() +} diff --git a/plugins/logger/encoder.go b/plugins/logger/encoder.go new file mode 100644 index 00000000..4ff583c4 --- /dev/null +++ b/plugins/logger/encoder.go @@ -0,0 +1,66 @@ +package logger + +import ( + "hash/fnv" + "strings" + "time" + + "github.com/fatih/color" + "go.uber.org/zap/zapcore" +) + +var colorMap = []func(string, ...interface{}) string{ + color.HiYellowString, + color.HiGreenString, + color.HiBlueString, + color.HiRedString, + color.HiCyanString, + color.HiMagentaString, +} + +// ColoredLevelEncoder colorizes log levels. +func ColoredLevelEncoder(level zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { + switch level { + case zapcore.DebugLevel: + enc.AppendString(color.HiWhiteString(level.CapitalString())) + case zapcore.InfoLevel: + enc.AppendString(color.HiCyanString(level.CapitalString())) + case zapcore.WarnLevel: + enc.AppendString(color.HiYellowString(level.CapitalString())) + case zapcore.ErrorLevel, zapcore.DPanicLevel: + enc.AppendString(color.HiRedString(level.CapitalString())) + case zapcore.PanicLevel, zapcore.FatalLevel: + enc.AppendString(color.HiMagentaString(level.CapitalString())) + } +} + +// ColoredNameEncoder colorizes service names. +func ColoredNameEncoder(s string, enc zapcore.PrimitiveArrayEncoder) { + if len(s) < 12 { + s += strings.Repeat(" ", 12-len(s)) + } + + enc.AppendString(color.HiGreenString(s)) +} + +// ColoredHashedNameEncoder colorizes service names and assigns different colors to different names. +func ColoredHashedNameEncoder(s string, enc zapcore.PrimitiveArrayEncoder) { + if len(s) < 12 { + s += strings.Repeat(" ", 12-len(s)) + } + + colorID := stringHash(s, len(colorMap)) + enc.AppendString(colorMap[colorID](s)) +} + +// UTCTimeEncoder encodes time into short UTC specific timestamp. +func UTCTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.UTC().Format("2006/01/02 15:04:05")) +} + +// returns string hash +func stringHash(name string, base int) int { + h := fnv.New32a() + _, _ = h.Write([]byte(name)) + return int(h.Sum32()) % base +} diff --git a/plugins/logger/plugin.go b/plugins/logger/plugin.go new file mode 100644 index 00000000..64b77a64 --- /dev/null +++ b/plugins/logger/plugin.go @@ -0,0 +1,65 @@ +package logger + +import ( + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/plugins/config" + "go.uber.org/zap" +) + +// PluginName declares plugin name. +const PluginName = "logs" + +// ZapLogger manages zap logger. +type ZapLogger struct { + base *zap.Logger + cfg Config + channels ChannelConfig +} + +// Init logger service. +func (z *ZapLogger) Init(cfg config.Configurer) error { + err := cfg.UnmarshalKey(PluginName, &z.cfg) + if err != nil { + return err + } + + err = cfg.UnmarshalKey(PluginName, &z.channels) + if err != nil { + return err + } + + z.base, err = z.cfg.BuildLogger() + return err +} + +// DefaultLogger returns default logger. +func (z *ZapLogger) DefaultLogger() (log.Logger, error) { + return NewZapAdapter(z.base), nil +} + +// NamedLogger returns logger dedicated to the specific channel. Similar to Named() but also reads the core params. +func (z *ZapLogger) NamedLogger(name string) (log.Logger, error) { + if cfg, ok := z.channels.Channels[name]; ok { + l, err := cfg.BuildLogger() + if err != nil { + return nil, err + } + return NewZapAdapter(l), nil + } + + return NewZapAdapter(z.base.Named(name)), nil +} + +// NamedLogger returns logger dedicated to the specific channel. Similar to Named() but also reads the core params. +func (z *ZapLogger) ServiceLogger(n endure.Named) (log.Logger, error) { + return z.NamedLogger(n.Name()) +} + +// Provides declares factory methods. +func (z *ZapLogger) Provides() []interface{} { + return []interface{}{ + z.ServiceLogger, + z.DefaultLogger, + } +} diff --git a/plugins/logger/tests/.rr.yaml b/plugins/logger/tests/.rr.yaml new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/plugins/logger/tests/.rr.yaml diff --git a/plugins/logger/tests/logger_test.go b/plugins/logger/tests/logger_test.go new file mode 100644 index 00000000..1df74c47 --- /dev/null +++ b/plugins/logger/tests/logger_test.go @@ -0,0 +1,74 @@ +package tests + +import ( + "os" + "os/signal" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/stretchr/testify/assert" +) + +func TestLogger(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = ".rr.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + errCh, err := container.Serve() + if err != nil { + t.Fatal(err) + } + + // stop by CTRL+C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + // stop after 10 seconds + tt := time.NewTicker(time.Second * 10) + + for { + select { + case e := <-errCh: + assert.NoError(t, e.Error) + assert.NoError(t, container.Stop()) + return + case <-c: + er := container.Stop() + if er != nil { + panic(er) + } + return + case <-tt.C: + tt.Stop() + assert.NoError(t, container.Stop()) + return + } + } +} diff --git a/plugins/logger/tests/plugin.go b/plugins/logger/tests/plugin.go new file mode 100644 index 00000000..32238f63 --- /dev/null +++ b/plugins/logger/tests/plugin.go @@ -0,0 +1,40 @@ +package tests + +import ( + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +type Plugin struct { + config config.Configurer + log log.Logger +} + +func (p1 *Plugin) Init(cfg config.Configurer, log log.Logger) error { + p1.config = cfg + p1.log = log + return nil +} + +func (p1 *Plugin) Serve() chan error { + errCh := make(chan error, 1) + p1.log.Error("error", "test", errors.E(errors.Str("test"))) + p1.log.Info("error", "test", errors.E(errors.Str("test"))) + p1.log.Debug("error", "test", errors.E(errors.Str("test"))) + p1.log.Warn("error", "test", errors.E(errors.Str("test"))) + + p1.log.Error("error", "test") + p1.log.Info("error", "test") + p1.log.Debug("error", "test") + p1.log.Warn("error", "test") + return errCh +} + +func (p1 *Plugin) Stop() error { + return nil +} + +func (p1 *Plugin) Name() string { + return "logger_plugin" +} diff --git a/plugins/logger/zap_adapter.go b/plugins/logger/zap_adapter.go new file mode 100644 index 00000000..41e6d8f9 --- /dev/null +++ b/plugins/logger/zap_adapter.go @@ -0,0 +1,57 @@ +package logger + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/interfaces/log" + "go.uber.org/zap" +) + +type ZapAdapter struct { + zl *zap.Logger +} + +// Create NewZapAdapter which uses general log interface +func NewZapAdapter(zapLogger *zap.Logger) *ZapAdapter { + return &ZapAdapter{ + zl: zapLogger.WithOptions(zap.AddCallerSkip(1)), + } +} + +func (log *ZapAdapter) fields(keyvals []interface{}) []zap.Field { + // we should have even number of keys and values + if len(keyvals)%2 != 0 { + return []zap.Field{zap.Error(fmt.Errorf("odd number of keyvals pairs: %v", keyvals))} + } + + var fields []zap.Field + for i := 0; i < len(keyvals); i += 2 { + key, ok := keyvals[i].(string) + if !ok { + key = fmt.Sprintf("%v", keyvals[i]) + } + fields = append(fields, zap.Any(key, keyvals[i+1])) + } + + return fields +} + +func (log *ZapAdapter) Debug(msg string, keyvals ...interface{}) { + log.zl.Debug(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) Info(msg string, keyvals ...interface{}) { + log.zl.Info(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) Warn(msg string, keyvals ...interface{}) { + log.zl.Warn(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) Error(msg string, keyvals ...interface{}) { + log.zl.Error(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) With(keyvals ...interface{}) log.Logger { + return NewZapAdapter(log.zl.With(log.fields(keyvals)...)) +} diff --git a/plugins/metrics/config.go b/plugins/metrics/config.go new file mode 100644 index 00000000..933b7eb8 --- /dev/null +++ b/plugins/metrics/config.go @@ -0,0 +1,135 @@ +package metrics + +import ( + "fmt" + + "github.com/prometheus/client_golang/prometheus" +) + +// Config configures metrics service. +type Config struct { + // Address to listen + Address string + + // Collect define application specific metrics. + Collect map[string]Collector +} + +type NamedCollector struct { + // Name of the collector + Name string `json:"name"` + + // Collector structure + Collector `json:"collector"` +} + +// CollectorType represents prometheus collector types +type CollectorType string + +const ( + // Histogram type + Histogram CollectorType = "histogram" + + // Gauge type + Gauge CollectorType = "gauge" + + // Counter type + Counter CollectorType = "counter" + + // Summary type + Summary CollectorType = "summary" +) + +// Collector describes single application specific metric. +type Collector struct { + // Namespace of the metric. + Namespace string `json:"namespace"` + // Subsystem of the metric. + Subsystem string `json:"subsystem"` + // Collector type (histogram, gauge, counter, summary). + Type CollectorType `json:"type"` + // Help of collector. + Help string `json:"help"` + // Labels for vectorized metrics. + Labels []string `json:"labels"` + // Buckets for histogram metric. + Buckets []float64 `json:"buckets"` +} + +// register application specific metrics. +func (c *Config) getCollectors() (map[string]prometheus.Collector, error) { + if c.Collect == nil { + return nil, nil + } + + collectors := make(map[string]prometheus.Collector) + + for name, m := range c.Collect { + var collector prometheus.Collector + switch m.Type { + case Histogram: + opts := prometheus.HistogramOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + Buckets: m.Buckets, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewHistogramVec(opts, m.Labels) + } else { + collector = prometheus.NewHistogram(opts) + } + case Gauge: + opts := prometheus.GaugeOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewGaugeVec(opts, m.Labels) + } else { + collector = prometheus.NewGauge(opts) + } + case Counter: + opts := prometheus.CounterOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewCounterVec(opts, m.Labels) + } else { + collector = prometheus.NewCounter(opts) + } + case Summary: + opts := prometheus.SummaryOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewSummaryVec(opts, m.Labels) + } else { + collector = prometheus.NewSummary(opts) + } + default: + return nil, fmt.Errorf("invalid metric type `%s` for `%s`", m.Type, name) + } + + collectors[name] = collector + } + + return collectors, nil +} + +func (c *Config) InitDefaults() { + +} diff --git a/plugins/metrics/config_test.go b/plugins/metrics/config_test.go new file mode 100644 index 00000000..665ec9cd --- /dev/null +++ b/plugins/metrics/config_test.go @@ -0,0 +1,89 @@ +package metrics + +import ( + "bytes" + "testing" + + j "github.com/json-iterator/go" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +var json = j.ConfigCompatibleWithStandardLibrary + +func Test_Config_Hydrate_Error1(t *testing.T) { + cfg := `{"request": {"From": "Something"}}` + c := &Config{} + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + if err != nil { + t.Fatal(err) + } +} + +func Test_Config_Hydrate_Error2(t *testing.T) { + cfg := `{"dir": "/dir/"` + c := &Config{} + + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + assert.Error(t, err) +} + +func Test_Config_Metrics(t *testing.T) { + cfg := `{ +"collect":{ + "metric1":{"type": "gauge"}, + "metric2":{ "type": "counter"}, + "metric3":{"type": "summary"}, + "metric4":{"type": "histogram"} +} +}` + c := &Config{} + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + if err != nil { + t.Fatal(err) + } + + m, err := c.getCollectors() + assert.NoError(t, err) + + assert.IsType(t, prometheus.NewGauge(prometheus.GaugeOpts{}), m["metric1"]) + assert.IsType(t, prometheus.NewCounter(prometheus.CounterOpts{}), m["metric2"]) + assert.IsType(t, prometheus.NewSummary(prometheus.SummaryOpts{}), m["metric3"]) + assert.IsType(t, prometheus.NewHistogram(prometheus.HistogramOpts{}), m["metric4"]) +} + +func Test_Config_MetricsVector(t *testing.T) { + cfg := `{ +"collect":{ + "metric1":{"type": "gauge","labels":["label"]}, + "metric2":{ "type": "counter","labels":["label"]}, + "metric3":{"type": "summary","labels":["label"]}, + "metric4":{"type": "histogram","labels":["label"]} +} +}` + c := &Config{} + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + if err != nil { + t.Fatal(err) + } + + m, err := c.getCollectors() + assert.NoError(t, err) + + assert.IsType(t, prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{}), m["metric1"]) + assert.IsType(t, prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{}), m["metric2"]) + assert.IsType(t, prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{}), m["metric3"]) + assert.IsType(t, prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{}), m["metric4"]) +} diff --git a/plugins/metrics/doc.go b/plugins/metrics/doc.go new file mode 100644 index 00000000..1abe097a --- /dev/null +++ b/plugins/metrics/doc.go @@ -0,0 +1 @@ +package metrics diff --git a/plugins/metrics/plugin.go b/plugins/metrics/plugin.go new file mode 100644 index 00000000..c115826b --- /dev/null +++ b/plugins/metrics/plugin.go @@ -0,0 +1,230 @@ +package metrics + +import ( + "context" + "crypto/tls" + "net/http" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/interfaces/metrics" + "github.com/spiral/roadrunner/v2/plugins/config" + "golang.org/x/sys/cpu" +) + +const ( + // PluginName declares plugin name. + PluginName = "metrics" + // maxHeaderSize declares max header size for prometheus server + maxHeaderSize = 1024 * 1024 * 100 // 104MB +) + +type statsProvider struct { + collectors []prometheus.Collector + name string +} + +// Plugin to manage application metrics using Prometheus. +type Plugin struct { + cfg Config + log log.Logger + mu sync.Mutex // all receivers are pointers + http *http.Server + collectors sync.Map // all receivers are pointers + registry *prometheus.Registry +} + +// Init service. +func (m *Plugin) Init(cfg config.Configurer, log log.Logger) error { + const op = errors.Op("Metrics Init") + err := cfg.UnmarshalKey(PluginName, &m.cfg) + if err != nil { + return err + } + + // TODO figure out what is Init + m.cfg.InitDefaults() + + m.log = log + m.registry = prometheus.NewRegistry() + + // Default + err = m.registry.Register(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{})) + if err != nil { + return errors.E(op, err) + } + + // Default + err = m.registry.Register(prometheus.NewGoCollector()) + if err != nil { + return errors.E(op, err) + } + + collectors, err := m.cfg.getCollectors() + if err != nil { + return errors.E(op, err) + } + + // Register invocation will be later in the Serve method + for k, v := range collectors { + m.collectors.Store(k, statsProvider{ + collectors: []prometheus.Collector{v}, + name: k, + }) + } + return nil +} + +// Register new prometheus collector. +func (m *Plugin) Register(c prometheus.Collector) error { + return m.registry.Register(c) +} + +// Serve prometheus metrics service. +func (m *Plugin) Serve() chan error { + errCh := make(chan error, 1) + m.collectors.Range(func(key, value interface{}) bool { + // key - name + // value - statsProvider struct + c := value.(statsProvider) + for _, v := range c.collectors { + if err := m.registry.Register(v); err != nil { + errCh <- err + return false + } + } + + return true + }) + + var topCipherSuites []uint16 + var defaultCipherSuitesTLS13 []uint16 + + hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + + if hasGCMAsm { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } + + DefaultCipherSuites := make([]uint16, 0, 22) + DefaultCipherSuites = append(DefaultCipherSuites, topCipherSuites...) + DefaultCipherSuites = append(DefaultCipherSuites, defaultCipherSuitesTLS13...) + + m.http = &http.Server{ + Addr: m.cfg.Address, + Handler: promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{}), + IdleTimeout: time.Hour * 24, + ReadTimeout: time.Minute * 60, + MaxHeaderBytes: maxHeaderSize, + ReadHeaderTimeout: time.Minute * 60, + WriteTimeout: time.Minute * 60, + TLSConfig: &tls.Config{ + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.CurveP384, + tls.CurveP521, + tls.X25519, + }, + CipherSuites: DefaultCipherSuites, + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + }, + } + + go func() { + err := m.http.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + errCh <- err + return + } + }() + + return errCh +} + +// Stop prometheus metrics service. +func (m *Plugin) Stop() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.http != nil { + // timeout is 10 seconds + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + err := m.http.Shutdown(ctx) + if err != nil { + // Function should be Stop() error + m.log.Error("stop error", "error", errors.Errorf("error shutting down the metrics server: error %v", err)) + } + } + return nil +} + +// Collects used to collect all plugins which implement metrics.StatProvider interface (and Named) +func (m *Plugin) Collects() []interface{} { + return []interface{}{ + m.AddStatProvider, + } +} + +// Collector returns application specific collector by name or nil if collector not found. +func (m *Plugin) AddStatProvider(name endure.Named, stat metrics.StatProvider) error { + m.collectors.Store(name.Name(), statsProvider{ + collectors: stat.MetricsCollector(), + name: name.Name(), + }) + return nil +} + +// RPC interface satisfaction +func (m *Plugin) Name() string { + return PluginName +} + +// RPC interface satisfaction +func (m *Plugin) RPC() interface{} { + return &rpcServer{ + svc: m, + log: m.log, + } +} diff --git a/plugins/metrics/rpc.go b/plugins/metrics/rpc.go new file mode 100644 index 00000000..b8897098 --- /dev/null +++ b/plugins/metrics/rpc.go @@ -0,0 +1,294 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" +) + +type rpcServer struct { + svc *Plugin + log log.Logger +} + +// Metric represent single metric produced by the application. +type Metric struct { + // Collector name. + Name string + + // Collector value. + Value float64 + + // Labels associated with metric. Only for vector metrics. Must be provided in a form of label values. + Labels []string +} + +// Add new metric to the designated collector. +func (rpc *rpcServer) Add(m *Metric, ok *bool) error { + const op = errors.Op("Add metric") + rpc.log.Info("Adding metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + rpc.log.Error("undefined collector", "collector", m.Name) + return errors.E(op, errors.Errorf("undefined collector %s, try first Declare the desired collector", m.Name)) + } + + switch c := c.(type) { + case prometheus.Gauge: + c.Add(m.Value) + + case *prometheus.GaugeVec: + if len(m.Labels) == 0 { + rpc.log.Error("required labels for collector", "collector", m.Name) + return errors.E(op, errors.Errorf("required labels for collector %s", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Add(m.Value) + case prometheus.Counter: + c.Add(m.Value) + + case *prometheus.CounterVec: + if len(m.Labels) == 0 { + return errors.E(op, errors.Errorf("required labels for collector `%s`", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Add(m.Value) + + default: + return errors.E(op, errors.Errorf("collector %s does not support method `Add`", m.Name)) + } + + // RPC, set ok to true as return value. Need by rpc.Call reply argument + *ok = true + rpc.log.Info("new metric successfully added", "name", m.Name, "labels", m.Labels, "value", m.Value) + return nil +} + +// Sub subtract the value from the specific metric (gauge only). +func (rpc *rpcServer) Sub(m *Metric, ok *bool) error { + const op = errors.Op("Subtracting metric") + rpc.log.Info("Subtracting value from metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + rpc.log.Error("undefined collector", "name", m.Name, "value", m.Value, "labels", m.Labels) + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + if c == nil { + // can it be nil ??? I guess can't + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + + switch c := c.(type) { + case prometheus.Gauge: + c.Sub(m.Value) + + case *prometheus.GaugeVec: + if len(m.Labels) == 0 { + rpc.log.Error("required labels for collector, but none was provided", "name", m.Name, "value", m.Value) + return errors.E(op, errors.Errorf("required labels for collector %s", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Sub(m.Value) + default: + return errors.E(op, errors.Errorf("collector `%s` does not support method `Sub`", m.Name)) + } + rpc.log.Info("Subtracting operation applied successfully", "name", m.Name, "labels", m.Labels, "value", m.Value) + + *ok = true + return nil +} + +// Observe the value (histogram and summary only). +func (rpc *rpcServer) Observe(m *Metric, ok *bool) error { + const op = errors.Op("Observe metrics") + rpc.log.Info("Observing metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + rpc.log.Error("undefined collector", "name", m.Name, "value", m.Value, "labels", m.Labels) + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + if c == nil { + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + + switch c := c.(type) { + case *prometheus.SummaryVec: + if len(m.Labels) == 0 { + return errors.E(op, errors.Errorf("required labels for collector `%s`", m.Name)) + } + + observer, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + return errors.E(op, err) + } + observer.Observe(m.Value) + + case prometheus.Histogram: + c.Observe(m.Value) + + case *prometheus.HistogramVec: + if len(m.Labels) == 0 { + return errors.E(op, errors.Errorf("required labels for collector `%s`", m.Name)) + } + + observer, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + observer.Observe(m.Value) + default: + return errors.E(op, errors.Errorf("collector `%s` does not support method `Observe`", m.Name)) + } + + rpc.log.Info("observe operation finished successfully", "name", m.Name, "labels", m.Labels, "value", m.Value) + + *ok = true + return nil +} + +// Declare is used to register new collector in prometheus +// THE TYPES ARE: +// NamedCollector -> Collector with the name +// bool -> RPC reply value +// RETURNS: +// error +func (rpc *rpcServer) Declare(nc *NamedCollector, ok *bool) error { + const op = errors.Op("Declare metric") + rpc.log.Info("Declaring new metric", "name", nc.Name, "type", nc.Type, "namespace", nc.Namespace) + _, exist := rpc.svc.collectors.Load(nc.Name) + if exist { + rpc.log.Error("metric with provided name already exist", "name", nc.Name, "type", nc.Type, "namespace", nc.Namespace) + return errors.E(op, errors.Errorf("tried to register existing collector with the name `%s`", nc.Name)) + } + + var collector prometheus.Collector + switch nc.Type { + case Histogram: + opts := prometheus.HistogramOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + Buckets: nc.Buckets, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewHistogramVec(opts, nc.Labels) + } else { + collector = prometheus.NewHistogram(opts) + } + case Gauge: + opts := prometheus.GaugeOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewGaugeVec(opts, nc.Labels) + } else { + collector = prometheus.NewGauge(opts) + } + case Counter: + opts := prometheus.CounterOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewCounterVec(opts, nc.Labels) + } else { + collector = prometheus.NewCounter(opts) + } + case Summary: + opts := prometheus.SummaryOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewSummaryVec(opts, nc.Labels) + } else { + collector = prometheus.NewSummary(opts) + } + + default: + return errors.E(op, errors.Errorf("unknown collector type %s", nc.Type)) + } + + // add collector to sync.Map + rpc.svc.collectors.Store(nc.Name, collector) + // that method might panic, we handle it by recover + err := rpc.svc.Register(collector) + if err != nil { + *ok = false + return errors.E(op, err) + } + + rpc.log.Info("metric successfully added", "name", nc.Name, "type", nc.Type, "namespace", nc.Namespace) + + *ok = true + return nil +} + +// Set the metric value (only for gaude). +func (rpc *rpcServer) Set(m *Metric, ok *bool) (err error) { + const op = errors.Op("Set metric") + rpc.log.Info("Observing metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + if c == nil { + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + + switch c := c.(type) { + case prometheus.Gauge: + c.Set(m.Value) + + case *prometheus.GaugeVec: + if len(m.Labels) == 0 { + rpc.log.Error("required labels for collector", "collector", m.Name) + return errors.E(op, errors.Errorf("required labels for collector %s", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Set(m.Value) + + default: + return errors.E(op, errors.Errorf("collector `%s` does not support method Set", m.Name)) + } + + rpc.log.Info("set operation finished successfully", "name", m.Name, "labels", m.Labels, "value", m.Value) + + *ok = true + return nil +} diff --git a/plugins/metrics/tests/.rr-test.yaml b/plugins/metrics/tests/.rr-test.yaml new file mode 100644 index 00000000..79343e3c --- /dev/null +++ b/plugins/metrics/tests/.rr-test.yaml @@ -0,0 +1,13 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false + +metrics: + # prometheus client address (path /metrics added automatically) + address: localhost:2112 + collect: + app_metric: + type: histogram + help: "Custom application metric" + labels: [ "type" ] + buckets: [ 0.1, 0.2, 0.3, 1.0 ]
\ No newline at end of file diff --git a/plugins/metrics/tests/docker-compose.yml b/plugins/metrics/tests/docker-compose.yml new file mode 100644 index 00000000..610633b4 --- /dev/null +++ b/plugins/metrics/tests/docker-compose.yml @@ -0,0 +1,7 @@ +version: '3.7' + +services: + prometheus: + image: prom/prometheus + ports: + - 9090:9090 diff --git a/plugins/metrics/tests/metrics_test.go b/plugins/metrics/tests/metrics_test.go new file mode 100644 index 00000000..f9014c95 --- /dev/null +++ b/plugins/metrics/tests/metrics_test.go @@ -0,0 +1,739 @@ +package tests + +import ( + "io/ioutil" + "net" + "net/http" + "net/rpc" + "os" + "os/signal" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/goridge/v2" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/metrics" + rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/stretchr/testify/assert" +) + +const dialAddr = "127.0.0.1:6001" +const dialNetwork = "tcp" +const getAddr = "http://localhost:2112/metrics" + +// get request and return body +func get() (string, error) { + r, err := http.Get(getAddr) + if err != nil { + return "", err + } + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return "", err + } + + err = r.Body.Close() + if err != nil { + return "", err + } + // unsafe + return string(b), err +} + +func TestMetricsInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + cfg := &config.Viper{} + cfg.Prefix = "rr" + cfg.Path = ".rr-test.yaml" + + err = cont.RegisterAll( + cfg, + &metrics.Plugin{}, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &Plugin1{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + tt := time.NewTimer(time.Second * 5) + + out, err := get() + assert.NoError(t, err) + + assert.Contains(t, out, "go_gc_duration_seconds") + + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } +} + +func TestMetricsGaugeCollector(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + cfg := &config.Viper{} + cfg.Prefix = "rr" + cfg.Path = ".rr-test.yaml" + + err = cont.RegisterAll( + cfg, + &metrics.Plugin{}, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + &Plugin1{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + time.Sleep(time.Second) + tt := time.NewTimer(time.Second * 5) + + out, err := get() + assert.NoError(t, err) + assert.Contains(t, out, "my_gauge 100") + assert.Contains(t, out, "my_gauge2 100") + + out, err = get() + assert.NoError(t, err) + assert.Contains(t, out, "go_gc_duration_seconds") + + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } +} + +func TestMetricsDifferentRPCCalls(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + cfg := &config.Viper{} + cfg.Prefix = "rr" + cfg.Path = ".rr-test.yaml" + + err = cont.RegisterAll( + cfg, + &metrics.Plugin{}, + &rpcPlugin.Plugin{}, + &logger.ZapLogger{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + go func() { + tt := time.NewTimer(time.Minute * 3) + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + t.Run("DeclareMetric", declareMetricsTest) + genericOut, err := get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "test_metrics_named_collector") + + t.Run("AddMetric", addMetricsTest) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "test_metrics_named_collector 10000") + + t.Run("SetMetric", setMetric) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "user_gauge_collector 100") + + t.Run("VectorMetric", vectorMetric) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "gauge_2_collector{section=\"first\",type=\"core\"} 100") + + t.Run("MissingSection", missingSection) + t.Run("SetWithoutLabels", setWithoutLabels) + t.Run("SetOnHistogram", setOnHistogram) + t.Run("MetricSub", subMetric) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "sub_gauge_subMetric 1") + + t.Run("SubVector", subVector) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "sub_gauge_subVector{section=\"first\",type=\"core\"} 1") + + t.Run("RegisterHistogram", registerHistogram) + + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, `TYPE histogram_registerHistogram`) + + // check buckets + assert.Contains(t, genericOut, `histogram_registerHistogram_bucket{le="0.1"} 0`) + assert.Contains(t, genericOut, `histogram_registerHistogram_bucket{le="0.2"} 0`) + assert.Contains(t, genericOut, `histogram_registerHistogram_bucket{le="0.5"} 0`) + assert.Contains(t, genericOut, `histogram_registerHistogram_bucket{le="+Inf"} 0`) + assert.Contains(t, genericOut, `histogram_registerHistogram_sum 0`) + assert.Contains(t, genericOut, `histogram_registerHistogram_count 0`) + + t.Run("CounterMetric", counterMetric) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "HELP default_default_counter_CounterMetric test_counter") + assert.Contains(t, genericOut, `default_default_counter_CounterMetric{section="section2",type="type2"}`) + + t.Run("ObserveMetric", observeMetric) + genericOut, err = get() + assert.NoError(t, err) + assert.Contains(t, genericOut, "observe_observeMetric") + + t.Run("ObserveMetricNotEnoughLabels", observeMetricNotEnoughLabels) + + close(sig) +} + +func observeMetricNotEnoughLabels(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "observe_observeMetricNotEnoughLabels", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Help: "test_observe", + Type: metrics.Histogram, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + assert.Error(t, client.Call("metrics.Observe", metrics.Metric{ + Name: "observe_observeMetric", + Value: 100.0, + Labels: []string{"test"}, + }, &ret)) + assert.False(t, ret) +} + +func observeMetric(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "observe_observeMetric", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Help: "test_observe", + Type: metrics.Histogram, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + assert.NoError(t, client.Call("metrics.Observe", metrics.Metric{ + Name: "observe_observeMetric", + Value: 100.0, + Labels: []string{"test", "test2"}, + }, &ret)) + assert.True(t, ret) +} + +func counterMetric(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "counter_CounterMetric", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Help: "test_counter", + Type: metrics.Counter, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + + ret = false + + assert.NoError(t, client.Call("metrics.Add", metrics.Metric{ + Name: "counter_CounterMetric", + Value: 100.0, + Labels: []string{"type2", "section2"}, + }, &ret)) + assert.True(t, ret) +} + +func registerHistogram(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "histogram_registerHistogram", + Collector: metrics.Collector{ + Help: "test_histogram", + Type: metrics.Histogram, + Buckets: []float64{0.1, 0.2, 0.5}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + + ret = false + + m := metrics.Metric{ + Name: "histogram_registerHistogram", + Value: 10000, + Labels: nil, + } + + err = client.Call("metrics.Add", m, &ret) + assert.Error(t, err) + assert.False(t, ret) +} + +func subVector(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "sub_gauge_subVector", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Gauge, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + m := metrics.Metric{ + Name: "sub_gauge_subVector", + Value: 100000, + Labels: []string{"core", "first"}, + } + + err = client.Call("metrics.Add", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + m = metrics.Metric{ + Name: "sub_gauge_subVector", + Value: 99999, + Labels: []string{"core", "first"}, + } + + err = client.Call("metrics.Sub", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) +} + +func subMetric(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "sub_gauge_subMetric", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Gauge, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + m := metrics.Metric{ + Name: "sub_gauge_subMetric", + Value: 100000, + } + + err = client.Call("metrics.Add", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + m = metrics.Metric{ + Name: "sub_gauge_subMetric", + Value: 99999, + } + + err = client.Call("metrics.Sub", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) +} + +func setOnHistogram(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "histogram_setOnHistogram", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Histogram, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + + ret = false + + m := metrics.Metric{ + Name: "gauge_setOnHistogram", + Value: 100.0, + } + + err = client.Call("metrics.Set", m, &ret) // expected 2 label values but got 1 in []string{"missing"} + assert.Error(t, err) + assert.False(t, ret) +} + +func setWithoutLabels(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "gauge_setWithoutLabels", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Gauge, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + + ret = false + + m := metrics.Metric{ + Name: "gauge_setWithoutLabels", + Value: 100.0, + } + + err = client.Call("metrics.Set", m, &ret) // expected 2 label values but got 1 in []string{"missing"} + assert.Error(t, err) + assert.False(t, ret) +} + +func missingSection(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "gauge_missing_section_collector", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Gauge, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + + ret = false + + m := metrics.Metric{ + Name: "gauge_missing_section_collector", + Value: 100.0, + Labels: []string{"missing"}, + } + + err = client.Call("metrics.Set", m, &ret) // expected 2 label values but got 1 in []string{"missing"} + assert.Error(t, err) + assert.False(t, ret) +} + +func vectorMetric(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "gauge_2_collector", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Gauge, + Labels: []string{"type", "section"}, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + + ret = false + + m := metrics.Metric{ + Name: "gauge_2_collector", + Value: 100.0, + Labels: []string{"core", "first"}, + } + + err = client.Call("metrics.Set", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) +} + +func setMetric(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "user_gauge_collector", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Gauge, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + m := metrics.Metric{ + Name: "user_gauge_collector", + Value: 100.0, + } + + err = client.Call("metrics.Set", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) +} + +func addMetricsTest(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + m := metrics.Metric{ + Name: "test_metrics_named_collector", + Value: 10000, + Labels: nil, + } + + err = client.Call("metrics.Add", m, &ret) + assert.NoError(t, err) + assert.True(t, ret) +} + +func declareMetricsTest(t *testing.T) { + conn, err := net.Dial(dialNetwork, dialAddr) + assert.NoError(t, err) + defer func() { + _ = conn.Close() + }() + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret bool + + nc := metrics.NamedCollector{ + Name: "test_metrics_named_collector", + Collector: metrics.Collector{ + Namespace: "default", + Subsystem: "default", + Type: metrics.Counter, + Help: "NO HELP!", + Labels: nil, + Buckets: nil, + }, + } + + err = client.Call("metrics.Declare", nc, &ret) + assert.NoError(t, err) + assert.True(t, ret) +} diff --git a/plugins/metrics/tests/plugin1.go b/plugins/metrics/tests/plugin1.go new file mode 100644 index 00000000..b48c415d --- /dev/null +++ b/plugins/metrics/tests/plugin1.go @@ -0,0 +1,46 @@ +package tests + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// Gauge ////////////// +type Plugin1 struct { + config config.Configurer +} + +func (p1 *Plugin1) Init(cfg config.Configurer) error { + p1.config = cfg + return nil +} + +func (p1 *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (p1 *Plugin1) Stop() error { + return nil +} + +func (p1 *Plugin1) Name() string { + return "metrics_test.plugin1" +} + +func (p1 *Plugin1) MetricsCollector() []prometheus.Collector { + collector := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "my_gauge", + Help: "My gauge value", + }) + + collector.Set(100) + + collector2 := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "my_gauge2", + Help: "My gauge2 value", + }) + + collector2.Set(100) + return []prometheus.Collector{collector, collector2} +} diff --git a/plugins/resetter/plugin.go b/plugins/resetter/plugin.go new file mode 100644 index 00000000..99e02aef --- /dev/null +++ b/plugins/resetter/plugin.go @@ -0,0 +1,54 @@ +package resetter + +import ( + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/interfaces/resetter" +) + +const PluginName = "resetter" + +type Plugin struct { + registry map[string]resetter.Resetter + log log.Logger +} + +func (p *Plugin) Init(log log.Logger) error { + p.registry = make(map[string]resetter.Resetter) + p.log = log + return nil +} + +// Reset named service. +func (p *Plugin) Reset(name string) error { + svc, ok := p.registry[name] + if !ok { + return errors.E("no such service", errors.Str(name)) + } + + return svc.Reset() +} + +// RegisterTarget resettable service. +func (p *Plugin) RegisterTarget(name endure.Named, r resetter.Resetter) error { + p.registry[name.Name()] = r + return nil +} + +// Collects declares services to be collected. +func (p *Plugin) Collects() []interface{} { + return []interface{}{ + p.RegisterTarget, + } +} + +// Name of the service. +func (p *Plugin) Name() string { + return PluginName +} + +// RPCService returns associated rpc service. +func (p *Plugin) RPC() interface{} { + return &rpc{srv: p, log: p.log} +} diff --git a/plugins/resetter/rpc.go b/plugins/resetter/rpc.go new file mode 100644 index 00000000..ecc51bb3 --- /dev/null +++ b/plugins/resetter/rpc.go @@ -0,0 +1,30 @@ +package resetter + +import "github.com/spiral/roadrunner/v2/interfaces/log" + +type rpc struct { + srv *Plugin + log log.Logger +} + +// List all resettable services. +func (rpc *rpc) List(_ bool, list *[]string) error { + rpc.log.Debug("started List method") + *list = make([]string, 0) + + for name := range rpc.srv.registry { + *list = append(*list, name) + } + rpc.log.Debug("services list", "services", *list) + + rpc.log.Debug("finished List method") + return nil +} + +// Reset named service. +func (rpc *rpc) Reset(service string, done *bool) error { + rpc.log.Debug("started Reset method for the service", "service", service) + defer rpc.log.Debug("finished Reset method for the service", "service", service) + *done = true + return rpc.srv.Reset(service) +} diff --git a/plugins/resetter/tests/.rr-resetter.yaml b/plugins/resetter/tests/.rr-resetter.yaml new file mode 100644 index 00000000..83ecd582 --- /dev/null +++ b/plugins/resetter/tests/.rr-resetter.yaml @@ -0,0 +1,13 @@ +server: + command: "php ../../../tests/client.php echo pipes" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "pipes" + relayTimeout: "20s" + +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false
\ No newline at end of file diff --git a/plugins/resetter/tests/resetter_test.go b/plugins/resetter/tests/resetter_test.go new file mode 100644 index 00000000..a1873dd4 --- /dev/null +++ b/plugins/resetter/tests/resetter_test.go @@ -0,0 +1,101 @@ +package tests + +import ( + "net" + "net/rpc" + "os" + "os/signal" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/goridge/v2" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/resetter" + rpcPlugin "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestInformerInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + cfg := &config.Viper{ + Path: ".rr-resetter.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &server.Plugin{}, + &logger.ZapLogger{}, + &resetter.Plugin{}, + &rpcPlugin.Plugin{}, + &Plugin1{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + tt := time.NewTimer(time.Second * 15) + + t.Run("InformerRpcTest", resetterRPCTest) + + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } +} + +func resetterRPCTest(t *testing.T) { + conn, err := net.Dial("tcp", "127.0.0.1:6001") + assert.NoError(t, err) + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + // WorkerList contains list of workers. + + var ret bool + err = client.Call("resetter.Reset", "resetter.plugin1", &ret) + assert.NoError(t, err) + assert.True(t, ret) + ret = false + + var services []string + err = client.Call("resetter.List", nil, &services) + assert.NoError(t, err) + if services[0] != "resetter.plugin1" { + t.Fatal("no enough services") + } +} diff --git a/plugins/resetter/tests/test_plugin.go b/plugins/resetter/tests/test_plugin.go new file mode 100644 index 00000000..9f48a43f --- /dev/null +++ b/plugins/resetter/tests/test_plugin.go @@ -0,0 +1,66 @@ +package tests + +import ( + "context" + "time" + + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +var testPoolConfig = roadrunner.PoolConfig{ + NumWorkers: 10, + MaxJobs: 100, + AllocateTimeout: time.Second * 10, + DestroyTimeout: time.Second * 10, + Supervisor: &roadrunner.SupervisorConfig{ + WatchTick: 60, + TTL: 1000, + IdleTTL: 10, + ExecTTL: 10, + MaxWorkerMemory: 1000, + }, +} + +// Gauge ////////////// +type Plugin1 struct { + config config.Configurer + server server.Server +} + +func (p1 *Plugin1) Init(cfg config.Configurer, server server.Server) error { + p1.config = cfg + p1.server = server + return nil +} + +func (p1 *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (p1 *Plugin1) Stop() error { + return nil +} + +func (p1 *Plugin1) Name() string { + return "resetter.plugin1" +} + +func (p1 *Plugin1) Reset() error { + pool, err := p1.server.NewWorkerPool(context.Background(), testPoolConfig, nil) + if err != nil { + panic(err) + } + pool.Destroy(context.Background()) + + pool, err = p1.server.NewWorkerPool(context.Background(), testPoolConfig, nil) + if err != nil { + panic(err) + } + + _ = pool + + return nil +} diff --git a/plugins/rpc/config.go b/plugins/rpc/config.go new file mode 100755 index 00000000..719fd5e3 --- /dev/null +++ b/plugins/rpc/config.go @@ -0,0 +1,49 @@ +package rpc + +import ( + "errors" + "net" + "strings" + + "github.com/spiral/roadrunner/v2/util" +) + +// Config defines RPC service config. +type Config struct { + // Listen string + Listen string + + // Disabled disables RPC service. + Disabled bool +} + +// InitDefaults allows to init blank config with pre-defined set of default values. +func (c *Config) InitDefaults() { + if c.Listen == "" { + c.Listen = "tcp://127.0.0.1:6001" + } +} + +// Valid returns nil if config is valid. +func (c *Config) Valid() error { + if dsn := strings.Split(c.Listen, "://"); len(dsn) != 2 { + return errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + } + + return nil +} + +// Listener creates new rpc socket Listener. +func (c *Config) Listener() (net.Listener, error) { + return util.CreateListener(c.Listen) +} + +// Dialer creates rpc socket Dialer. +func (c *Config) Dialer() (net.Conn, error) { + dsn := strings.Split(c.Listen, "://") + if len(dsn) != 2 { + return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + } + + return net.Dial(dsn[0], dsn[1]) +} diff --git a/plugins/rpc/config_test.go b/plugins/rpc/config_test.go new file mode 100755 index 00000000..8b1d974a --- /dev/null +++ b/plugins/rpc/config_test.go @@ -0,0 +1,138 @@ +package rpc + +import ( + "testing" + + j "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" +) + +var json = j.ConfigCompatibleWithStandardLibrary + +type testCfg struct{ cfg string } + +func (cfg *testCfg) Unmarshal(out interface{}) error { + return json.Unmarshal([]byte(cfg.cfg), out) +} + +func TestConfig_Listener(t *testing.T) { + cfg := &Config{Listen: "tcp://:18001"} + + ln, err := cfg.Listener() + assert.NoError(t, err) + assert.NotNil(t, ln) + defer func() { + err := ln.Close() + if err != nil { + t.Errorf("error closing the listener: error %v", err) + } + }() + + assert.Equal(t, "tcp", ln.Addr().Network()) + assert.Equal(t, "0.0.0.0:18001", ln.Addr().String()) +} + +func TestConfig_ListenerUnix(t *testing.T) { + cfg := &Config{Listen: "unix://file.sock"} + + ln, err := cfg.Listener() + assert.NoError(t, err) + assert.NotNil(t, ln) + defer func() { + err := ln.Close() + if err != nil { + t.Errorf("error closing the listener: error %v", err) + } + }() + + assert.Equal(t, "unix", ln.Addr().Network()) + assert.Equal(t, "file.sock", ln.Addr().String()) +} + +func Test_Config_Error(t *testing.T) { + cfg := &Config{Listen: "uni:unix.sock"} + ln, err := cfg.Listener() + assert.Nil(t, ln) + assert.Error(t, err) + assert.Equal(t, "invalid DSN (tcp://:6001, unix://file.sock)", err.Error()) +} + +func Test_Config_ErrorMethod(t *testing.T) { + cfg := &Config{Listen: "xinu://unix.sock"} + + ln, err := cfg.Listener() + assert.Nil(t, ln) + assert.Error(t, err) +} + +func TestConfig_Dialer(t *testing.T) { + cfg := &Config{Listen: "tcp://:18001"} + + ln, _ := cfg.Listener() + defer func() { + err := ln.Close() + if err != nil { + t.Errorf("error closing the listener: error %v", err) + } + }() + + conn, err := cfg.Dialer() + assert.NoError(t, err) + assert.NotNil(t, conn) + defer func() { + err := conn.Close() + if err != nil { + t.Errorf("error closing the connection: error %v", err) + } + }() + + assert.Equal(t, "tcp", conn.RemoteAddr().Network()) + assert.Equal(t, "127.0.0.1:18001", conn.RemoteAddr().String()) +} + +func TestConfig_DialerUnix(t *testing.T) { + cfg := &Config{Listen: "unix://file.sock"} + + ln, _ := cfg.Listener() + defer func() { + err := ln.Close() + if err != nil { + t.Errorf("error closing the listener: error %v", err) + } + }() + + conn, err := cfg.Dialer() + assert.NoError(t, err) + assert.NotNil(t, conn) + defer func() { + err := conn.Close() + if err != nil { + t.Errorf("error closing the connection: error %v", err) + } + }() + + assert.Equal(t, "unix", conn.RemoteAddr().Network()) + assert.Equal(t, "file.sock", conn.RemoteAddr().String()) +} + +func Test_Config_DialerError(t *testing.T) { + cfg := &Config{Listen: "uni:unix.sock"} + ln, err := cfg.Dialer() + assert.Nil(t, ln) + assert.Error(t, err) + assert.Equal(t, "invalid socket DSN (tcp://:6001, unix://file.sock)", err.Error()) +} + +func Test_Config_DialerErrorMethod(t *testing.T) { + cfg := &Config{Listen: "xinu://unix.sock"} + + ln, err := cfg.Dialer() + assert.Nil(t, ln) + assert.Error(t, err) +} + +func Test_Config_Defaults(t *testing.T) { + c := &Config{} + c.InitDefaults() + assert.Equal(t, "tcp://127.0.0.1:6001", c.Listen) +} diff --git a/plugins/rpc/doc/plugin_arch.drawio b/plugins/rpc/doc/plugin_arch.drawio new file mode 100755 index 00000000..dec5f0b2 --- /dev/null +++ b/plugins/rpc/doc/plugin_arch.drawio @@ -0,0 +1 @@ +<mxfile host="Electron" modified="2020-10-19T17:14:19.125Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/13.7.9 Chrome/85.0.4183.121 Electron/10.1.3 Safari/537.36" etag="2J39x4EyFr1zaE9BXKM4" version="13.7.9" type="device"><diagram id="q2oMKs6VHyn7y0AfAXBL" name="Page-1">7Vttc9o4EP41zLQfksE2GPIxQHPXu7RlQntt7ptiC1sX2XJlOUB//a1sGdtIJDQFnE6YyUys1YutfR7trlai44yj5R8cJeEH5mPasbv+suNMOrZtORcO/JOSVSEZWv1CEHDiq0aVYEZ+YCXsKmlGfJw2GgrGqCBJU+ixOMaeaMgQ52zRbDZntPnWBAVYE8w8RHXpV+KLUEkt96Kq+BOTIFSvHtqDoiJCZWM1kzREPlvURM67jjPmjIniKVqOMZXKK/VS9LvaUrv+MI5jsUuHL/zu0yx7//HT3Pln8vfN59vvS/usVHMqVuWMsQ8KUEXGRcgCFiP6rpKOOMtiH8thu1Cq2lwzloDQAuF/WIiVQhNlgoEoFBFVtXhJxLfa860c6ryvSpOlGjkvrMpCLPjqW71Q6yWLVbe8VPabs1hcoYhQKRizjBPMYcIf8UJVqq+8gGKhC6mArTpWohQG8lSrfz88xF8/ds/+uiLe7MsXtLiyZ2clVxEPsHik3WDNBFhCmEUYvh36cUyRIA/N70CKy8G6XQU3PCjEfwZ9q030K8RvazVPoV8BftvA+7dE33KOBP9jX/mAaKbedDOFkbpTmgUk1qjRBH4REoFnCcr1sADj3wT55xVv0PMD5gIvayJdU6rWGSi3otyMYw3OlWRRme21VwlrFtsdHEi9jqbe9zERha+ak0DTL0xVNJWIKAliePZAMaA+ZyQVQsA5XaqKiPh+sShxSn6gu3woiU7CSCzyCfVHnf5EjgXrMC103go+3Q18hho6QwM4pfPcOzg9DZwJTnDspyBk8Rqk8ylnDxCB8N8DLcveD1z2BlxWWa4vpu4x8epreOmuK/YvZcQnIaAoTYm34XeO5kMMun/aFRjdj45QDYG+AYBStrMHUW+YSgpWBOgNtxCgHKJwgapXPercGKhvbwxkbQxUKEYbKCfJetrP542r8aa0vt0U9gsE1rpzKfWVeK97ia+Xc41glolhB1viA32Jj+3O5YhIXc9loAHFEczdpRKWO95Ay/2eyZ1UrqqzQq8S14tkmeurrIanQP0vRvmVQYA052WwVAwHE7+rXrHBp/bCI3f4tPu1jMGReyCwLT06KoLPVPDMExnHmvrSBYkoinGpIVWz07oUcm8y8kJC/Wu0YpmcXiqQd1+WRiHj5AcMi0qIoJqXMNhuo8VM9lQLO1/oeFqiY22IPqBlo+E1SoUSeIxSlKTkbj2NCGwhiUdMCBbt0/k8P47uuQarULapE8Vye4diytDg+ke7R2hAKHaPx4wyIMYkZgWBCKUbopJDFM/FVgalsOEhcXCdt5n0KsmNUoUUMeg7p3kgEoI/wHG+axZIbPUHI9DyWIYl4BnsMZStqpw7iwT22WMWw1wQycHFwKMFTsUvU+Tx1fk0cUr34e7GE/tQBqV0SxpNpJGeYf6QK+VNjMX5TeK9PbGlTbb07ZbZYl1sYUsKTCEeltvAIlKr+aNuSqHqxJw2mTMwBC7HZY6eOSiYMydYni3IeHH8aILnxIk9c8Lq9tomxQ7pCUpyqAszUZ4lWc/iw3qXqQjwOc+8n1kaSRydJI6BEBTdYTqF3WixH57woq1h0/ryueDsGLAOD0UFPeNQ2AcYPmT+G7FK8NvCTMjHkzdply1HdCfmIzhDHvMIR3Av9jDVrKTOjjnUCzPaRzpN1Ra+Ciafk9Xo/nK6wmAsfpMMhrZ+DazZmsHoNTNdPcvgD1xDpmuwB4dgpIX9dLxY8aTKdZ78wp7osn2t/lQyw8SZg3kFPTmqcSZGkTIsgNeJLS2yxZTMOCpb9IizMigcByQFmyITGlYxV4A2o0iqyc+PvOGvYYPmTNbl2Xgzq17Wgdie/Ia1cYFkqO8pHftAx2FGVPUMVVJkul8VLK61cXJl67gc6pTSbAvcVgJ245259TW5Vm5M1k6i9xPlO7uG+b1Ww3zdOVdXCk5h/pHsgtM0C64p7WNywqWz3j8tdsgLX0tXHJ+itiNFbVsu176UIN/SL7xMOQOFR2lOl7a9fN3MP4rYHpbzxq7dsGk/1O1QMzT6nYOAqSAZFqaPvY78hYecQIBjzJGQgbNgsk2UeaH8Ji93RdLvefdY3ohDeZyNlx7G8iGjJMqvA5/pV61fE9YGy93fU6ANxer3NcWNwupXSs67/wE=</diagram></mxfile>
\ No newline at end of file diff --git a/plugins/rpc/plugin.go b/plugins/rpc/plugin.go new file mode 100755 index 00000000..24624d91 --- /dev/null +++ b/plugins/rpc/plugin.go @@ -0,0 +1,165 @@ +package rpc + +import ( + "net" + "net/rpc" + "sync/atomic" + + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/goridge/v2" + "github.com/spiral/roadrunner/v2/interfaces/log" + rpc_ "github.com/spiral/roadrunner/v2/interfaces/rpc" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// PluginName contains default plugin name. +const PluginName = "RPC" + +type pluggable struct { + service rpc_.RPCer + name string +} + +// Plugin is RPC service. +type Plugin struct { + cfg Config + log log.Logger + rpc *rpc.Server + services []pluggable + listener net.Listener + closed *uint32 +} + +// Init rpc service. Must return true if service is enabled. +func (s *Plugin) Init(cfg config.Configurer, log log.Logger) error { + const op = errors.Op("RPC Init") + if !cfg.Has(PluginName) { + return errors.E(op, errors.Disabled) + } + + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return err + } + s.cfg.InitDefaults() + + if s.cfg.Disabled { + return errors.E(op, errors.Disabled) + } + + s.log = log + state := uint32(0) + s.closed = &state + atomic.StoreUint32(s.closed, 0) + + return s.cfg.Valid() +} + +// Serve serves the service. +func (s *Plugin) Serve() chan error { + const op = errors.Op("register service") + errCh := make(chan error, 1) + + s.rpc = rpc.NewServer() + + services := make([]string, 0, len(s.services)) + + // Attach all services + for i := 0; i < len(s.services); i++ { + err := s.Register(s.services[i].name, s.services[i].service.RPC()) + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + + services = append(services, s.services[i].name) + } + + var err error + s.listener, err = s.cfg.Listener() + if err != nil { + errCh <- err + return errCh + } + + s.log.Debug("Started RPC service", "address", s.cfg.Listen, "services", services) + + go func() { + for { + conn, err := s.listener.Accept() + if err != nil { + if atomic.LoadUint32(s.closed) == 1 { + // just log and continue, this is not a critical issue, we just called Stop + s.log.Error("listener accept error, connection closed", "error", err) + return + } + + s.log.Error("listener accept error", "error", err) + errCh <- errors.E(errors.Op("listener accept"), errors.Serve, err) + return + } + + go s.rpc.ServeCodec(goridge.NewCodec(conn)) + } + }() + + return errCh +} + +// Stop stops the service. +func (s *Plugin) Stop() error { + // store closed state + atomic.StoreUint32(s.closed, 1) + err := s.listener.Close() + if err != nil { + return errors.E(errors.Op("stop RPC socket"), err) + } + return nil +} + +// Name contains service name. +func (s *Plugin) Name() string { + return PluginName +} + +// Depends declares services to collect for RPC. +func (s *Plugin) Collects() []interface{} { + return []interface{}{ + s.RegisterPlugin, + } +} + +// RegisterPlugin registers RPC service plugin. +func (s *Plugin) RegisterPlugin(name endure.Named, p rpc_.RPCer) { + s.services = append(s.services, pluggable{ + service: p, + name: name.Name(), + }) +} + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +// It returns an error if the receiver is not an exported type or has +// no suitable methods. It also logs the error using package log. +func (s *Plugin) Register(name string, svc interface{}) error { + if s.rpc == nil { + return errors.E("RPC service is not configured") + } + + return s.rpc.RegisterName(name, svc) +} + +// Client creates new RPC client. +func (s *Plugin) Client() (*rpc.Client, error) { + conn, err := s.cfg.Dialer() + if err != nil { + return nil, err + } + + return rpc.NewClientWithCodec(goridge.NewClientCodec(conn)), nil +} diff --git a/plugins/rpc/tests/.rr-rpc-disabled.yaml b/plugins/rpc/tests/.rr-rpc-disabled.yaml new file mode 100644 index 00000000..624fb3c5 --- /dev/null +++ b/plugins/rpc/tests/.rr-rpc-disabled.yaml @@ -0,0 +1,3 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: true
\ No newline at end of file diff --git a/plugins/rpc/tests/.rr.yaml b/plugins/rpc/tests/.rr.yaml new file mode 100644 index 00000000..76e8b440 --- /dev/null +++ b/plugins/rpc/tests/.rr.yaml @@ -0,0 +1,3 @@ +rpc: + listen: tcp://127.0.0.1:6001 + disabled: false
\ No newline at end of file diff --git a/plugins/rpc/tests/plugin1.go b/plugins/rpc/tests/plugin1.go new file mode 100644 index 00000000..79e98ed4 --- /dev/null +++ b/plugins/rpc/tests/plugin1.go @@ -0,0 +1,42 @@ +package tests + +import ( + "fmt" + + "github.com/spiral/roadrunner/v2/plugins/config" +) + +type Plugin1 struct { + config config.Configurer +} + +func (p1 *Plugin1) Init(cfg config.Configurer) error { + p1.config = cfg + return nil +} + +func (p1 *Plugin1) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (p1 *Plugin1) Stop() error { + return nil +} + +func (p1 *Plugin1) Name() string { + return "rpc_test.plugin1" +} + +func (p1 *Plugin1) RPC() interface{} { + return &PluginRPC{srv: p1} +} + +type PluginRPC struct { + srv *Plugin1 +} + +func (r *PluginRPC) Hello(in string, out *string) error { + *out = fmt.Sprintf("Hello, username: %s", in) + return nil +} diff --git a/plugins/rpc/tests/plugin2.go b/plugins/rpc/tests/plugin2.go new file mode 100644 index 00000000..854bf097 --- /dev/null +++ b/plugins/rpc/tests/plugin2.go @@ -0,0 +1,54 @@ +package tests + +import ( + "net" + "net/rpc" + "time" + + "github.com/spiral/errors" + "github.com/spiral/goridge/v2" +) + +// plugin2 makes a call to the plugin1 via RPC +// this is just a simulation of external call FOR TEST +// you don't need to do such things :) +type Plugin2 struct { +} + +func (p2 *Plugin2) Init() error { + return nil +} + +func (p2 *Plugin2) Serve() chan error { + errCh := make(chan error, 1) + + go func() { + time.Sleep(time.Second * 3) + + conn, err := net.Dial("tcp", "127.0.0.1:6001") + if err != nil { + errCh <- errors.E(errors.Serve, err) + return + } + client := rpc.NewClientWithCodec(goridge.NewClientCodec(conn)) + var ret string + err = client.Call("rpc_test.plugin1.Hello", "Valery", &ret) + if err != nil { + errCh <- err + return + } + if ret != "Hello, username: Valery" { + errCh <- errors.E("wrong response") + return + } + // to stop exec + errCh <- errors.E(errors.Disabled) + return + }() + + return errCh +} + +func (p2 *Plugin2) Stop() error { + return nil +} diff --git a/plugins/rpc/tests/rpc_test.go b/plugins/rpc/tests/rpc_test.go new file mode 100644 index 00000000..88267dfb --- /dev/null +++ b/plugins/rpc/tests/rpc_test.go @@ -0,0 +1,169 @@ +package tests + +import ( + "os" + "os/signal" + "syscall" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/rpc" + "github.com/stretchr/testify/assert" +) + +// graph https://bit.ly/3ensdNb +func TestRpcInit(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&Plugin1{}) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&Plugin2{}) + if err != nil { + t.Fatal(err) + } + + v := &config.Viper{} + v.Path = ".rr.yaml" + v.Prefix = "rr" + err = cont.Register(v) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&rpc.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + if err != nil { + t.Fatal(err) + } + + sig := make(chan os.Signal, 1) + + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + tt := time.NewTimer(time.Second * 10) + + for { + select { + case e := <-ch: + // just stop, this is ok + if errors.Is(errors.Disabled, e.Error) { + return + } + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + assert.Fail(t, "timeout") + } + } +} + +// graph https://bit.ly/3ensdNb +func TestRpcDisabled(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&Plugin1{}) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&Plugin2{}) + if err != nil { + t.Fatal(err) + } + + v := &config.Viper{} + v.Path = ".rr-rpc-disabled.yaml" + v.Prefix = "rr" + err = cont.Register(v) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&rpc.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = cont.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + if err != nil { + t.Fatal(err) + } + + sig := make(chan os.Signal, 1) + + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + tt := time.NewTimer(time.Second * 20) + + for { + select { + case e := <-ch: + // RPC is turned off, should be and dial error + if errors.Is(errors.Disabled, e.Error) { + assert.FailNow(t, "should not be disabled error") + } + assert.Error(t, e.Error) + err = cont.Stop() + assert.Error(t, err) + return + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + return + } + } +} diff --git a/plugins/server/config.go b/plugins/server/config.go new file mode 100644 index 00000000..147ae0f7 --- /dev/null +++ b/plugins/server/config.go @@ -0,0 +1,41 @@ +package server + +import ( + "time" + + "github.com/spiral/roadrunner/v2/interfaces/server" +) + +// Config config combines factory, pool and cmd configurations. +type Config struct { + // Command to run as application. + Command string + + // User to run application under. + User string + + // Group to run application under. + Group string + + // Env represents application environment. + Env server.Env + + // Listen defines connection method and factory to be used to connect to workers: + // "pipes", "tcp://:6001", "unix://rr.sock" + // This config section must not change on re-configuration. + Relay string + + // RelayTimeout defines for how long socket factory will be waiting for worker connection. This config section + // must not change on re-configuration. Defaults to 60s. + RelayTimeout time.Duration +} + +func (cfg *Config) InitDefaults() { + if cfg.Relay == "" { + cfg.Relay = "pipes" + } + + if cfg.RelayTimeout == 0 { + cfg.RelayTimeout = time.Second * 60 + } +} diff --git a/plugins/server/plugin.go b/plugins/server/plugin.go new file mode 100644 index 00000000..ea6d42eb --- /dev/null +++ b/plugins/server/plugin.go @@ -0,0 +1,176 @@ +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/util" +) + +const PluginName = "server" + +// Plugin manages worker +type Plugin struct { + cfg Config + log log.Logger + factory roadrunner.Factory +} + +// Init application provider. +func (server *Plugin) Init(cfg config.Configurer, log log.Logger) error { + const op = errors.Op("Init") + err := cfg.UnmarshalKey(PluginName, &server.cfg) + if err != nil { + return errors.E(op, errors.Init, err) + } + server.cfg.InitDefaults() + server.log = log + + server.factory, err = server.initFactory() + if err != nil { + return errors.E(errors.Op("Init factory"), err) + } + + return nil +} + +// Name contains service name. +func (server *Plugin) Name() string { + return PluginName +} + +func (server *Plugin) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (server *Plugin) Stop() error { + if server.factory == nil { + return nil + } + + return server.factory.Close(context.Background()) +} + +// CmdFactory provides worker command factory assocated with given context. +func (server *Plugin) CmdFactory(env server.Env) (func() *exec.Cmd, error) { + const op = errors.Op("cmd factory") + var cmdArgs []string + + // create command according to the config + cmdArgs = append(cmdArgs, strings.Split(server.cfg.Command, " ")...) + if len(cmdArgs) < 2 { + return nil, errors.E(op, errors.Str("should be in form of `php <script>")) + } + if cmdArgs[0] != "php" { + return nil, errors.E(op, errors.Str("first arg in command should be `php`")) + } + return func() *exec.Cmd { + cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) //nolint:gosec + util.IsolateProcess(cmd) + + // if user is not empty, and OS is linux or macos + // execute php worker from that particular user + if server.cfg.User != "" { + err := util.ExecuteFromUser(cmd, server.cfg.User) + if err != nil { + return nil + } + } + + cmd.Env = server.setEnv(env) + + return cmd + }, nil +} + +// NewWorker issues new standalone worker. +func (server *Plugin) NewWorker(ctx context.Context, env server.Env) (roadrunner.WorkerBase, error) { + const op = errors.Op("new worker") + spawnCmd, err := server.CmdFactory(env) + if err != nil { + return nil, errors.E(op, err) + } + + w, err := server.factory.SpawnWorkerWithContext(ctx, spawnCmd()) + if err != nil { + return nil, errors.E(op, err) + } + + w.AddListener(server.collectLogs) + + return w, nil +} + +// NewWorkerPool issues new worker pool. +func (server *Plugin) NewWorkerPool(ctx context.Context, opt roadrunner.PoolConfig, env server.Env) (roadrunner.Pool, error) { + spawnCmd, err := server.CmdFactory(env) + if err != nil { + return nil, err + } + + p, err := roadrunner.NewPool(ctx, spawnCmd, server.factory, opt) + if err != nil { + return nil, err + } + + p.AddListener(server.collectLogs) + + return p, nil +} + +// creates relay and worker factory. +func (server *Plugin) initFactory() (roadrunner.Factory, error) { + const op = errors.Op("network factory init") + if server.cfg.Relay == "" || server.cfg.Relay == "pipes" { + return roadrunner.NewPipeFactory(), nil + } + + dsn := strings.Split(server.cfg.Relay, "://") + if len(dsn) != 2 { + return nil, errors.E(op, errors.Network, errors.Str("invalid DSN (tcp://:6001, unix://file.sock)")) + } + + lsn, err := util.CreateListener(server.cfg.Relay) + if err != nil { + return nil, errors.E(op, errors.Network, err) + } + + switch dsn[0] { + // sockets group + case "unix": + return roadrunner.NewSocketServer(lsn, server.cfg.RelayTimeout), nil + case "tcp": + return roadrunner.NewSocketServer(lsn, server.cfg.RelayTimeout), nil + default: + return nil, errors.E(op, errors.Network, errors.Str("invalid DSN (tcp://:6001, unix://file.sock)")) + } +} + +func (server *Plugin) setEnv(e server.Env) []string { + env := append(os.Environ(), fmt.Sprintf("RR_RELAY=%s", server.cfg.Relay)) + for k, v := range e { + env = append(env, fmt.Sprintf("%s=%s", strings.ToUpper(k), v)) + } + + return env +} + +func (server *Plugin) collectLogs(event interface{}) { + if we, ok := event.(roadrunner.WorkerEvent); ok { + switch we.Event { + case roadrunner.EventWorkerError: + server.log.Error(we.Payload.(error).Error(), "pid", we.Worker.Pid()) + case roadrunner.EventWorkerLog: + server.log.Debug(strings.TrimRight(string(we.Payload.([]byte)), " \n\t"), "pid", we.Worker.Pid()) + } + } +} diff --git a/plugins/server/tests/configs/.rr-no-app-section.yaml b/plugins/server/tests/configs/.rr-no-app-section.yaml new file mode 100644 index 00000000..b6e3ea93 --- /dev/null +++ b/plugins/server/tests/configs/.rr-no-app-section.yaml @@ -0,0 +1,9 @@ +server: + command: "php ../../../tests/client.php echo pipes" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "pipes" + relayTimeout: "20s"
\ No newline at end of file diff --git a/plugins/server/tests/configs/.rr-sockets.yaml b/plugins/server/tests/configs/.rr-sockets.yaml new file mode 100644 index 00000000..ab1239aa --- /dev/null +++ b/plugins/server/tests/configs/.rr-sockets.yaml @@ -0,0 +1,9 @@ +server: + command: "php socket.php" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "unix://unix.sock" + relayTimeout: "20s"
\ No newline at end of file diff --git a/plugins/server/tests/configs/.rr-tcp.yaml b/plugins/server/tests/configs/.rr-tcp.yaml new file mode 100644 index 00000000..f53bffcc --- /dev/null +++ b/plugins/server/tests/configs/.rr-tcp.yaml @@ -0,0 +1,9 @@ +server: + command: "php tcp.php" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "tcp://localhost:9999" + relayTimeout: "20s"
\ No newline at end of file diff --git a/plugins/server/tests/configs/.rr-wrong-command.yaml b/plugins/server/tests/configs/.rr-wrong-command.yaml new file mode 100644 index 00000000..d2c087a6 --- /dev/null +++ b/plugins/server/tests/configs/.rr-wrong-command.yaml @@ -0,0 +1,9 @@ +server: + command: "php some_absent_file.php" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "pipes" + relayTimeout: "20s" diff --git a/plugins/server/tests/configs/.rr-wrong-relay.yaml b/plugins/server/tests/configs/.rr-wrong-relay.yaml new file mode 100644 index 00000000..1dd73d73 --- /dev/null +++ b/plugins/server/tests/configs/.rr-wrong-relay.yaml @@ -0,0 +1,9 @@ +server: + command: "php ../../../tests/client.php echo pipes" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "pupes" + relayTimeout: "20s"
\ No newline at end of file diff --git a/plugins/server/tests/configs/.rr.yaml b/plugins/server/tests/configs/.rr.yaml new file mode 100644 index 00000000..b6e3ea93 --- /dev/null +++ b/plugins/server/tests/configs/.rr.yaml @@ -0,0 +1,9 @@ +server: + command: "php ../../../tests/client.php echo pipes" + user: "" + group: "" + env: + "RR_CONFIG": "/some/place/on/the/C134" + "RR_CONFIG2": "C138" + relay: "pipes" + relayTimeout: "20s"
\ No newline at end of file diff --git a/plugins/server/tests/plugin_pipes.go b/plugins/server/tests/plugin_pipes.go new file mode 100644 index 00000000..fbd37e12 --- /dev/null +++ b/plugins/server/tests/plugin_pipes.go @@ -0,0 +1,131 @@ +package tests + +import ( + "context" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/plugins/config" + plugin "github.com/spiral/roadrunner/v2/plugins/server" +) + +const ConfigSection = "server" +const Response = "test" + +var testPoolConfig = roadrunner.PoolConfig{ + NumWorkers: 10, + MaxJobs: 100, + AllocateTimeout: time.Second * 10, + DestroyTimeout: time.Second * 10, + Supervisor: &roadrunner.SupervisorConfig{ + WatchTick: 60, + TTL: 1000, + IdleTTL: 10, + ExecTTL: 10, + MaxWorkerMemory: 1000, + }, +} + +type Foo struct { + configProvider config.Configurer + wf server.Server + pool roadrunner.Pool +} + +func (f *Foo) Init(p config.Configurer, workerFactory server.Server) error { + f.configProvider = p + f.wf = workerFactory + return nil +} + +func (f *Foo) Serve() chan error { + const op = errors.Op("serve") + + // test payload for echo + r := roadrunner.Payload{ + Context: nil, + Body: []byte(Response), + } + + errCh := make(chan error, 1) + + conf := &plugin.Config{} + var err error + err = f.configProvider.UnmarshalKey(ConfigSection, conf) + if err != nil { + errCh <- err + return errCh + } + + // test CMDFactory + cmd, err := f.wf.CmdFactory(nil) + if err != nil { + errCh <- err + return errCh + } + if cmd == nil { + errCh <- errors.E(op, "command is nil") + return errCh + } + + // test worker creation + w, err := f.wf.NewWorker(context.Background(), nil) + if err != nil { + errCh <- err + return errCh + } + + // test that our worker is functional + sw, err := roadrunner.NewSyncWorker(w) + if err != nil { + errCh <- err + return errCh + } + + rsp, err := sw.Exec(r) + if err != nil { + errCh <- err + return errCh + } + + if string(rsp.Body) != Response { + errCh <- errors.E("response from worker is wrong", errors.Errorf("response: %s", rsp.Body)) + return errCh + } + + // should not be errors + err = sw.Stop(context.Background()) + if err != nil { + errCh <- err + return errCh + } + + // test pool + f.pool, err = f.wf.NewWorkerPool(context.Background(), testPoolConfig, nil) + if err != nil { + errCh <- err + return errCh + } + + // test pool execution + rsp, err = f.pool.Exec(r) + if err != nil { + errCh <- err + return errCh + } + + // echo of the "test" should be -> test + if string(rsp.Body) != Response { + errCh <- errors.E("response from worker is wrong", errors.Errorf("response: %s", rsp.Body)) + return errCh + } + + return errCh +} + +func (f *Foo) Stop() error { + f.pool.Destroy(context.Background()) + return nil +} diff --git a/plugins/server/tests/plugin_sockets.go b/plugins/server/tests/plugin_sockets.go new file mode 100644 index 00000000..4942d4c5 --- /dev/null +++ b/plugins/server/tests/plugin_sockets.go @@ -0,0 +1,112 @@ +package tests + +import ( + "context" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/plugins/config" + plugin "github.com/spiral/roadrunner/v2/plugins/server" +) + +type Foo2 struct { + configProvider config.Configurer + wf server.Server + pool roadrunner.Pool +} + +func (f *Foo2) Init(p config.Configurer, workerFactory server.Server) error { + f.configProvider = p + f.wf = workerFactory + return nil +} + +func (f *Foo2) Serve() chan error { + const op = errors.Op("serve") + var err error + errCh := make(chan error, 1) + conf := &plugin.Config{} + + // test payload for echo + r := roadrunner.Payload{ + Context: nil, + Body: []byte(Response), + } + + err = f.configProvider.UnmarshalKey(ConfigSection, conf) + if err != nil { + errCh <- err + return errCh + } + + // test CMDFactory + cmd, err := f.wf.CmdFactory(nil) + if err != nil { + errCh <- err + return errCh + } + if cmd == nil { + errCh <- errors.E(op, "command is nil") + return errCh + } + + // test worker creation + w, err := f.wf.NewWorker(context.Background(), nil) + if err != nil { + errCh <- err + return errCh + } + + // test that our worker is functional + sw, err := roadrunner.NewSyncWorker(w) + if err != nil { + errCh <- err + return errCh + } + + rsp, err := sw.Exec(r) + if err != nil { + errCh <- err + return errCh + } + + if string(rsp.Body) != Response { + errCh <- errors.E("response from worker is wrong", errors.Errorf("response: %s", rsp.Body)) + return errCh + } + + // should not be errors + err = sw.Stop(context.Background()) + if err != nil { + errCh <- err + return errCh + } + + // test pool + f.pool, err = f.wf.NewWorkerPool(context.Background(), testPoolConfig, nil) + if err != nil { + errCh <- err + return errCh + } + + // test pool execution + rsp, err = f.pool.Exec(r) + if err != nil { + errCh <- err + return errCh + } + + // echo of the "test" should be -> test + if string(rsp.Body) != Response { + errCh <- errors.E("response from worker is wrong", errors.Errorf("response: %s", rsp.Body)) + return errCh + } + + return errCh +} + +func (f *Foo2) Stop() error { + f.pool.Destroy(context.Background()) + return nil +} diff --git a/plugins/server/tests/plugin_tcp.go b/plugins/server/tests/plugin_tcp.go new file mode 100644 index 00000000..89757a02 --- /dev/null +++ b/plugins/server/tests/plugin_tcp.go @@ -0,0 +1,112 @@ +package tests + +import ( + "context" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2" + "github.com/spiral/roadrunner/v2/interfaces/server" + "github.com/spiral/roadrunner/v2/plugins/config" + plugin "github.com/spiral/roadrunner/v2/plugins/server" +) + +type Foo3 struct { + configProvider config.Configurer + wf server.Server + pool roadrunner.Pool +} + +func (f *Foo3) Init(p config.Configurer, workerFactory server.Server) error { + f.configProvider = p + f.wf = workerFactory + return nil +} + +func (f *Foo3) Serve() chan error { + const op = errors.Op("serve") + var err error + errCh := make(chan error, 1) + conf := &plugin.Config{} + + // test payload for echo + r := roadrunner.Payload{ + Context: nil, + Body: []byte(Response), + } + + err = f.configProvider.UnmarshalKey(ConfigSection, conf) + if err != nil { + errCh <- err + return errCh + } + + // test CMDFactory + cmd, err := f.wf.CmdFactory(nil) + if err != nil { + errCh <- err + return errCh + } + if cmd == nil { + errCh <- errors.E(op, "command is nil") + return errCh + } + + // test worker creation + w, err := f.wf.NewWorker(context.Background(), nil) + if err != nil { + errCh <- err + return errCh + } + + // test that our worker is functional + sw, err := roadrunner.NewSyncWorker(w) + if err != nil { + errCh <- err + return errCh + } + + rsp, err := sw.Exec(r) + if err != nil { + errCh <- err + return errCh + } + + if string(rsp.Body) != Response { + errCh <- errors.E("response from worker is wrong", errors.Errorf("response: %s", rsp.Body)) + return errCh + } + + // should not be errors + err = sw.Stop(context.Background()) + if err != nil { + errCh <- err + return errCh + } + + // test pool + f.pool, err = f.wf.NewWorkerPool(context.Background(), testPoolConfig, nil) + if err != nil { + errCh <- err + return errCh + } + + // test pool execution + rsp, err = f.pool.Exec(r) + if err != nil { + errCh <- err + return errCh + } + + // echo of the "test" should be -> test + if string(rsp.Body) != Response { + errCh <- errors.E("response from worker is wrong", errors.Errorf("response: %s", rsp.Body)) + return errCh + } + + return errCh +} + +func (f *Foo3) Stop() error { + f.pool.Destroy(context.Background()) + return nil +} diff --git a/plugins/server/tests/server_test.go b/plugins/server/tests/server_test.go new file mode 100644 index 00000000..bc374a9e --- /dev/null +++ b/plugins/server/tests/server_test.go @@ -0,0 +1,356 @@ +package tests + +import ( + "os" + "os/signal" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/stretchr/testify/assert" +) + +func TestAppPipes(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rr.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + errCh, err := container.Serve() + if err != nil { + t.Fatal(err) + } + + // stop by CTRL+C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + // stop after 10 seconds + tt := time.NewTicker(time.Second * 10) + + for { + select { + case e := <-errCh: + assert.NoError(t, e.Error) + assert.NoError(t, container.Stop()) + return + case <-c: + er := container.Stop() + if er != nil { + panic(er) + } + return + case <-tt.C: + tt.Stop() + assert.NoError(t, container.Stop()) + return + } + } +} + +func TestAppSockets(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rr-sockets.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo2{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + errCh, err := container.Serve() + if err != nil { + t.Fatal(err) + } + + // stop by CTRL+C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + // stop after 10 seconds + tt := time.NewTicker(time.Second * 10) + + for { + select { + case e := <-errCh: + assert.NoError(t, e.Error) + assert.NoError(t, container.Stop()) + return + case <-c: + er := container.Stop() + if er != nil { + panic(er) + } + return + case <-tt.C: + tt.Stop() + assert.NoError(t, container.Stop()) + return + } + } +} + +func TestAppTCP(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rr-tcp.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo3{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + errCh, err := container.Serve() + if err != nil { + t.Fatal(err) + } + + // stop by CTRL+C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + // stop after 10 seconds + tt := time.NewTicker(time.Second * 10) + + for { + select { + case e := <-errCh: + assert.NoError(t, e.Error) + assert.NoError(t, container.Stop()) + return + case <-c: + er := container.Stop() + if er != nil { + panic(er) + } + return + case <-tt.C: + tt.Stop() + assert.NoError(t, container.Stop()) + return + } + } +} + +func TestAppWrongConfig(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rrrrrrrrrr.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo3{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + assert.Error(t, container.Init()) +} + +func TestAppWrongRelay(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rr-wrong-relay.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo3{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + assert.Error(t, err) + + _, err = container.Serve() + assert.Error(t, err) +} + +func TestAppWrongCommand(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rr-wrong-command.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo3{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + _, err = container.Serve() + assert.Error(t, err) +} + +func TestAppNoAppSectionInConfig(t *testing.T) { + container, err := endure.NewContainer(nil, endure.RetryOnFail(true), endure.SetLogLevel(endure.DebugLevel)) + if err != nil { + t.Fatal(err) + } + // config plugin + vp := &config.Viper{} + vp.Path = "configs/.rr-wrong-command.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&server.Plugin{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo3{}) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&logger.ZapLogger{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + _, err = container.Serve() + assert.Error(t, err) +} diff --git a/plugins/server/tests/socket.php b/plugins/server/tests/socket.php new file mode 100644 index 00000000..143c3ce4 --- /dev/null +++ b/plugins/server/tests/socket.php @@ -0,0 +1,25 @@ +<?php +/** + * @var Goridge\RelayInterface $relay + */ + +use Spiral\Goridge; +use Spiral\RoadRunner; + +require dirname(__DIR__) . "/../../vendor_php/autoload.php"; + +$relay = new Goridge\SocketRelay( + "unix.sock", + null, + Goridge\SocketRelay::SOCK_UNIX + ); + +$rr = new RoadRunner\Worker($relay); + +while ($in = $rr->receive($ctx)) { + try { + $rr->send((string)$in); + } catch (\Throwable $e) { + $rr->error((string)$e); + } +} diff --git a/plugins/server/tests/tcp.php b/plugins/server/tests/tcp.php new file mode 100644 index 00000000..2d6fb00a --- /dev/null +++ b/plugins/server/tests/tcp.php @@ -0,0 +1,20 @@ +<?php +/** + * @var Goridge\RelayInterface $relay + */ + +use Spiral\Goridge; +use Spiral\RoadRunner; + +require dirname(__DIR__) . "/../../vendor_php/autoload.php"; + +$relay = new Goridge\SocketRelay("localhost", 9999); +$rr = new RoadRunner\Worker($relay); + +while ($in = $rr->receive($ctx)) { + try { + $rr->send((string)$in); + } catch (\Throwable $e) { + $rr->error((string)$e); + } +}
\ No newline at end of file diff --git a/plugins/static/config.go b/plugins/static/config.go new file mode 100644 index 00000000..f5d26b2d --- /dev/null +++ b/plugins/static/config.go @@ -0,0 +1,76 @@ +package static + +import ( + "os" + "path" + "strings" + + "github.com/spiral/errors" +) + +// Config describes file location and controls access to them. +type Config struct { + Static struct { + // Dir contains name of directory to control access to. + Dir string + + // Forbid specifies list of file extensions which are forbidden for access. + // Example: .php, .exe, .bat, .htaccess and etc. + Forbid []string + + // Always specifies list of extensions which must always be served by static + // service, even if file not found. + Always []string + + // Request headers to add to every static. + Request map[string]string + + // Response headers to add to every static. + Response map[string]string + } +} + +// Valid returns nil if config is valid. +func (c *Config) Valid() error { + const op = errors.Op("static plugin validation") + st, err := os.Stat(c.Static.Dir) + if err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("root directory '%s' does not exists", c.Static.Dir)) + } + + return err + } + + if !st.IsDir() { + return errors.E(op, errors.Errorf("invalid root directory '%s'", c.Static.Dir)) + } + + return nil +} + +// AlwaysForbid must return true if file extension is not allowed for the upload. +func (c *Config) AlwaysForbid(filename string) bool { + ext := strings.ToLower(path.Ext(filename)) + + for _, v := range c.Static.Forbid { + if ext == v { + return true + } + } + + return false +} + +// AlwaysServe must indicate that file is expected to be served by static service. +func (c *Config) AlwaysServe(filename string) bool { + ext := strings.ToLower(path.Ext(filename)) + + for _, v := range c.Static.Always { + if ext == v { + return true + } + } + + return false +} diff --git a/plugins/static/config_test.go b/plugins/static/config_test.go new file mode 100644 index 00000000..de88ded3 --- /dev/null +++ b/plugins/static/config_test.go @@ -0,0 +1,48 @@ +package static + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfig_Forbids(t *testing.T) { + cfg := Config{Static: struct { + Dir string + Forbid []string + Always []string + Request map[string]string + Response map[string]string + }{Dir: "", Forbid: []string{".php"}, Always: nil, Request: nil, Response: nil}} + + assert.True(t, cfg.AlwaysForbid("index.php")) + assert.True(t, cfg.AlwaysForbid("index.PHP")) + assert.True(t, cfg.AlwaysForbid("phpadmin/index.bak.php")) + assert.False(t, cfg.AlwaysForbid("index.html")) +} + +func TestConfig_Valid(t *testing.T) { + assert.NoError(t, (&Config{Static: struct { + Dir string + Forbid []string + Always []string + Request map[string]string + Response map[string]string + }{Dir: "./"}}).Valid()) + + assert.Error(t, (&Config{Static: struct { + Dir string + Forbid []string + Always []string + Request map[string]string + Response map[string]string + }{Dir: "./config.go"}}).Valid()) + + assert.Error(t, (&Config{Static: struct { + Dir string + Forbid []string + Always []string + Request map[string]string + Response map[string]string + }{Dir: "./dir/"}}).Valid()) +} diff --git a/plugins/static/plugin.go b/plugins/static/plugin.go new file mode 100644 index 00000000..cf5cee25 --- /dev/null +++ b/plugins/static/plugin.go @@ -0,0 +1,110 @@ +package static + +import ( + "net/http" + "path" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/log" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// ID contains default service name. +const PluginName = "static" + +const RootPluginName = "http" + +// Plugin serves static files. Potentially convert into middleware? +type Plugin struct { + // server configuration (location, forbidden files and etc) + cfg *Config + + log log.Logger + + // root is initiated http directory + root http.Dir +} + +// Init must return configure service and return true if service hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer, log log.Logger) error { + const op = errors.Op("static plugin init") + err := cfg.UnmarshalKey(RootPluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + s.log = log + s.root = http.Dir(s.cfg.Static.Dir) + + err = s.cfg.Valid() + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + return nil +} + +func (s *Plugin) Name() string { + return PluginName +} + +// middleware must return true if request/response pair is handled within the middleware. +func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if s.cfg.Static.Request != nil { + for k, v := range s.cfg.Static.Request { + r.Header.Add(k, v) + } + } + + if s.cfg.Static.Response != nil { + for k, v := range s.cfg.Static.Response { + w.Header().Set(k, v) + } + } + + if !s.handleStatic(w, r) { + next.ServeHTTP(w, r) + } + } +} + +func (s *Plugin) handleStatic(w http.ResponseWriter, r *http.Request) bool { + fPath := path.Clean(r.URL.Path) + + if s.cfg.AlwaysForbid(fPath) { + return false + } + + f, err := s.root.Open(fPath) + if err != nil { + s.log.Error("file open error", "error", err) + if s.cfg.AlwaysServe(fPath) { + w.WriteHeader(404) + return true + } + + return false + } + defer func() { + err = f.Close() + if err != nil { + s.log.Error("file closing error", "error", err) + } + }() + + d, err := f.Stat() + if err != nil { + return false + } + + // do not serve directories + if d.IsDir() { + return false + } + + http.ServeContent(w, r, d.Name(), d.ModTime(), f) + return true +} diff --git a/plugins/static/tests/configs/.rr-http-static-disabled.yaml b/plugins/static/tests/configs/.rr-http-static-disabled.yaml new file mode 100644 index 00000000..d0b9b388 --- /dev/null +++ b/plugins/static/tests/configs/.rr-http-static-disabled.yaml @@ -0,0 +1,30 @@ +server: + command: "php ../../../tests/http/client.php pid pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:21234 + maxRequestSize: 1024 + middleware: [ "gzip", "static" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + static: + dir: "abc" #not exists + forbid: [ ".php", ".htaccess" ] + request: + "Example-Request-Header": "Value" + # Automatically add headers to every response. + response: + "X-Powered-By": "RoadRunner" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/static/tests/configs/.rr-http-static-files-disable.yaml b/plugins/static/tests/configs/.rr-http-static-files-disable.yaml new file mode 100644 index 00000000..a3d814a3 --- /dev/null +++ b/plugins/static/tests/configs/.rr-http-static-files-disable.yaml @@ -0,0 +1,30 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:45877 + maxRequestSize: 1024 + middleware: [ "gzip", "static" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + static: + dir: "../../../tests" + forbid: [ ".php" ] + request: + "Example-Request-Header": "Value" + # Automatically add headers to every response. + response: + "X-Powered-By": "RoadRunner" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/static/tests/configs/.rr-http-static-files.yaml b/plugins/static/tests/configs/.rr-http-static-files.yaml new file mode 100644 index 00000000..35938b80 --- /dev/null +++ b/plugins/static/tests/configs/.rr-http-static-files.yaml @@ -0,0 +1,31 @@ +server: + command: "php ../../../tests/http/client.php echo pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:34653 + maxRequestSize: 1024 + middleware: [ "gzip", "static" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + static: + dir: "../../../tests" + forbid: [ ".php", ".htaccess" ] + always: [ ".ico" ] + request: + "Example-Request-Header": "Value" + # Automatically add headers to every response. + response: + "X-Powered-By": "RoadRunner" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/static/tests/configs/.rr-http-static.yaml b/plugins/static/tests/configs/.rr-http-static.yaml new file mode 100644 index 00000000..80a1fa7f --- /dev/null +++ b/plugins/static/tests/configs/.rr-http-static.yaml @@ -0,0 +1,29 @@ +server: + command: "php ../../../tests/http/client.php pid pipes" + user: "" + group: "" + env: + "RR_HTTP": "true" + relay: "pipes" + relayTimeout: "20s" + +http: + debug: true + address: 127.0.0.1:21603 + maxRequestSize: 1024 + middleware: [ "gzip", "static" ] + trustedSubnets: [ "10.0.0.0/8", "127.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "::1/128", "fc00::/7", "fe80::/10" ] + uploads: + forbid: [ ".php", ".exe", ".bat" ] + static: + dir: "../../../tests" + forbid: [ "" ] + request: + "input": "custom-header" + response: + "output": "output-header" + pool: + numWorkers: 2 + maxJobs: 0 + allocateTimeout: 60s + destroyTimeout: 60s
\ No newline at end of file diff --git a/plugins/static/tests/static_plugin_test.go b/plugins/static/tests/static_plugin_test.go new file mode 100644 index 00000000..528d5eea --- /dev/null +++ b/plugins/static/tests/static_plugin_test.go @@ -0,0 +1,423 @@ +package tests + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/spiral/endure" + "github.com/spiral/roadrunner/v2/mocks" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/gzip" + httpPlugin "github.com/spiral/roadrunner/v2/plugins/http" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/spiral/roadrunner/v2/plugins/static" + "github.com/stretchr/testify/assert" +) + +func TestStaticPlugin(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http-static.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &gzip.Gzip{}, + &static.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("ServeSample", serveStaticSample) + t.Run("StaticNotForbid", staticNotForbid) + t.Run("StaticHeaders", staticHeaders) + wg.Wait() +} + +func staticHeaders(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost:21603/client.php", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.Header.Get("Output") != "output-header" { + t.Fatal("can't find output header in response") + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, all("../../../tests/client.php"), string(b)) + assert.Equal(t, all("../../../tests/client.php"), string(b)) +} + +func staticNotForbid(t *testing.T) { + b, r, err := get("http://localhost:21603/client.php") + assert.NoError(t, err) + assert.Equal(t, all("../../../tests/client.php"), b) + assert.Equal(t, all("../../../tests/client.php"), b) + _ = r.Body.Close() +} + +func serveStaticSample(t *testing.T) { + b, r, err := get("http://localhost:21603/sample.txt") + assert.NoError(t, err) + assert.Equal(t, "sample", b) + _ = r.Body.Close() +} + +func TestStaticDisabled(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http-static-disabled.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &gzip.Gzip{}, + &static.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("StaticDisabled", staticDisabled) + wg.Wait() +} + +func staticDisabled(t *testing.T) { + _, r, err := get("http://localhost:21234/sample.txt") + assert.Error(t, err) + assert.Nil(t, r) +} + +func TestStaticFilesDisabled(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http-static-files-disable.yaml", + Prefix: "rr", + } + + err = cont.RegisterAll( + cfg, + &logger.ZapLogger{}, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &gzip.Gzip{}, + &static.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("StaticFilesDisabled", staticFilesDisabled) + wg.Wait() +} + +func staticFilesDisabled(t *testing.T) { + b, r, err := get("http://localhost:45877/client.php?hello=world") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "WORLD", b) + _ = r.Body.Close() +} + +func TestStaticFilesForbid(t *testing.T) { + cont, err := endure.NewContainer(nil, endure.SetLogLevel(endure.DebugLevel)) + assert.NoError(t, err) + + cfg := &config.Viper{ + Path: "configs/.rr-http-static-files.yaml", + Prefix: "rr", + } + + controller := gomock.NewController(t) + mockLogger := mocks.NewMockLogger(controller) + + mockLogger.EXPECT().Debug("http handler response received", "elapsed", gomock.Any(), "remote address", "127.0.0.1").AnyTimes() + mockLogger.EXPECT().Error("file open error", "error", gomock.Any()).AnyTimes() + + err = cont.RegisterAll( + cfg, + mockLogger, + &server.Plugin{}, + &httpPlugin.Plugin{}, + &gzip.Gzip{}, + &static.Plugin{}, + ) + assert.NoError(t, err) + + err = cont.Init() + if err != nil { + t.Fatal(err) + } + + ch, err := cont.Serve() + assert.NoError(t, err) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + wg := &sync.WaitGroup{} + wg.Add(1) + + tt := time.NewTimer(time.Second * 10) + + go func() { + defer wg.Done() + for { + select { + case e := <-ch: + assert.Fail(t, "error", e.Error.Error()) + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + case <-sig: + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + case <-tt.C: + // timeout + err = cont.Stop() + if err != nil { + assert.FailNow(t, "error", err.Error()) + } + return + } + } + }() + + time.Sleep(time.Second) + t.Run("StaticTestFilesDir", staticTestFilesDir) + t.Run("StaticNotFound", staticNotFound) + t.Run("StaticFilesForbid", staticFilesForbid) + t.Run("StaticFilesAlways", staticFilesAlways) + wg.Wait() +} + +func staticTestFilesDir(t *testing.T) { + b, r, err := get("http://localhost:34653/http?hello=world") + assert.NoError(t, err) + assert.Equal(t, "WORLD", b) + _ = r.Body.Close() +} + +func staticNotFound(t *testing.T) { + b, _, _ := get("http://localhost:34653/client.XXX?hello=world") + assert.Equal(t, "WORLD", b) +} + +func staticFilesAlways(t *testing.T) { + _, r, err := get("http://localhost:34653/favicon.ico") + assert.NoError(t, err) + assert.Equal(t, 404, r.StatusCode) + _ = r.Body.Close() +} + +func staticFilesForbid(t *testing.T) { + b, r, err := get("http://localhost:34653/client.php?hello=world") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "WORLD", b) + _ = r.Body.Close() +} + +// HELPERS +func get(url string) (string, *http.Response, error) { + r, err := http.Get(url) + if err != nil { + return "", nil, err + } + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return "", nil, err + } + + err = r.Body.Close() + if err != nil { + return "", nil, err + } + + return string(b), r, err +} + +func all(fn string) string { + f, _ := os.Open(fn) + + b := new(bytes.Buffer) + _, err := io.Copy(b, f) + if err != nil { + return "" + } + + err = f.Close() + if err != nil { + return "" + } + + return b.String() +} |