diff options
Diffstat (limited to 'plugins')
65 files changed, 7939 insertions, 0 deletions
diff --git a/plugins/checker/config.go b/plugins/checker/config.go new file mode 100644 index 00000000..5f952592 --- /dev/null +++ b/plugins/checker/config.go @@ -0,0 +1,5 @@ +package checker + +type Config struct { + Address string +} diff --git a/plugins/checker/interface.go b/plugins/checker/interface.go new file mode 100644 index 00000000..dd9dcada --- /dev/null +++ b/plugins/checker/interface.go @@ -0,0 +1,11 @@ +package checker + +// Status consists of status code from the service +type Status struct { + Code int +} + +// Checker interface used to get latest status from plugin +type Checker interface { + Status() Status +} diff --git a/plugins/checker/plugin.go b/plugins/checker/plugin.go new file mode 100644 index 00000000..95f4f68c --- /dev/null +++ b/plugins/checker/plugin.go @@ -0,0 +1,151 @@ +package checker + +import ( + "fmt" + "net/http" + "time" + + "github.com/gofiber/fiber/v2" + fiberLogger "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const ( + // PluginName declares public plugin name. + PluginName = "status" +) + +type Plugin struct { + registry map[string]Checker + server *fiber.App + log logger.Logger + cfg *Config +} + +func (c *Plugin) Init(log logger.Logger, cfg config.Configurer) error { + const op = errors.Op("status plugin init") + err := cfg.UnmarshalKey(PluginName, &c.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + if c.cfg == nil { + return errors.E(errors.Disabled) + } + + c.registry = make(map[string]Checker) + c.log = log + return nil +} + +func (c *Plugin) Serve() chan error { + errCh := make(chan error, 1) + c.server = fiber.New(fiber.Config{ + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + IdleTimeout: time.Second * 5, + }) + c.server.Group("/v1", c.healthHandler) + c.server.Use(fiberLogger.New()) + c.server.Use("/health", c.healthHandler) + + go func() { + err := c.server.Listen(c.cfg.Address) + if err != nil { + errCh <- err + } + }() + + return errCh +} + +func (c *Plugin) Stop() error { + const op = errors.Op("checker stop") + err := c.server.Shutdown() + if err != nil { + return errors.E(op, err) + } + return nil +} + +// Reset named service. +func (c *Plugin) Status(name string) (Status, error) { + const op = errors.Op("get status") + svc, ok := c.registry[name] + if !ok { + return Status{}, errors.E(op, errors.Errorf("no such service: %s", name)) + } + + return svc.Status(), nil +} + +// CollectTarget collecting services which can provide Status. +func (c *Plugin) CollectTarget(name endure.Named, r Checker) error { + c.registry[name.Name()] = r + return nil +} + +// Collects declares services to be collected. +func (c *Plugin) Collects() []interface{} { + return []interface{}{ + c.CollectTarget, + } +} + +// Name of the service. +func (c *Plugin) Name() string { + return PluginName +} + +// RPCService returns associated rpc service. +func (c *Plugin) RPC() interface{} { + return &rpc{srv: c, log: c.log} +} + +type Plugins struct { + Plugins []string `query:"plugin"` +} + +const template string = "Service: %s: Status: %d\n" + +func (c *Plugin) healthHandler(ctx *fiber.Ctx) error { + const op = errors.Op("health_handler") + plugins := &Plugins{} + err := ctx.QueryParser(plugins) + if err != nil { + return errors.E(op, err) + } + + if len(plugins.Plugins) == 0 { + ctx.Status(http.StatusOK) + _, _ = ctx.WriteString("No plugins provided in query. Query should be in form of: /v1/health?plugin=plugin1&plugin=plugin2 \n") + return nil + } + + failed := false + // iterate over all provided plugins + for i := 0; i < len(plugins.Plugins); i++ { + // check if the plugin exists + if plugin, ok := c.registry[plugins.Plugins[i]]; ok { + st := plugin.Status() + if st.Code >= 500 { + failed = true + continue + } else if st.Code >= 100 && st.Code <= 400 { + _, _ = ctx.WriteString(fmt.Sprintf(template, plugins.Plugins[i], st.Code)) + } + } else { + _, _ = ctx.WriteString(fmt.Sprintf("Service: %s not found", plugins.Plugins[i])) + } + } + if failed { + ctx.Status(http.StatusInternalServerError) + return nil + } + + ctx.Status(http.StatusOK) + return nil +} diff --git a/plugins/checker/rpc.go b/plugins/checker/rpc.go new file mode 100644 index 00000000..0daa62fe --- /dev/null +++ b/plugins/checker/rpc.go @@ -0,0 +1,27 @@ +package checker + +import ( + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +type rpc struct { + srv *Plugin + log logger.Logger +} + +// Status return current status of the provided plugin +func (rpc *rpc) Status(service string, status *Status) error { + const op = errors.Op("status") + rpc.log.Debug("started Status method", "service", service) + st, err := rpc.srv.Status(service) + if err != nil { + return errors.E(op, err) + } + + *status = st + + rpc.log.Debug("status code", "code", st.Code) + rpc.log.Debug("successfully finished Status method") + return nil +} diff --git a/plugins/config/interface.go b/plugins/config/interface.go new file mode 100644 index 00000000..23279f53 --- /dev/null +++ b/plugins/config/interface.go @@ -0,0 +1,26 @@ +package config + +type Configurer interface { + // // UnmarshalKey takes a single key and unmarshals it into a Struct. + // + // func (h *HttpService) Init(cp config.Configurer) error { + // h.config := &HttpConfig{} + // if err := configProvider.UnmarshalKey("http", h.config); err != nil { + // return err + // } + // } + UnmarshalKey(name string, out interface{}) error + + // Unmarshal unmarshals the config into a Struct. Make sure that the tags + // on the fields of the structure are properly set. + Unmarshal(out interface{}) error + + // Get used to get config section + Get(name string) interface{} + + // Overwrite used to overwrite particular values in the unmarshalled config + Overwrite(values map[string]interface{}) error + + // Has checks if config section exists. + Has(name string) bool +} diff --git a/plugins/config/plugin.go b/plugins/config/plugin.go new file mode 100755 index 00000000..9cecf9f9 --- /dev/null +++ b/plugins/config/plugin.go @@ -0,0 +1,84 @@ +package config + +import ( + "bytes" + "strings" + + "github.com/spf13/viper" + "github.com/spiral/errors" +) + +type Viper struct { + viper *viper.Viper + Path string + Prefix string + Type string + ReadInCfg []byte +} + +// Inits config provider. +func (v *Viper) Init() error { + const op = errors.Op("viper plugin init") + v.viper = viper.New() + // If user provided []byte data with config, read it and ignore Path and Prefix + if v.ReadInCfg != nil && v.Type != "" { + v.viper.SetConfigType("yaml") + return v.viper.ReadConfig(bytes.NewBuffer(v.ReadInCfg)) + } + + // read in environment variables that match + v.viper.AutomaticEnv() + if v.Prefix == "" { + return errors.E(op, errors.Str("prefix should be set")) + } + + v.viper.SetEnvPrefix(v.Prefix) + if v.Path == "" { + return errors.E(op, errors.Str("path should be set")) + } + + v.viper.SetConfigFile(v.Path) + v.viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + + return v.viper.ReadInConfig() +} + +// Overwrite overwrites existing config with provided values +func (v *Viper) Overwrite(values map[string]interface{}) error { + if len(values) != 0 { + for key, value := range values { + v.viper.Set(key, value) + } + } + + return nil +} + +// UnmarshalKey reads configuration section into configuration object. +func (v *Viper) UnmarshalKey(name string, out interface{}) error { + const op = errors.Op("unmarshal key") + err := v.viper.UnmarshalKey(name, &out) + if err != nil { + return errors.E(op, err) + } + return nil +} + +func (v *Viper) Unmarshal(out interface{}) error { + const op = errors.Op("config unmarshal") + err := v.viper.Unmarshal(&out) + if err != nil { + return errors.E(op, err) + } + return nil +} + +// Get raw config in a form of config section. +func (v *Viper) Get(name string) interface{} { + return v.viper.Get(name) +} + +// Has checks if config section exists. +func (v *Viper) Has(name string) bool { + return v.viper.IsSet(name) +} diff --git a/plugins/doc/graphviz.svg b/plugins/doc/graphviz.svg new file mode 100644 index 00000000..86f6ab5c --- /dev/null +++ b/plugins/doc/graphviz.svg @@ -0,0 +1,169 @@ +<?xml version="1.0" encoding="UTF-8" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><!-- Generated by graphviz version 2.40.1 (20161225.0304) + --><!-- Title: endure Pages: 1 --><svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="625pt" height="479pt" viewBox="0.00 0.00 624.94 478.79"> +<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 474.786)"> +<title>endure</title> +<polygon fill="#ffffff" stroke="transparent" points="-4,4 -4,-474.786 620.9357,-474.786 620.9357,4 -4,4"/> +<!-- checker --> +<g id="node1" class="node"> +<title>checker</title> +<ellipse fill="none" stroke="#000000" cx="412.2429" cy="-377.2862" rx="41.1103" ry="18"/> +<text text-anchor="middle" x="412.2429" y="-373.0862" font-family="Times,serif" font-size="14.00" fill="#000000">checker</text> +</g> +<!-- config --> +<g id="node2" class="node"> +<title>config</title> +<ellipse fill="none" stroke="#000000" cx="463.8878" cy="-235.393" rx="35.9154" ry="18"/> +<text text-anchor="middle" x="463.8878" y="-231.193" font-family="Times,serif" font-size="14.00" fill="#000000">config</text> +</g> +<!-- checker->config --> +<g id="edge1" class="edge"> +<title>checker->config</title> +<path fill="none" stroke="#000000" d="M418.7837,-359.3154C427.6313,-335.0068 443.4953,-291.4209 453.8554,-262.9568"/> +<polygon fill="#000000" stroke="#000000" points="457.2687,-263.812 457.4,-253.218 450.6908,-261.4178 457.2687,-263.812"/> +</g> +<!-- logger --> +<g id="node3" class="node"> +<title>logger</title> +<ellipse fill="none" stroke="#000000" cx="35.7071" cy="-310.8928" rx="35.9154" ry="18"/> +<text text-anchor="middle" x="35.7071" y="-306.6928" font-family="Times,serif" font-size="14.00" fill="#000000">logger</text> +</g> +<!-- checker->logger --> +<g id="edge2" class="edge"> +<title>checker->logger</title> +<path fill="none" stroke="#000000" d="M374.0665,-370.5547C303.7112,-358.1492 154.0014,-331.7513 79.586,-318.6299"/> +<polygon fill="#000000" stroke="#000000" points="80.0574,-315.1591 69.6015,-316.8693 78.8418,-322.0527 80.0574,-315.1591"/> +</g> +<!-- logger->config --> +<g id="edge4" class="edge"> +<title>logger->config</title> +<path fill="none" stroke="#000000" d="M69.6636,-304.9054C146.6435,-291.3317 334.3698,-258.2305 420.0048,-243.1308"/> +<polygon fill="#000000" stroke="#000000" points="420.6875,-246.5645 429.9277,-241.3811 419.4719,-239.6708 420.6875,-246.5645"/> +</g> +<!-- gzip --> +<g id="node4" class="node"> +<title>gzip</title> +<ellipse fill="none" stroke="#000000" cx="531.6651" cy="-102.393" rx="27.8286" ry="18"/> +<text text-anchor="middle" x="531.6651" y="-98.193" font-family="Times,serif" font-size="14.00" fill="#000000">gzip</text> +</g> +<!-- headers --> +<g id="node5" class="node"> +<title>headers</title> +<ellipse fill="none" stroke="#000000" cx="576.4118" cy="-235.393" rx="40.548" ry="18"/> +<text text-anchor="middle" x="576.4118" y="-231.193" font-family="Times,serif" font-size="14.00" fill="#000000">headers</text> +</g> +<!-- headers->config --> +<g id="edge3" class="edge"> +<title>headers->config</title> +<path fill="none" stroke="#000000" d="M535.788,-235.393C527.3742,-235.393 518.4534,-235.393 509.8639,-235.393"/> +<polygon fill="#000000" stroke="#000000" points="509.607,-231.8931 499.607,-235.393 509.607,-238.8931 509.607,-231.8931"/> +</g> +<!-- metrics --> +<g id="node6" class="node"> +<title>metrics</title> +<ellipse fill="none" stroke="#000000" cx="412.2429" cy="-93.4998" rx="39.4196" ry="18"/> +<text text-anchor="middle" x="412.2429" y="-89.2998" font-family="Times,serif" font-size="14.00" fill="#000000">metrics</text> +</g> +<!-- metrics->config --> +<g id="edge6" class="edge"> +<title>metrics->config</title> +<path fill="none" stroke="#000000" d="M418.7837,-111.4707C427.6313,-135.7792 443.4953,-179.3651 453.8554,-207.8292"/> +<polygon fill="#000000" stroke="#000000" points="450.6908,-209.3682 457.4,-217.5681 457.2687,-206.974 450.6908,-209.3682"/> +</g> +<!-- metrics->logger --> +<g id="edge5" class="edge"> +<title>metrics->logger</title> +<path fill="none" stroke="#000000" d="M387.5373,-107.7636C321.7958,-145.7194 142.5487,-249.2078 68.4432,-291.9926"/> +<polygon fill="#000000" stroke="#000000" points="66.4391,-289.1082 59.5289,-297.1393 69.9391,-295.1704 66.4391,-289.1082"/> +</g> +<!-- redis --> +<g id="node7" class="node"> +<title>redis</title> +<ellipse fill="none" stroke="#000000" cx="281.4734" cy="-18" rx="29.6127" ry="18"/> +<text text-anchor="middle" x="281.4734" y="-13.8" font-family="Times,serif" font-size="14.00" fill="#000000">redis</text> +</g> +<!-- redis->config --> +<g id="edge8" class="edge"> +<title>redis->config</title> +<path fill="none" stroke="#000000" d="M295.1841,-34.3398C326.9308,-72.174 405.6399,-165.9759 443.2445,-210.7914"/> +<polygon fill="#000000" stroke="#000000" points="440.6581,-213.1541 449.7672,-218.5648 446.0204,-208.6545 440.6581,-213.1541"/> +</g> +<!-- redis->logger --> +<g id="edge7" class="edge"> +<title>redis->logger</title> +<path fill="none" stroke="#000000" d="M267.9098,-34.1644C227.1471,-82.7435 105.5381,-227.6715 56.5241,-286.0841"/> +<polygon fill="#000000" stroke="#000000" points="53.5843,-284.1426 49.8376,-294.0528 58.9466,-288.6421 53.5843,-284.1426"/> +</g> +<!-- reload --> +<g id="node8" class="node"> +<title>reload</title> +<ellipse fill="none" stroke="#000000" cx="281.4734" cy="-452.786" rx="35.3315" ry="18"/> +<text text-anchor="middle" x="281.4734" y="-448.586" font-family="Times,serif" font-size="14.00" fill="#000000">reload</text> +</g> +<!-- reload->config --> +<g id="edge10" class="edge"> +<title>reload->config</title> +<path fill="none" stroke="#000000" d="M295.4842,-436.0885C327.4495,-397.9939 405.8819,-304.5217 443.3335,-259.8887"/> +<polygon fill="#000000" stroke="#000000" points="446.0824,-262.0576 449.8292,-252.1474 440.7201,-257.5581 446.0824,-262.0576"/> +</g> +<!-- reload->logger --> +<g id="edge9" class="edge"> +<title>reload->logger</title> +<path fill="none" stroke="#000000" d="M257.9083,-439.1807C213.6848,-413.6483 118.2025,-358.5216 68.0211,-329.5493"/> +<polygon fill="#000000" stroke="#000000" points="69.6111,-326.4259 59.2009,-324.457 66.1111,-332.4881 69.6111,-326.4259"/> +</g> +<!-- resetter --> +<g id="node9" class="node"> +<title>resetter</title> +<ellipse fill="none" stroke="#000000" cx="132.7678" cy="-426.5652" rx="39.3984" ry="18"/> +<text text-anchor="middle" x="132.7678" y="-422.3652" font-family="Times,serif" font-size="14.00" fill="#000000">resetter</text> +</g> +<!-- reload->resetter --> +<g id="edge11" class="edge"> +<title>reload->resetter</title> +<path fill="none" stroke="#000000" d="M248.1009,-446.9016C227.9026,-443.3401 201.8366,-438.7439 179.5962,-434.8224"/> +<polygon fill="#000000" stroke="#000000" points="180.1376,-431.3639 169.6817,-433.0742 178.922,-438.2575 180.1376,-431.3639"/> +</g> +<!-- resetter->logger --> +<g id="edge12" class="edge"> +<title>resetter->logger</title> +<path fill="none" stroke="#000000" d="M118.4461,-409.4974C102.0084,-389.9077 74.9173,-357.6218 56.2379,-335.3605"/> +<polygon fill="#000000" stroke="#000000" points="58.881,-333.0653 49.7719,-327.6546 53.5187,-337.5649 58.881,-333.0653"/> +</g> +<!-- rpc --> +<g id="node10" class="node"> +<title>rpc</title> +<ellipse fill="none" stroke="#000000" cx="132.7678" cy="-44.2208" rx="27" ry="18"/> +<text text-anchor="middle" x="132.7678" y="-40.0208" font-family="Times,serif" font-size="14.00" fill="#000000">rpc</text> +</g> +<!-- rpc->config --> +<g id="edge13" class="edge"> +<title>rpc->config</title> +<path fill="none" stroke="#000000" d="M153.4808,-56.1795C209.3277,-88.4227 363.359,-177.3527 431.1448,-216.4889"/> +<polygon fill="#000000" stroke="#000000" points="429.7078,-219.7006 440.1181,-221.6696 433.2078,-213.6384 429.7078,-219.7006"/> +</g> +<!-- rpc->logger --> +<g id="edge14" class="edge"> +<title>rpc->logger</title> +<path fill="none" stroke="#000000" d="M126.3994,-61.7179C109.8827,-107.097 65.5725,-228.8383 45.6502,-283.5745"/> +<polygon fill="#000000" stroke="#000000" points="42.3576,-282.3876 42.2262,-292.9816 48.9354,-284.7818 42.3576,-282.3876"/> +</g> +<!-- static --> +<g id="node11" class="node"> +<title>static</title> +<ellipse fill="none" stroke="#000000" cx="35.7071" cy="-159.8932" rx="31.3333" ry="18"/> +<text text-anchor="middle" x="35.7071" y="-155.6932" font-family="Times,serif" font-size="14.00" fill="#000000">static</text> +</g> +<!-- static->config --> +<g id="edge15" class="edge"> +<title>static->config</title> +<path fill="none" stroke="#000000" d="M65.8159,-165.2022C140.1736,-178.3135 332.7753,-212.2743 419.9157,-227.6396"/> +<polygon fill="#000000" stroke="#000000" points="419.5489,-231.1288 430.0048,-229.4185 420.7645,-224.2351 419.5489,-231.1288"/> +</g> +<!-- static->logger --> +<g id="edge16" class="edge"> +<title>static->logger</title> +<path fill="none" stroke="#000000" d="M35.7071,-178.1073C35.7071,-204.0691 35.7071,-251.9543 35.7071,-282.5696"/> +<polygon fill="#000000" stroke="#000000" points="32.2072,-282.6141 35.7071,-292.6141 39.2072,-282.6142 32.2072,-282.6141"/> +</g> +</g> +</svg>
\ No newline at end of file diff --git a/plugins/gzip/plugin.go b/plugins/gzip/plugin.go new file mode 100644 index 00000000..e5b9e4f5 --- /dev/null +++ b/plugins/gzip/plugin.go @@ -0,0 +1,25 @@ +package gzip + +import ( + "net/http" + + "github.com/NYTimes/gziphandler" +) + +const PluginName = "gzip" + +type Gzip struct{} + +func (g *Gzip) Init() error { + return nil +} + +func (g *Gzip) Middleware(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + gziphandler.GzipHandler(next).ServeHTTP(w, r) + } +} + +func (g *Gzip) Name() string { + return PluginName +} diff --git a/plugins/headers/config.go b/plugins/headers/config.go new file mode 100644 index 00000000..8d4e29c2 --- /dev/null +++ b/plugins/headers/config.go @@ -0,0 +1,36 @@ +package headers + +// Config declares headers service configuration. +type Config struct { + Headers struct { + // CORS settings. + CORS *CORSConfig + + // Request headers to add to every payload send to PHP. + Request map[string]string + + // Response headers to add to every payload generated by PHP. + Response map[string]string + } +} + +// CORSConfig headers configuration. +type CORSConfig struct { + // AllowedOrigin: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowedOrigin string + + // AllowedHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowedHeaders string + + // AllowedMethods: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowedMethods string + + // AllowCredentials https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials *bool + + // ExposeHeaders: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers + ExposedHeaders string + + // MaxAge of CORS headers in seconds/ + MaxAge int +} diff --git a/plugins/headers/plugin.go b/plugins/headers/plugin.go new file mode 100644 index 00000000..f1c6e6f3 --- /dev/null +++ b/plugins/headers/plugin.go @@ -0,0 +1,117 @@ +package headers + +import ( + "net/http" + "strconv" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" +) + +// ID contains default service name. +const PluginName = "headers" +const RootPluginName = "http" + +// Service serves headers files. Potentially convert into middleware? +type Plugin struct { + // server configuration (location, forbidden files and etc) + cfg *Config +} + +// Init must return configure service and return true if service hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer) error { + const op = errors.Op("headers plugin init") + err := cfg.UnmarshalKey(RootPluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + return nil +} + +// middleware must return true if request/response pair is handled within the middleware. +func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if s.cfg.Headers.Request != nil { + for k, v := range s.cfg.Headers.Request { + r.Header.Add(k, v) + } + } + + if s.cfg.Headers.Response != nil { + for k, v := range s.cfg.Headers.Response { + w.Header().Set(k, v) + } + } + + if s.cfg.Headers.CORS != nil { + if r.Method == http.MethodOptions { + s.preflightRequest(w) + return + } + s.corsHeaders(w) + } + + next.ServeHTTP(w, r) + } +} + +func (s *Plugin) Name() string { + return PluginName +} + +// configure OPTIONS response +func (s *Plugin) preflightRequest(w http.ResponseWriter) { + headers := w.Header() + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.AllowedMethods != "" { + headers.Set("Access-Control-Allow-Methods", s.cfg.Headers.CORS.AllowedMethods) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } + + if s.cfg.Headers.CORS.MaxAge > 0 { + headers.Set("Access-Control-Max-Age", strconv.Itoa(s.cfg.Headers.CORS.MaxAge)) + } + + w.WriteHeader(http.StatusOK) +} + +// configure CORS headers +func (s *Plugin) corsHeaders(w http.ResponseWriter) { + headers := w.Header() + + headers.Add("Vary", "Origin") + + if s.cfg.Headers.CORS.AllowedOrigin != "" { + headers.Set("Access-Control-Allow-Origin", s.cfg.Headers.CORS.AllowedOrigin) + } + + if s.cfg.Headers.CORS.AllowedHeaders != "" { + headers.Set("Access-Control-Allow-Headers", s.cfg.Headers.CORS.AllowedHeaders) + } + + if s.cfg.Headers.CORS.ExposedHeaders != "" { + headers.Set("Access-Control-Expose-Headers", s.cfg.Headers.CORS.ExposedHeaders) + } + + if s.cfg.Headers.CORS.AllowCredentials != nil { + headers.Set("Access-Control-Allow-Credentials", strconv.FormatBool(*s.cfg.Headers.CORS.AllowCredentials)) + } +} diff --git a/plugins/http/attributes/attributes.go b/plugins/http/attributes/attributes.go new file mode 100644 index 00000000..4c453766 --- /dev/null +++ b/plugins/http/attributes/attributes.go @@ -0,0 +1,85 @@ +package attributes + +import ( + "context" + "errors" + "net/http" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return k.name } + +var ( + // PsrContextKey is a context key. It can be used in the http attributes + PsrContextKey = &contextKey{"psr_attributes"} +) + +type attrs map[string]interface{} + +func (v attrs) get(key string) interface{} { + if v == nil { + return "" + } + + return v[key] +} + +func (v attrs) set(key string, value interface{}) { + v[key] = value +} + +func (v attrs) del(key string) { + delete(v, key) +} + +// Init returns request with new context and attribute bag. +func Init(r *http.Request) *http.Request { + return r.WithContext(context.WithValue(r.Context(), PsrContextKey, attrs{})) +} + +// All returns all context attributes. +func All(r *http.Request) map[string]interface{} { + v := r.Context().Value(PsrContextKey) + if v == nil { + return attrs{} + } + + return v.(attrs) +} + +// Get gets the value from request context. It replaces any existing +// values. +func Get(r *http.Request, key string) interface{} { + v := r.Context().Value(PsrContextKey) + if v == nil { + return nil + } + + return v.(attrs).get(key) +} + +// Set sets the key to value. It replaces any existing +// values. Context specific. +func Set(r *http.Request, key string, value interface{}) error { + v := r.Context().Value(PsrContextKey) + if v == nil { + return errors.New("unable to find `psr:attributes` context key") + } + + v.(attrs).set(key, value) + return nil +} + +// Delete deletes values associated with attribute key. +func (v attrs) Delete(key string) { + if v == nil { + return + } + + v.del(key) +} diff --git a/plugins/http/config.go b/plugins/http/config.go new file mode 100644 index 00000000..abde8917 --- /dev/null +++ b/plugins/http/config.go @@ -0,0 +1,294 @@ +package http + +import ( + "net" + "os" + "runtime" + "strings" + "time" + + "github.com/spiral/errors" + poolImpl "github.com/spiral/roadrunner/v2/pkg/pool" +) + +// Cidrs is a slice of IPNet addresses +type Cidrs []*net.IPNet + +// IsTrusted checks if the ip address exists in the provided in the config addresses +func (c *Cidrs) IsTrusted(ip string) bool { + if len(*c) == 0 { + return false + } + + i := net.ParseIP(ip) + if i == nil { + return false + } + + for _, cird := range *c { + if cird.Contains(i) { + return true + } + } + + return false +} + +// Config configures RoadRunner HTTP server. +type Config struct { + // Port and port to handle as http server. + Address string + + // SSL defines https server options. + SSL *SSLConfig + + // FCGI configuration. You can use FastCGI without HTTP server. + FCGI *FCGIConfig + + // HTTP2 configuration + HTTP2 *HTTP2Config + + // MaxRequestSize specified max size for payload body in megabytes, set 0 to unlimited. + MaxRequestSize uint64 `yaml:"max_request_size"` + + // TrustedSubnets declare IP subnets which are allowed to set ip using X-Real-Ip and X-Forwarded-For + TrustedSubnets []string `yaml:"trusted_subnets"` + + // Uploads configures uploads configuration. + Uploads *UploadsConfig + + // Pool configures worker pool. + Pool *poolImpl.Config + + // Env is environment variables passed to the http pool + Env map[string]string + + // List of the middleware names (order will be preserved) + Middleware []string + + // slice of net.IPNet + cidrs Cidrs +} + +// FCGIConfig for FastCGI server. +type FCGIConfig struct { + // Address and port to handle as http server. + Address string +} + +// HTTP2Config HTTP/2 server customizations. +type HTTP2Config struct { + // Enable or disable HTTP/2 extension, default enable. + Enabled bool + + // H2C enables HTTP/2 over TCP + H2C bool + + // MaxConcurrentStreams defaults to 128. + MaxConcurrentStreams uint32 `yaml:"max_concurrent_streams"` +} + +// InitDefaults sets default values for HTTP/2 configuration. +func (cfg *HTTP2Config) InitDefaults() error { + cfg.Enabled = true + cfg.MaxConcurrentStreams = 128 + + return nil +} + +// SSLConfig defines https server configuration. +type SSLConfig struct { + // Port to listen as HTTPS server, defaults to 443. + Port int + + // Redirect when enabled forces all http connections to switch to https. + Redirect bool + + // Key defined private server key. + Key string + + // Cert is https certificate. + Cert string + + // Root CA file + RootCA string +} + +// EnableHTTP is true when http server must run. +func (c *Config) EnableHTTP() bool { + return c.Address != "" +} + +// EnableTLS returns true if pool must listen TLS connections. +func (c *Config) EnableTLS() bool { + return c.SSL.Key != "" || c.SSL.Cert != "" || c.SSL.RootCA != "" +} + +// EnableHTTP2 when HTTP/2 extension must be enabled (only with TSL). +func (c *Config) EnableHTTP2() bool { + return c.HTTP2.Enabled +} + +// EnableH2C when HTTP/2 extension must be enabled on TCP. +func (c *Config) EnableH2C() bool { + return c.HTTP2.H2C +} + +// EnableFCGI is true when FastCGI server must be enabled. +func (c *Config) EnableFCGI() bool { + return c.FCGI.Address != "" +} + +// InitDefaults must populate Config values using given Config source. Must return error if Config is not valid. +func (c *Config) InitDefaults() error { + if c.Pool == nil { + // default pool + c.Pool = &poolImpl.Config{ + Debug: false, + NumWorkers: int64(runtime.NumCPU()), + MaxJobs: 1000, + AllocateTimeout: time.Second * 60, + DestroyTimeout: time.Second * 60, + Supervisor: nil, + } + } + + if c.HTTP2 == nil { + c.HTTP2 = &HTTP2Config{} + } + + if c.FCGI == nil { + c.FCGI = &FCGIConfig{} + } + + if c.Uploads == nil { + c.Uploads = &UploadsConfig{} + } + + if c.SSL == nil { + c.SSL = &SSLConfig{} + } + + if c.SSL.Port == 0 { + c.SSL.Port = 443 + } + + err := c.HTTP2.InitDefaults() + if err != nil { + return err + } + err = c.Uploads.InitDefaults() + if err != nil { + return err + } + + if c.TrustedSubnets == nil { + // @see https://en.wikipedia.org/wiki/Reserved_IP_addresses + c.TrustedSubnets = []string{ + "10.0.0.0/8", + "127.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fc00::/7", + "fe80::/10", + } + } + + cidrs, err := ParseCIDRs(c.TrustedSubnets) + if err != nil { + return err + } + c.cidrs = cidrs + + return c.Valid() +} + +// ParseCIDRs parse IPNet addresses and return slice of its +func ParseCIDRs(subnets []string) (Cidrs, error) { + c := make(Cidrs, 0, len(subnets)) + for _, cidr := range subnets { + _, cr, err := net.ParseCIDR(cidr) + if err != nil { + return nil, err + } + + c = append(c, cr) + } + + return c, nil +} + +// IsTrusted if api can be trusted to use X-Real-Ip, X-Forwarded-For +func (c *Config) IsTrusted(ip string) bool { + if c.cidrs == nil { + return false + } + + i := net.ParseIP(ip) + if i == nil { + return false + } + + for _, cird := range c.cidrs { + if cird.Contains(i) { + return true + } + } + + return false +} + +// Valid validates the configuration. +func (c *Config) Valid() error { + const op = errors.Op("validation") + if c.Uploads == nil { + return errors.E(op, errors.Str("malformed uploads config")) + } + + if c.HTTP2 == nil { + return errors.E(op, errors.Str("malformed http2 config")) + } + + if c.Pool == nil { + return errors.E(op, "malformed pool config") + } + + if !c.EnableHTTP() && !c.EnableTLS() && !c.EnableFCGI() { + return errors.E(op, errors.Str("unable to run http service, no method has been specified (http, https, http/2 or FastCGI)")) + } + + if c.Address != "" && !strings.Contains(c.Address, ":") { + return errors.E(op, errors.Str("malformed http server address")) + } + + if c.EnableTLS() { + if _, err := os.Stat(c.SSL.Key); err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("key file '%s' does not exists", c.SSL.Key)) + } + + return err + } + + if _, err := os.Stat(c.SSL.Cert); err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("cert file '%s' does not exists", c.SSL.Cert)) + } + + return err + } + + // RootCA is optional, but if provided - check it + if c.SSL.RootCA != "" { + if _, err := os.Stat(c.SSL.RootCA); err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("root ca path provided, but path '%s' does not exists", c.SSL.RootCA)) + } + return err + } + } + } + + return nil +} diff --git a/plugins/http/constants.go b/plugins/http/constants.go new file mode 100644 index 00000000..c3d5c589 --- /dev/null +++ b/plugins/http/constants.go @@ -0,0 +1,8 @@ +package http + +import "net/http" + +var http2pushHeaderKey = http.CanonicalHeaderKey("http2-push") + +// TrailerHeaderKey http header key +var TrailerHeaderKey = http.CanonicalHeaderKey("trailer") diff --git a/plugins/http/errors.go b/plugins/http/errors.go new file mode 100644 index 00000000..fb8762ef --- /dev/null +++ b/plugins/http/errors.go @@ -0,0 +1,25 @@ +// +build !windows + +package http + +import ( + "errors" + "net" + "os" + "syscall" +) + +// Broken pipe +var errEPIPE = errors.New("EPIPE(32) -> connection reset by peer") + +// handleWriteError just check if error was caused by aborted connection on linux +func handleWriteError(err error) error { + if netErr, ok2 := err.(*net.OpError); ok2 { + if syscallErr, ok3 := netErr.Err.(*os.SyscallError); ok3 { + if syscallErr.Err == syscall.EPIPE { + return errEPIPE + } + } + } + return err +} diff --git a/plugins/http/errors_windows.go b/plugins/http/errors_windows.go new file mode 100644 index 00000000..3d0ba04c --- /dev/null +++ b/plugins/http/errors_windows.go @@ -0,0 +1,27 @@ +// +build windows + +package http + +import ( + "errors" + "net" + "os" + "syscall" +) + +//Software caused connection abort. +//An established connection was aborted by the software in your host computer, +//possibly due to a data transmission time-out or protocol error. +var errEPIPE = errors.New("WSAECONNABORTED (10053) -> an established connection was aborted by peer") + +// handleWriteError just check if error was caused by aborted connection on windows +func handleWriteError(err error) error { + if netErr, ok2 := err.(*net.OpError); ok2 { + if syscallErr, ok3 := netErr.Err.(*os.SyscallError); ok3 { + if syscallErr.Err == syscall.WSAECONNABORTED { + return errEPIPE + } + } + } + return err +} diff --git a/plugins/http/handler.go b/plugins/http/handler.go new file mode 100644 index 00000000..9c40cdfc --- /dev/null +++ b/plugins/http/handler.go @@ -0,0 +1,237 @@ +package http + +import ( + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/events" + "github.com/spiral/roadrunner/v2/interfaces/pool" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const ( + // EventResponse thrown after the request been processed. See ErrorEvent as payload. + EventResponse = iota + 500 + + // EventError thrown on any non job error provided by road runner server. + EventError +) + +// MB is 1024 bytes +const MB = 1024 * 1024 + +// ErrorEvent represents singular http error event. +type ErrorEvent struct { + // Request contains client request, must not be stored. + Request *http.Request + + // Error - associated error, if any. + Error error + + // event timings + start time.Time + elapsed time.Duration +} + +// Elapsed returns duration of the invocation. +func (e *ErrorEvent) Elapsed() time.Duration { + return e.elapsed +} + +// ResponseEvent represents singular http response event. +type ResponseEvent struct { + // Request contains client request, must not be stored. + Request *Request + + // Response contains service response. + Response *Response + + // event timings + start time.Time + elapsed time.Duration +} + +// Elapsed returns duration of the invocation. +func (e *ResponseEvent) Elapsed() time.Duration { + return e.elapsed +} + +// Handler serves http connections to underlying PHP application using PSR-7 protocol. Context will include request headers, +// parsed files and query, payload will include parsed form dataTree (if any). +type Handler struct { + maxRequestSize uint64 + uploads UploadsConfig + trusted Cidrs + log logger.Logger + pool pool.Pool + mul sync.Mutex + lsn events.Listener +} + +// NewHandler return handle interface implementation +func NewHandler(maxReqSize uint64, uploads UploadsConfig, trusted Cidrs, pool pool.Pool) (*Handler, error) { + if pool == nil { + return nil, errors.E(errors.Str("pool should be initialized")) + } + return &Handler{ + maxRequestSize: maxReqSize * MB, + uploads: uploads, + pool: pool, + trusted: trusted, + }, nil +} + +// AddListener attaches handler event controller. +func (h *Handler) AddListener(l events.Listener) { + h.mul.Lock() + defer h.mul.Unlock() + + h.lsn = l +} + +// mdwr serve using PSR-7 requests passed to underlying application. Attempts to serve static files first if enabled. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + const op = errors.Op("ServeHTTP") + start := time.Now() + + // validating request size + if h.maxRequestSize != 0 { + err := h.maxSize(w, r, start, op) + if err != nil { + return + } + } + + req, err := NewRequest(r, h.uploads) + if err != nil { + h.handleError(w, r, err, start) + return + } + + // proxy IP resolution + h.resolveIP(req) + + req.Open(h.log) + defer req.Close(h.log) + + p, err := req.Payload() + if err != nil { + h.handleError(w, r, err, start) + return + } + + rsp, err := h.pool.Exec(p) + if err != nil { + h.handleError(w, r, err, start) + return + } + + resp, err := NewResponse(rsp) + if err != nil { + h.handleError(w, r, err, start) + return + } + + h.handleResponse(req, resp, start) + err = resp.Write(w) + if err != nil { + h.handleError(w, r, err, start) + } +} + +func (h *Handler) maxSize(w http.ResponseWriter, r *http.Request, start time.Time, op errors.Op) error { + if length := r.Header.Get("content-length"); length != "" { + if size, err := strconv.ParseInt(length, 10, 64); err != nil { + h.handleError(w, r, err, start) + return err + } else if size > int64(h.maxRequestSize) { + h.handleError(w, r, errors.E(op, errors.Str("request body max size is exceeded")), start) + return err + } + } + return nil +} + +// handleError sends error. +func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error, start time.Time) { + h.mul.Lock() + defer h.mul.Unlock() + // if pipe is broken, there is no sense to write the header + // in this case we just report about error + if err == errEPIPE { + h.throw(ErrorEvent{Request: r, Error: err, start: start, elapsed: time.Since(start)}) + return + } + err = multierror.Append(err) + // ResponseWriter is ok, write the error code + w.WriteHeader(500) + _, err2 := w.Write([]byte(err.Error())) + // error during the writing to the ResponseWriter + if err2 != nil { + err = multierror.Append(err2, err) + // concat original error with ResponseWriter error + h.throw(ErrorEvent{Request: r, Error: errors.E(err), start: start, elapsed: time.Since(start)}) + return + } + h.throw(ErrorEvent{Request: r, Error: err, start: start, elapsed: time.Since(start)}) +} + +// handleResponse triggers response event. +func (h *Handler) handleResponse(req *Request, resp *Response, start time.Time) { + h.throw(ResponseEvent{Request: req, Response: resp, start: start, elapsed: time.Since(start)}) +} + +// throw invokes event handler if any. +func (h *Handler) throw(event interface{}) { + if h.lsn != nil { + h.lsn(event) + } +} + +// get real ip passing multiple proxy +func (h *Handler) resolveIP(r *Request) { + if h.trusted.IsTrusted(r.RemoteAddr) == false { + return + } + + if r.Header.Get("X-Forwarded-For") != "" { + ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",") + ipCount := len(ips) + + for i := ipCount - 1; i >= 0; i-- { + addr := strings.TrimSpace(ips[i]) + if net.ParseIP(addr) != nil { + r.RemoteAddr = addr + return + } + } + + return + } + + // The logic here is the following: + // In general case, we only expect X-Real-Ip header. If it exist, we get the IP address from header and set request Remote address + // But, if there is no X-Real-Ip header, we also trying to check CloudFlare headers + // True-Client-IP is a general CF header in which copied information from X-Real-Ip in CF. + // CF-Connecting-IP is an Enterprise feature and we check it last in order. + // This operations are near O(1) because Headers struct are the map type -> type MIMEHeader map[string][]string + if r.Header.Get("X-Real-Ip") != "" { + r.RemoteAddr = fetchIP(r.Header.Get("X-Real-Ip")) + return + } + + if r.Header.Get("True-Client-IP") != "" { + r.RemoteAddr = fetchIP(r.Header.Get("True-Client-IP")) + return + } + + if r.Header.Get("CF-Connecting-IP") != "" { + r.RemoteAddr = fetchIP(r.Header.Get("CF-Connecting-IP")) + } +} diff --git a/plugins/http/parse.go b/plugins/http/parse.go new file mode 100644 index 00000000..d4a1604b --- /dev/null +++ b/plugins/http/parse.go @@ -0,0 +1,147 @@ +package http + +import ( + "net/http" +) + +// MaxLevel defines maximum tree depth for incoming request data and files. +const MaxLevel = 127 + +type dataTree map[string]interface{} +type fileTree map[string]interface{} + +// parseData parses incoming request body into data tree. +func parseData(r *http.Request) dataTree { + data := make(dataTree) + if r.PostForm != nil { + for k, v := range r.PostForm { + data.push(k, v) + } + } + + if r.MultipartForm != nil { + for k, v := range r.MultipartForm.Value { + data.push(k, v) + } + } + + return data +} + +// pushes value into data tree. +func (d dataTree) push(k string, v []string) { + keys := FetchIndexes(k) + if len(keys) <= MaxLevel { + d.mount(keys, v) + } +} + +// mount mounts data tree recursively. +func (d dataTree) mount(i []string, v []string) { + if len(i) == 1 { + // single value context (last element) + d[i[0]] = v[len(v)-1] + return + } + + if len(i) == 2 && i[1] == "" { + // non associated array of elements + d[i[0]] = v + return + } + + if p, ok := d[i[0]]; ok { + p.(dataTree).mount(i[1:], v) + return + } + + d[i[0]] = make(dataTree) + d[i[0]].(dataTree).mount(i[1:], v) +} + +// parse incoming dataTree request into JSON (including contentMultipart form dataTree) +func parseUploads(r *http.Request, cfg UploadsConfig) *Uploads { + u := &Uploads{ + cfg: cfg, + tree: make(fileTree), + list: make([]*FileUpload, 0), + } + + for k, v := range r.MultipartForm.File { + files := make([]*FileUpload, 0, len(v)) + for _, f := range v { + files = append(files, NewUpload(f)) + } + + u.list = append(u.list, files...) + u.tree.push(k, files) + } + + return u +} + +// pushes new file upload into it's proper place. +func (d fileTree) push(k string, v []*FileUpload) { + keys := FetchIndexes(k) + if len(keys) <= MaxLevel { + d.mount(keys, v) + } +} + +// mount mounts data tree recursively. +func (d fileTree) mount(i []string, v []*FileUpload) { + if len(i) == 1 { + // single value context + d[i[0]] = v[0] + return + } + + if len(i) == 2 && i[1] == "" { + // non associated array of elements + d[i[0]] = v + return + } + + if p, ok := d[i[0]]; ok { + p.(fileTree).mount(i[1:], v) + return + } + + d[i[0]] = make(fileTree) + d[i[0]].(fileTree).mount(i[1:], v) +} + +// FetchIndexes parses input name and splits it into separate indexes list. +func FetchIndexes(s string) []string { + var ( + pos int + ch string + keys = make([]string, 1) + ) + + for _, c := range s { + ch = string(c) + switch ch { + case " ": + // ignore all spaces + continue + case "[": + pos = 1 + continue + case "]": + if pos == 1 { + keys = append(keys, "") + } + pos = 2 + default: + if pos == 1 || pos == 2 { + keys = append(keys, "") + } + + keys[len(keys)-1] += ch + pos = 0 + } + } + + return keys +} diff --git a/plugins/http/plugin.go b/plugins/http/plugin.go new file mode 100644 index 00000000..e6aba78b --- /dev/null +++ b/plugins/http/plugin.go @@ -0,0 +1,532 @@ +package http + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net/http" + "net/http/fcgi" + "net/url" + "strings" + "sync" + + "github.com/hashicorp/go-multierror" + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/pool" + "github.com/spiral/roadrunner/v2/interfaces/worker" + poolImpl "github.com/spiral/roadrunner/v2/pkg/pool" + "github.com/spiral/roadrunner/v2/plugins/checker" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/server" + "github.com/spiral/roadrunner/v2/utils" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "golang.org/x/sys/cpu" +) + +const ( + // PluginName declares plugin name. + PluginName = "http" + + // EventInitSSL thrown at moment of https initialization. SSL server passed as context. + EventInitSSL = 750 +) + +// Middleware interface +type Middleware interface { + Middleware(f http.Handler) http.HandlerFunc +} + +type middleware map[string]Middleware + +// Plugin manages pool, http servers. The main http plugin structure +type Plugin struct { + sync.RWMutex + + configurer config.Configurer + server server.Server + log logger.Logger + + cfg *Config + // middlewares to chain + mdwr middleware + + // Pool which attached to all servers + pool pool.Pool + + // servers RR handler + handler *Handler + + // servers + http *http.Server + https *http.Server + fcgi *http.Server +} + +// Init must return configure svc and return true if svc hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer, log logger.Logger, server server.Server) error { + const op = errors.Op("http Init") + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, err) + } + + err = s.cfg.InitDefaults() + if err != nil { + return errors.E(op, err) + } + + s.configurer = cfg + s.log = log + s.mdwr = make(map[string]Middleware) + + if !s.cfg.EnableHTTP() && !s.cfg.EnableTLS() && !s.cfg.EnableFCGI() { + return errors.E(op, errors.Disabled) + } + + s.pool, err = server.NewWorkerPool(context.Background(), poolImpl.Config{ + Debug: s.cfg.Pool.Debug, + NumWorkers: s.cfg.Pool.NumWorkers, + MaxJobs: s.cfg.Pool.MaxJobs, + AllocateTimeout: s.cfg.Pool.AllocateTimeout, + DestroyTimeout: s.cfg.Pool.DestroyTimeout, + Supervisor: s.cfg.Pool.Supervisor, + }, s.cfg.Env, s.logCallback) + if err != nil { + return errors.E(op, err) + } + + s.server = server + + return nil +} + +func (s *Plugin) logCallback(event interface{}) { + if ev, ok := event.(ResponseEvent); ok { + s.log.Debug("http handler response received", "elapsed", ev.Elapsed().String(), "remote address", ev.Request.RemoteAddr) + } +} + +// Serve serves the svc. +func (s *Plugin) Serve() chan error { + s.Lock() + defer s.Unlock() + + const op = errors.Op("serve http") + errCh := make(chan error, 2) + + var err error + s.handler, err = NewHandler( + s.cfg.MaxRequestSize, + *s.cfg.Uploads, + s.cfg.cidrs, + s.pool, + ) + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + + s.handler.AddListener(s.logCallback) + + if s.cfg.EnableHTTP() { + if s.cfg.EnableH2C() { + s.http = &http.Server{Addr: s.cfg.Address, Handler: h2c.NewHandler(s, &http2.Server{})} + } else { + s.http = &http.Server{Addr: s.cfg.Address, Handler: s} + } + } + + if s.cfg.EnableTLS() { + s.https = s.initSSL() + if s.cfg.SSL.RootCA != "" { + err = s.appendRootCa() + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + } + + if s.cfg.EnableHTTP2() { + if err := s.initHTTP2(); err != nil { + errCh <- errors.E(op, err) + return errCh + } + } + } + + if s.cfg.EnableFCGI() { + s.fcgi = &http.Server{Handler: s} + } + + // apply middlewares before starting the server + if len(s.mdwr) > 0 { + s.addMiddlewares() + } + + if s.http != nil { + go func() { + httpErr := s.http.ListenAndServe() + if httpErr != nil && httpErr != http.ErrServerClosed { + errCh <- errors.E(op, httpErr) + return + } + }() + } + + if s.https != nil { + go func() { + httpErr := s.https.ListenAndServeTLS( + s.cfg.SSL.Cert, + s.cfg.SSL.Key, + ) + + if httpErr != nil && httpErr != http.ErrServerClosed { + errCh <- errors.E(op, httpErr) + return + } + }() + } + + if s.fcgi != nil { + go func() { + httpErr := s.serveFCGI() + if httpErr != nil && httpErr != http.ErrServerClosed { + errCh <- errors.E(op, httpErr) + return + } + }() + } + + return errCh +} + +// Stop stops the http. +func (s *Plugin) Stop() error { + s.Lock() + defer s.Unlock() + + var err error + if s.fcgi != nil { + err = s.fcgi.Shutdown(context.Background()) + if err != nil && err != http.ErrServerClosed { + s.log.Error("error shutting down the fcgi server", "error", err) + // write error and try to stop other transport + err = multierror.Append(err) + } + } + + if s.https != nil { + err = s.https.Shutdown(context.Background()) + if err != nil && err != http.ErrServerClosed { + s.log.Error("error shutting down the https server", "error", err) + // write error and try to stop other transport + err = multierror.Append(err) + } + } + + if s.http != nil { + err = s.http.Shutdown(context.Background()) + if err != nil && err != http.ErrServerClosed { + s.log.Error("error shutting down the http server", "error", err) + // write error and try to stop other transport + err = multierror.Append(err) + } + } + + s.pool.Destroy(context.Background()) + + return err +} + +// ServeHTTP handles connection using set of middleware and pool PSR-7 server. +func (s *Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if headerContainsUpgrade(r, s) { + http.Error(w, "server does not support upgrade header", http.StatusInternalServerError) + return + } + + if s.redirect(w, r) { + return + } + + if s.https != nil && r.TLS != nil { + w.Header().Add("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") + } + + r = attributes.Init(r) + // protect the case, when user send Reset and we are replacing handler with pool + s.RLock() + s.handler.ServeHTTP(w, r) + s.RUnlock() +} + +// Workers returns associated pool workers +func (s *Plugin) Workers() []worker.BaseProcess { + return s.pool.Workers() +} + +// Name returns endure.Named interface implementation +func (s *Plugin) Name() string { + return PluginName +} + +// Reset destroys the old pool and replaces it with new one, waiting for old pool to die +func (s *Plugin) Reset() error { + s.Lock() + defer s.Unlock() + const op = errors.Op("http reset") + s.log.Info("HTTP plugin got restart request. Restarting...") + s.pool.Destroy(context.Background()) + s.pool = nil + + // re-read the config + err := s.configurer.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, err) + } + + s.pool, err = s.server.NewWorkerPool(context.Background(), poolImpl.Config{ + Debug: s.cfg.Pool.Debug, + NumWorkers: s.cfg.Pool.NumWorkers, + MaxJobs: s.cfg.Pool.MaxJobs, + AllocateTimeout: s.cfg.Pool.AllocateTimeout, + DestroyTimeout: s.cfg.Pool.DestroyTimeout, + Supervisor: s.cfg.Pool.Supervisor, + }, s.cfg.Env, s.logCallback) + if err != nil { + return errors.E(op, err) + } + + s.log.Info("HTTP listeners successfully re-added") + + s.log.Info("HTTP workers Pool successfully restarted") + s.handler, err = NewHandler( + s.cfg.MaxRequestSize, + *s.cfg.Uploads, + s.cfg.cidrs, + s.pool, + ) + if err != nil { + return errors.E(op, err) + } + + s.log.Info("HTTP plugin successfully restarted") + return nil +} + +// Collects collecting http middlewares +func (s *Plugin) Collects() []interface{} { + return []interface{}{ + s.AddMiddleware, + } +} + +// AddMiddleware is base requirement for the middleware (name and Middleware) +func (s *Plugin) AddMiddleware(name endure.Named, m Middleware) { + s.mdwr[name.Name()] = m +} + +// Status return status of the particular plugin +func (s *Plugin) Status() checker.Status { + workers := s.Workers() + for i := 0; i < len(workers); i++ { + if workers[i].State().IsActive() { + return checker.Status{ + Code: http.StatusOK, + } + } + } + // if there are no workers, threat this as error + return checker.Status{ + Code: http.StatusInternalServerError, + } +} + +func (s *Plugin) redirect(w http.ResponseWriter, r *http.Request) bool { + if s.https != nil && r.TLS == nil && s.cfg.SSL.Redirect { + target := &url.URL{ + Scheme: "https", + Host: s.tlsAddr(r.Host, false), + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, + } + + http.Redirect(w, r, target.String(), http.StatusTemporaryRedirect) + return true + } + return false +} + +func headerContainsUpgrade(r *http.Request, s *Plugin) bool { + if _, ok := r.Header["Upgrade"]; ok { + // https://golang.org/pkg/net/http/#Hijacker + s.log.Error("server does not support Upgrade header") + return true + } + return false +} + +// append RootCA to the https server TLS config +func (s *Plugin) appendRootCa() error { + const op = errors.Op("append root CA") + rootCAs, err := x509.SystemCertPool() + if err != nil { + return nil + } + if rootCAs == nil { + rootCAs = x509.NewCertPool() + } + + CA, err := ioutil.ReadFile(s.cfg.SSL.RootCA) + if err != nil { + return err + } + + // should append our CA cert + ok := rootCAs.AppendCertsFromPEM(CA) + if !ok { + return errors.E(op, errors.Str("could not append Certs from PEM")) + } + // disable "G402 (CWE-295): TLS MinVersion too low. (Confidence: HIGH, Severity: HIGH)" + // #nosec G402 + cfg := &tls.Config{ + InsecureSkipVerify: false, + RootCAs: rootCAs, + } + s.http.TLSConfig = cfg + + return nil +} + +// Init https server +func (s *Plugin) initSSL() *http.Server { + var topCipherSuites []uint16 + var defaultCipherSuitesTLS13 []uint16 + + hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + + if hasGCMAsm { + // If AES-GCM hardware is provided then priorities AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } + + DefaultCipherSuites := make([]uint16, 0, 22) + DefaultCipherSuites = append(DefaultCipherSuites, topCipherSuites...) + DefaultCipherSuites = append(DefaultCipherSuites, defaultCipherSuitesTLS13...) + + server := &http.Server{ + Addr: s.tlsAddr(s.cfg.Address, true), + Handler: s, + TLSConfig: &tls.Config{ + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.CurveP384, + tls.CurveP521, + tls.X25519, + }, + CipherSuites: DefaultCipherSuites, + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + }, + } + + return server +} + +// init http/2 server +func (s *Plugin) initHTTP2() error { + return http2.ConfigureServer(s.https, &http2.Server{ + MaxConcurrentStreams: s.cfg.HTTP2.MaxConcurrentStreams, + }) +} + +// serveFCGI starts FastCGI server. +func (s *Plugin) serveFCGI() error { + l, err := utils.CreateListener(s.cfg.FCGI.Address) + if err != nil { + return err + } + + err = fcgi.Serve(l, s.fcgi.Handler) + if err != nil { + return err + } + + return nil +} + +// tlsAddr replaces listen or host port with port configured by SSL config. +func (s *Plugin) tlsAddr(host string, forcePort bool) string { + // remove current forcePort first + host = strings.Split(host, ":")[0] + + if forcePort || s.cfg.SSL.Port != 443 { + host = fmt.Sprintf("%s:%v", host, s.cfg.SSL.Port) + } + + return host +} + +func (s *Plugin) addMiddlewares() { + if s.http != nil { + applyMiddlewares(s.http, s.mdwr, s.cfg.Middleware, s.log) + } + if s.https != nil { + applyMiddlewares(s.https, s.mdwr, s.cfg.Middleware, s.log) + } + + if s.fcgi != nil { + applyMiddlewares(s.fcgi, s.mdwr, s.cfg.Middleware, s.log) + } +} + +func applyMiddlewares(server *http.Server, middlewares map[string]Middleware, order []string, log logger.Logger) { + for i := 0; i < len(order); i++ { + if mdwr, ok := middlewares[order[i]]; ok { + server.Handler = mdwr.Middleware(server.Handler) + } else { + log.Warn("requested middleware does not exist", "requested", order[i]) + } + } +} diff --git a/plugins/http/request.go b/plugins/http/request.go new file mode 100644 index 00000000..3983fdde --- /dev/null +++ b/plugins/http/request.go @@ -0,0 +1,186 @@ +package http + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + + j "github.com/json-iterator/go" + "github.com/spiral/roadrunner/v2/pkg/payload" + "github.com/spiral/roadrunner/v2/plugins/http/attributes" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +var json = j.ConfigCompatibleWithStandardLibrary + +const ( + defaultMaxMemory = 32 << 20 // 32 MB + contentNone = iota + 900 + contentStream + contentMultipart + contentFormData +) + +// Request maps net/http requests to PSR7 compatible structure and managed state of temporary uploaded files. +type Request struct { + // RemoteAddr contains ip address of client, make sure to check X-Real-Ip and X-Forwarded-For for real client address. + RemoteAddr string `json:"remoteAddr"` + + // Protocol includes HTTP protocol version. + Protocol string `json:"protocol"` + + // Method contains name of HTTP method used for the request. + Method string `json:"method"` + + // URI contains full request URI with scheme and query. + URI string `json:"uri"` + + // Header contains list of request headers. + Header http.Header `json:"headers"` + + // Cookies contains list of request cookies. + Cookies map[string]string `json:"cookies"` + + // RawQuery contains non parsed query string (to be parsed on php end). + RawQuery string `json:"rawQuery"` + + // Parsed indicates that request body has been parsed on RR end. + Parsed bool `json:"parsed"` + + // Uploads contains list of uploaded files, their names, sized and associations with temporary files. + Uploads *Uploads `json:"uploads"` + + // Attributes can be set by chained mdwr to safely pass value from Golang to PHP. See: GetAttribute, SetAttribute functions. + Attributes map[string]interface{} `json:"attributes"` + + // request body can be parsedData or []byte + body interface{} +} + +func fetchIP(pair string) string { + if !strings.ContainsRune(pair, ':') { + return pair + } + + addr, _, _ := net.SplitHostPort(pair) + return addr +} + +// NewRequest creates new PSR7 compatible request using net/http request. +func NewRequest(r *http.Request, cfg UploadsConfig) (*Request, error) { + req := &Request{ + RemoteAddr: fetchIP(r.RemoteAddr), + Protocol: r.Proto, + Method: r.Method, + URI: uri(r), + Header: r.Header, + Cookies: make(map[string]string), + RawQuery: r.URL.RawQuery, + Attributes: attributes.All(r), + } + + for _, c := range r.Cookies() { + if v, err := url.QueryUnescape(c.Value); err == nil { + req.Cookies[c.Name] = v + } + } + + switch req.contentType() { + case contentNone: + return req, nil + + case contentStream: + var err error + req.body, err = ioutil.ReadAll(r.Body) + return req, err + + case contentMultipart: + if err := r.ParseMultipartForm(defaultMaxMemory); err != nil { + return nil, err + } + + req.Uploads = parseUploads(r, cfg) + fallthrough + case contentFormData: + if err := r.ParseForm(); err != nil { + return nil, err + } + + req.body = parseData(r) + } + + req.Parsed = true + return req, nil +} + +// Open moves all uploaded files to temporary directory so it can be given to php later. +func (r *Request) Open(log logger.Logger) { + if r.Uploads == nil { + return + } + + r.Uploads.Open(log) +} + +// Close clears all temp file uploads +func (r *Request) Close(log logger.Logger) { + if r.Uploads == nil { + return + } + + r.Uploads.Clear(log) +} + +// Payload request marshaled RoadRunner payload based on PSR7 data. values encode method is JSON. Make sure to open +// files prior to calling this method. +func (r *Request) Payload() (payload.Payload, error) { + p := payload.Payload{} + + var err error + if p.Context, err = json.Marshal(r); err != nil { + return payload.Payload{}, err + } + + if r.Parsed { + if p.Body, err = json.Marshal(r.body); err != nil { + return payload.Payload{}, err + } + } else if r.body != nil { + p.Body = r.body.([]byte) + } + + return p, nil +} + +// contentType returns the payload content type. +func (r *Request) contentType() int { + if r.Method == "HEAD" || r.Method == "OPTIONS" { + return contentNone + } + + ct := r.Header.Get("content-type") + if strings.Contains(ct, "application/x-www-form-urlencoded") { + return contentFormData + } + + if strings.Contains(ct, "multipart/form-data") { + return contentMultipart + } + + return contentStream +} + +// uri fetches full uri from request in a form of string (including https scheme if TLS connection is enabled). +func uri(r *http.Request) string { + if r.URL.Host != "" { + return r.URL.String() + } + if r.TLS != nil { + return fmt.Sprintf("https://%s%s", r.Host, r.URL.String()) + } + + return fmt.Sprintf("http://%s%s", r.Host, r.URL.String()) +} diff --git a/plugins/http/response.go b/plugins/http/response.go new file mode 100644 index 00000000..17049ce1 --- /dev/null +++ b/plugins/http/response.go @@ -0,0 +1,105 @@ +package http + +import ( + "io" + "net/http" + "strings" + "sync" + + "github.com/spiral/roadrunner/v2/pkg/payload" +) + +// Response handles PSR7 response logic. +type Response struct { + // Status contains response status. + Status int `json:"status"` + + // Header contains list of response headers. + Headers map[string][]string `json:"headers"` + + // associated Body payload. + Body interface{} + sync.Mutex +} + +// NewResponse creates new response based on given pool payload. +func NewResponse(p payload.Payload) (*Response, error) { + r := &Response{Body: p.Body} + if err := json.Unmarshal(p.Context, r); err != nil { + return nil, err + } + + return r, nil +} + +// Write writes response headers, status and body into ResponseWriter. +func (r *Response) Write(w http.ResponseWriter) error { + // INFO map is the reference type in golang + p := handlePushHeaders(r.Headers) + if pusher, ok := w.(http.Pusher); ok { + for _, v := range p { + err := pusher.Push(v, nil) + if err != nil { + return err + } + } + } + + handleTrailers(r.Headers) + for n, h := range r.Headers { + for _, v := range h { + w.Header().Add(n, v) + } + } + + w.WriteHeader(r.Status) + + if data, ok := r.Body.([]byte); ok { + _, err := w.Write(data) + if err != nil { + return handleWriteError(err) + } + } + + if rc, ok := r.Body.(io.Reader); ok { + if _, err := io.Copy(w, rc); err != nil { + return err + } + } + + return nil +} + +func handlePushHeaders(h map[string][]string) []string { + var p []string + pushHeader, ok := h[http2pushHeaderKey] + if !ok { + return p + } + + p = append(p, pushHeader...) + + delete(h, http2pushHeaderKey) + + return p +} + +func handleTrailers(h map[string][]string) { + trailers, ok := h[TrailerHeaderKey] + if !ok { + return + } + + for _, tr := range trailers { + for _, n := range strings.Split(tr, ",") { + n = strings.Trim(n, "\t ") + if v, ok := h[n]; ok { + h["Trailer:"+n] = v + + delete(h, n) + } + } + } + + delete(h, TrailerHeaderKey) +} diff --git a/plugins/http/uploads.go b/plugins/http/uploads.go new file mode 100644 index 00000000..d5196844 --- /dev/null +++ b/plugins/http/uploads.go @@ -0,0 +1,158 @@ +package http + +import ( + "github.com/spiral/roadrunner/v2/plugins/logger" + + "io" + "io/ioutil" + "mime/multipart" + "os" + "sync" +) + +const ( + // UploadErrorOK - no error, the file uploaded with success. + UploadErrorOK = 0 + + // UploadErrorNoFile - no file was uploaded. + UploadErrorNoFile = 4 + + // UploadErrorNoTmpDir - missing a temporary folder. + UploadErrorNoTmpDir = 6 + + // UploadErrorCantWrite - failed to write file to disk. + UploadErrorCantWrite = 7 + + // UploadErrorExtension - forbidden file extension. + UploadErrorExtension = 8 +) + +// Uploads tree manages uploaded files tree and temporary files. +type Uploads struct { + // associated temp directory and forbidden extensions. + cfg UploadsConfig + + // pre processed data tree for Uploads. + tree fileTree + + // flat list of all file Uploads. + list []*FileUpload +} + +// MarshalJSON marshal tree tree into JSON. +func (u *Uploads) MarshalJSON() ([]byte, error) { + return json.Marshal(u.tree) +} + +// Open moves all uploaded files to temp directory, return error in case of issue with temp directory. File errors +// will be handled individually. +func (u *Uploads) Open(log logger.Logger) { + var wg sync.WaitGroup + for _, f := range u.list { + wg.Add(1) + go func(f *FileUpload) { + defer wg.Done() + err := f.Open(u.cfg) + if err != nil && log != nil { + log.Error("error opening the file", "err", err) + } + }(f) + } + + wg.Wait() +} + +// Clear deletes all temporary files. +func (u *Uploads) Clear(log logger.Logger) { + for _, f := range u.list { + if f.TempFilename != "" && exists(f.TempFilename) { + err := os.Remove(f.TempFilename) + if err != nil && log != nil { + log.Error("error removing the file", "err", err) + } + } + } +} + +// FileUpload represents singular file NewUpload. +type FileUpload struct { + // ID contains filename specified by the client. + Name string `json:"name"` + + // Mime contains mime-type provided by the client. + Mime string `json:"mime"` + + // Size of the uploaded file. + Size int64 `json:"size"` + + // Error indicates file upload error (if any). See http://php.net/manual/en/features.file-upload.errors.php + Error int `json:"error"` + + // TempFilename points to temporary file location. + TempFilename string `json:"tmpName"` + + // associated file header + header *multipart.FileHeader +} + +// NewUpload wraps net/http upload into PRS-7 compatible structure. +func NewUpload(f *multipart.FileHeader) *FileUpload { + return &FileUpload{ + Name: f.Filename, + Mime: f.Header.Get("Content-Type"), + Error: UploadErrorOK, + header: f, + } +} + +// Open moves file content into temporary file available for PHP. +// NOTE: +// There is 2 deferred functions, and in case of getting 2 errors from both functions +// error from close of temp file would be overwritten by error from the main file +// STACK +// DEFER FILE CLOSE (2) +// DEFER TMP CLOSE (1) +func (f *FileUpload) Open(cfg UploadsConfig) (err error) { + if cfg.Forbids(f.Name) { + f.Error = UploadErrorExtension + return nil + } + + file, err := f.header.Open() + if err != nil { + f.Error = UploadErrorNoFile + return err + } + + defer func() { + // close the main file + err = file.Close() + }() + + tmp, err := ioutil.TempFile(cfg.TmpDir(), "upload") + if err != nil { + // most likely cause of this issue is missing tmp dir + f.Error = UploadErrorNoTmpDir + return err + } + + f.TempFilename = tmp.Name() + defer func() { + // close the temp file + err = tmp.Close() + }() + + if f.Size, err = io.Copy(tmp, file); err != nil { + f.Error = UploadErrorCantWrite + } + + return err +} + +// exists if file exists. +func exists(path string) bool { + if _, err := os.Stat(path); os.IsNotExist(err) { + return false + } + return true +} diff --git a/plugins/http/uploads_config.go b/plugins/http/uploads_config.go new file mode 100644 index 00000000..4c20c8e8 --- /dev/null +++ b/plugins/http/uploads_config.go @@ -0,0 +1,46 @@ +package http + +import ( + "os" + "path" + "strings" +) + +// UploadsConfig describes file location and controls access to them. +type UploadsConfig struct { + // Dir contains name of directory to control access to. + Dir string + + // Forbid specifies list of file extensions which are forbidden for access. + // Example: .php, .exe, .bat, .htaccess and etc. + Forbid []string +} + +// InitDefaults sets missing values to their default values. +func (cfg *UploadsConfig) InitDefaults() error { + cfg.Forbid = []string{".php", ".exe", ".bat"} + cfg.Dir = os.TempDir() + return nil +} + +// TmpDir returns temporary directory. +func (cfg *UploadsConfig) TmpDir() string { + if cfg.Dir != "" { + return cfg.Dir + } + + return os.TempDir() +} + +// Forbids must return true if file extension is not allowed for the upload. +func (cfg *UploadsConfig) Forbids(filename string) bool { + ext := strings.ToLower(path.Ext(filename)) + + for _, v := range cfg.Forbid { + if ext == v { + return true + } + } + + return false +} diff --git a/plugins/informer/interface.go b/plugins/informer/interface.go new file mode 100644 index 00000000..27139ae1 --- /dev/null +++ b/plugins/informer/interface.go @@ -0,0 +1,8 @@ +package informer + +import "github.com/spiral/roadrunner/v2/interfaces/worker" + +// Informer used to get workers from particular plugin or set of plugins +type Informer interface { + Workers() []worker.BaseProcess +} diff --git a/plugins/informer/plugin.go b/plugins/informer/plugin.go new file mode 100644 index 00000000..3359cd7e --- /dev/null +++ b/plugins/informer/plugin.go @@ -0,0 +1,55 @@ +package informer + +import ( + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/interfaces/worker" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const PluginName = "informer" + +type Plugin struct { + registry map[string]Informer + log logger.Logger +} + +func (p *Plugin) Init(log logger.Logger) error { + p.registry = make(map[string]Informer) + p.log = log + return nil +} + +// Workers provides BaseProcess slice with workers for the requested plugin +func (p *Plugin) Workers(name string) ([]worker.BaseProcess, error) { + const op = errors.Op("get workers") + svc, ok := p.registry[name] + if !ok { + return nil, errors.E(op, errors.Errorf("no such service: %s", name)) + } + + return svc.Workers(), nil +} + +// CollectTarget resettable service. +func (p *Plugin) CollectTarget(name endure.Named, r Informer) error { + p.registry[name.Name()] = r + return nil +} + +// Collects declares services to be collected. +func (p *Plugin) Collects() []interface{} { + return []interface{}{ + p.CollectTarget, + } +} + +// Name of the service. +func (p *Plugin) Name() string { + return PluginName +} + +// RPCService returns associated rpc service. +func (p *Plugin) RPC() interface{} { + return &rpc{srv: p, log: p.log} +} diff --git a/plugins/informer/rpc.go b/plugins/informer/rpc.go new file mode 100644 index 00000000..98b5681c --- /dev/null +++ b/plugins/informer/rpc.go @@ -0,0 +1,54 @@ +package informer + +import ( + "github.com/spiral/roadrunner/v2/interfaces/worker" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/tools" +) + +type rpc struct { + srv *Plugin + log logger.Logger +} + +// WorkerList contains list of workers. +type WorkerList struct { + // Workers is list of workers. + Workers []tools.ProcessState `json:"workers"` +} + +// List all resettable services. +func (rpc *rpc) List(_ bool, list *[]string) error { + rpc.log.Debug("Started List method") + *list = make([]string, 0, len(rpc.srv.registry)) + + for name := range rpc.srv.registry { + *list = append(*list, name) + } + rpc.log.Debug("list of services", "list", *list) + + rpc.log.Debug("successfully finished List method") + return nil +} + +// Workers state of a given service. +func (rpc *rpc) Workers(service string, list *WorkerList) error { + rpc.log.Debug("started Workers method", "service", service) + workers, err := rpc.srv.Workers(service) + if err != nil { + return err + } + + list.Workers = make([]tools.ProcessState, 0) + for _, w := range workers { + ps, err := tools.WorkerProcessState(w.(worker.BaseProcess)) + if err != nil { + continue + } + + list.Workers = append(list.Workers, ps) + } + rpc.log.Debug("list of workers", "workers", list.Workers) + rpc.log.Debug("successfully finished Workers method") + return nil +} diff --git a/plugins/kv/boltdb/config.go b/plugins/kv/boltdb/config.go new file mode 100644 index 00000000..b2e1e636 --- /dev/null +++ b/plugins/kv/boltdb/config.go @@ -0,0 +1,24 @@ +package boltdb + +type Config struct { + // Dir is a directory to store the DB files + Dir string + // File is boltDB file. No need to create it by your own, + // boltdb driver is able to create the file, or read existing + File string + // Bucket to store data in boltDB + Bucket string + // db file permissions + Permissions int + // timeout + Interval uint `yaml:"interval"` +} + +// InitDefaults initializes default values for the boltdb +func (s *Config) InitDefaults() { + s.Dir = "." // current dir + s.Bucket = "rr" // default bucket name + s.File = "rr.db" // default file name + s.Permissions = 0777 // free for all + s.Interval = 60 // default is 60 seconds timeout +} diff --git a/plugins/kv/boltdb/plugin.go b/plugins/kv/boltdb/plugin.go new file mode 100644 index 00000000..6cfc49f6 --- /dev/null +++ b/plugins/kv/boltdb/plugin.go @@ -0,0 +1,452 @@ +package boltdb + +import ( + "bytes" + "encoding/gob" + "os" + "path" + "strings" + "sync" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/spiral/roadrunner/v2/plugins/logger" + bolt "go.etcd.io/bbolt" +) + +const PluginName = "boltdb" + +// BoltDB K/V storage. +type Plugin struct { + // db instance + DB *bolt.DB + // name should be UTF-8 + bucket []byte + + // config for RR integration + cfg *Config + + // logger + log logger.Logger + + // gc contains key which are contain timeouts + gc *sync.Map + // default timeout for cache cleanup is 1 minute + timeout time.Duration + + // stop is used to stop keys GC and close boltdb connection + stop chan struct{} +} + +func (s *Plugin) Init(log logger.Logger, cfg config.Configurer) error { + const op = errors.Op("boltdb plugin init") + s.cfg = &Config{} + + s.cfg.InitDefaults() + + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + // set the logger + s.log = log + + db, err := bolt.Open(path.Join(s.cfg.Dir, s.cfg.File), os.FileMode(s.cfg.Permissions), nil) + if err != nil { + return errors.E(op, err) + } + + // create bucket if it does not exist + // tx.Commit invokes via the db.Update + err = db.Update(func(tx *bolt.Tx) error { + const upOp = errors.Op("boltdb Update") + _, err = tx.CreateBucketIfNotExists([]byte(s.cfg.Bucket)) + if err != nil { + return errors.E(op, upOp) + } + return nil + }) + + if err != nil { + return errors.E(op, err) + } + + s.DB = db + s.bucket = []byte(s.cfg.Bucket) + s.stop = make(chan struct{}) + s.timeout = time.Duration(s.cfg.Interval) * time.Second + s.gc = &sync.Map{} + + return nil +} + +func (s *Plugin) Serve() chan error { + errCh := make(chan error, 1) + // start the TTL gc + go s.gcPhase() + + return errCh +} + +func (s *Plugin) Stop() error { + const op = errors.Op("boltdb stop") + err := s.Close() + if err != nil { + return errors.E(op, err) + } + return nil +} + +func (s *Plugin) Has(keys ...string) (map[string]bool, error) { + const op = errors.Op("boltdb Has") + s.log.Debug("boltdb HAS method called", "args", keys) + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + m := make(map[string]bool, len(keys)) + + // this is readable transaction + err := s.DB.View(func(tx *bolt.Tx) error { + // Get retrieves the value for a key in the bucket. + // Returns a nil value if the key does not exist or if the key is a nested bucket. + // The returned value is only valid for the life of the transaction. + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return errors.E(op, errors.EmptyKey) + } + b := tx.Bucket(s.bucket) + if b == nil { + return errors.E(op, errors.NoSuchBucket) + } + exist := b.Get([]byte(keys[i])) + if exist != nil { + m[keys[i]] = true + } + } + return nil + }) + if err != nil { + return nil, errors.E(op, err) + } + + s.log.Debug("boltdb HAS method finished") + return m, nil +} + +// Get retrieves the value for a key in the bucket. +// Returns a nil value if the key does not exist or if the key is a nested bucket. +// The returned value is only valid for the life of the transaction. +func (s *Plugin) Get(key string) ([]byte, error) { + const op = errors.Op("boltdb Get") + // to get cases like " " + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + + var val []byte + err := s.DB.View(func(tx *bolt.Tx) error { + b := tx.Bucket(s.bucket) + if b == nil { + return errors.E(op, errors.NoSuchBucket) + } + val = b.Get([]byte(key)) + + // try to decode values + if val != nil { + buf := bytes.NewReader(val) + decoder := gob.NewDecoder(buf) + + var i string + err := decoder.Decode(&i) + if err != nil { + // unsafe (w/o runes) convert + return errors.E(op, err) + } + + // set the value + val = []byte(i) + } + return nil + }) + if err != nil { + return nil, errors.E(op, err) + } + + return val, nil +} + +func (s *Plugin) MGet(keys ...string) (map[string]interface{}, error) { + const op = errors.Op("boltdb MGet") + // defence + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string]interface{}, len(keys)) + + err := s.DB.View(func(tx *bolt.Tx) error { + b := tx.Bucket(s.bucket) + if b == nil { + return errors.E(op, errors.NoSuchBucket) + } + + buf := new(bytes.Buffer) + var out string + buf.Grow(100) + for i := range keys { + value := b.Get([]byte(keys[i])) + buf.Write(value) + // allocate enough space + dec := gob.NewDecoder(buf) + if value != nil { + err := dec.Decode(&out) + if err != nil { + return errors.E(op, err) + } + m[keys[i]] = out + buf.Reset() + out = "" + } + } + + return nil + }) + if err != nil { + return nil, errors.E(op, err) + } + + return m, nil +} + +// Set puts the K/V to the bolt +func (s *Plugin) Set(items ...kv.Item) error { + const op = errors.Op("boltdb Set") + if items == nil { + return errors.E(op, errors.NoKeys) + } + + // start writable transaction + tx, err := s.DB.Begin(true) + if err != nil { + return errors.E(op, err) + } + defer func() { + err = tx.Commit() + if err != nil { + errRb := tx.Rollback() + if errRb != nil { + s.log.Error("during the commit, Rollback error occurred", "commit error", err, "rollback error", errRb) + } + } + }() + + b := tx.Bucket(s.bucket) + // use access by index to avoid copying + for i := range items { + // performance note: pass a prepared bytes slice with initial cap + // we can't move buf and gob out of loop, because we need to clear both from data + // but gob will contain (w/o re-init) the past data + buf := bytes.Buffer{} + encoder := gob.NewEncoder(&buf) + if errors.Is(errors.EmptyItem, err) { + return errors.E(op, errors.EmptyItem) + } + + // Encode value + err = encoder.Encode(&items[i].Value) + if err != nil { + return errors.E(op, err) + } + // buf.Bytes will copy the underlying slice. Take a look in case of performance problems + err = b.Put([]byte(items[i].Key), buf.Bytes()) + if err != nil { + return errors.E(op, err) + } + + // if there are no errors, and TTL > 0, we put the key with timeout to the hashmap, for future check + // we do not need mutex here, since we use sync.Map + if items[i].TTL != "" { + // check correctness of provided TTL + _, err := time.Parse(time.RFC3339, items[i].TTL) + if err != nil { + return errors.E(op, err) + } + // Store key TTL in the separate map + s.gc.Store(items[i].Key, items[i].TTL) + } + + buf.Reset() + } + + return nil +} + +// Delete all keys from DB +func (s *Plugin) Delete(keys ...string) error { + const op = errors.Op("boltdb Delete") + if keys == nil { + return errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for _, key := range keys { + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return errors.E(op, errors.EmptyKey) + } + } + + // start writable transaction + tx, err := s.DB.Begin(true) + if err != nil { + return errors.E(op, err) + } + + defer func() { + err = tx.Commit() + if err != nil { + errRb := tx.Rollback() + if errRb != nil { + s.log.Error("during the commit, Rollback error occurred", "commit error", err, "rollback error", errRb) + } + } + }() + + b := tx.Bucket(s.bucket) + if b == nil { + return errors.E(op, errors.NoSuchBucket) + } + + for _, key := range keys { + err = b.Delete([]byte(key)) + if err != nil { + return errors.E(op, err) + } + } + + return nil +} + +// MExpire sets the expiration time to the key +// If key already has the expiration time, it will be overwritten +func (s *Plugin) MExpire(items ...kv.Item) error { + const op = errors.Op("boltdb MExpire") + for i := range items { + if items[i].TTL == "" || strings.TrimSpace(items[i].Key) == "" { + return errors.E(op, errors.Str("should set timeout and at least one key")) + } + + // verify provided TTL + _, err := time.Parse(time.RFC3339, items[i].TTL) + if err != nil { + return errors.E(op, err) + } + + s.gc.Store(items[i].Key, items[i].TTL) + } + return nil +} + +func (s *Plugin) TTL(keys ...string) (map[string]interface{}, error) { + const op = errors.Op("boltdb TTL") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string]interface{}, len(keys)) + + for i := range keys { + if item, ok := s.gc.Load(keys[i]); ok { + // a little bit dangerous operation, but user can't store value other that kv.Item.TTL --> int64 + m[keys[i]] = item.(string) + } + } + return m, nil +} + +// Close the DB connection +func (s *Plugin) Close() error { + // stop the keys GC + s.stop <- struct{}{} + return s.DB.Close() +} + +// RPCService returns associated rpc service. +func (s *Plugin) RPC() interface{} { + return kv.NewRPCServer(s, s.log) +} + +// Name returns plugin name +func (s *Plugin) Name() string { + return PluginName +} + +// ========================= PRIVATE ================================= + +func (s *Plugin) gcPhase() { + t := time.NewTicker(s.timeout) + defer t.Stop() + for { + select { + case <-t.C: + // calculate current time before loop started to be fair + now := time.Now() + s.gc.Range(func(key, value interface{}) bool { + const op = errors.Op("gcPhase") + k := key.(string) + v, err := time.Parse(time.RFC3339, value.(string)) + if err != nil { + return false + } + + if now.After(v) { + // time expired + s.gc.Delete(k) + s.log.Debug("key deleted", "key", k) + err := s.DB.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(s.bucket) + if b == nil { + return errors.E(op, errors.NoSuchBucket) + } + err := b.Delete([]byte(k)) + if err != nil { + return errors.E(op, err) + } + return nil + }) + if err != nil { + s.log.Error("error during the gc phase of update", "error", err) + // todo this error is ignored, it means, that timer still be active + // to prevent this, we need to invoke t.Stop() + return false + } + } + return true + }) + case <-s.stop: + return + } + } +} diff --git a/plugins/kv/boltdb/plugin_unit_test.go b/plugins/kv/boltdb/plugin_unit_test.go new file mode 100644 index 00000000..2459e493 --- /dev/null +++ b/plugins/kv/boltdb/plugin_unit_test.go @@ -0,0 +1,531 @@ +package boltdb + +import ( + "os" + "strconv" + "sync" + "testing" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/stretchr/testify/assert" + bolt "go.etcd.io/bbolt" + "go.uber.org/zap" +) + +// NewBoltClient instantiate new BOLTDB client +// The parameters are: +// path string -- path to database file (can be placed anywhere), if file is not exist, it will be created +// perm os.FileMode -- file permissions, for example 0777 +// options *bolt.Options -- boltDB options, such as timeouts, noGrows options and other +// bucket string -- name of the bucket to use, should be UTF-8 +func newBoltClient(path string, perm os.FileMode, options *bolt.Options, bucket string, ttl time.Duration) (kv.Storage, error) { + const op = errors.Op("newBoltClient") + db, err := bolt.Open(path, perm, options) + if err != nil { + return nil, errors.E(op, err) + } + + // bucket should be SET + if bucket == "" { + return nil, errors.E(op, errors.Str("bucket should be set")) + } + + // create bucket if it does not exist + // tx.Commit invokes via the db.Update + err = db.Update(func(tx *bolt.Tx) error { + _, err = tx.CreateBucketIfNotExists([]byte(bucket)) + if err != nil { + return errors.E(op, err) + } + return nil + }) + if err != nil { + return nil, errors.E(op, err) + } + + // if TTL is not set, make it default + if ttl == 0 { + ttl = time.Minute + } + + l, _ := zap.NewDevelopment() + s := &Plugin{ + DB: db, + bucket: []byte(bucket), + stop: make(chan struct{}), + timeout: ttl, + gc: &sync.Map{}, + log: logger.NewZapAdapter(l), + } + + // start the TTL gc + go s.gcPhase() + + return s, nil +} + +func initStorage() kv.Storage { + storage, err := newBoltClient("rr.db", 0777, nil, "rr", time.Second) + if err != nil { + panic(err) + } + return storage +} + +func cleanup(t *testing.T, path string) { + err := os.RemoveAll(path) + if err != nil { + t.Fatal(err) + } +} + +func TestStorage_Has(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) +} + +func TestStorage_Has_Set_Has(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) +} + +func TestConcurrentReadWriteTransactions(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + wg := &sync.WaitGroup{} + wg.Add(3) + + m := &sync.RWMutex{} + // concurrently set the keys + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.Lock() + // set is writable transaction + // it should stop readable + assert.NoError(t, s.Set(kv.Item{ + Key: "key" + strconv.Itoa(i), + Value: "hello world" + strconv.Itoa(i), + TTL: "", + }, kv.Item{ + Key: "key2" + strconv.Itoa(i), + Value: "hello world" + strconv.Itoa(i), + TTL: "", + })) + m.Unlock() + } + }(s) + + // should be no errors + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.RLock() + v, err = s.Has("key") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + m.RUnlock() + } + }(s) + + // should be no errors + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.Lock() + err = s.Delete("key" + strconv.Itoa(i)) + assert.NoError(t, err) + m.Unlock() + } + }(s) + + wg.Wait() +} + +func TestStorage_Has_Set_MGet(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) +} + +func TestStorage_Has_Set_Get(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world2", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + res, err := s.Get("key") + assert.NoError(t, err) + + if string(res) != "hello world" { + t.Fatal("wrong value by key") + } +} + +func TestStorage_Set_Del_Get(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + // check that keys are present + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) + + assert.NoError(t, s.Delete("key", "key2")) + // check that keys are not present + res, err = s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 0) +} + +func TestStorage_Set_GetM(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) +} + +func TestNilAndWrongArgs(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + // check + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) + + _, err = s.Has("") + assert.Error(t, err) + + _, err = s.Get("") + assert.Error(t, err) + + _, err = s.Get(" ") + assert.Error(t, err) + + _, err = s.Get(" ") + assert.Error(t, err) + + _, err = s.MGet("key", "key2", "") + assert.Error(t, err) + + _, err = s.MGet("key", "key2", " ") + assert.Error(t, err) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + })) + + assert.Error(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "asdf", + })) + + _, err = s.Has("key") + assert.NoError(t, err) + + assert.Error(t, s.Set(kv.Item{})) + + err = s.Delete("") + assert.Error(t, err) + + err = s.Delete("key", "") + assert.Error(t, err) + + err = s.Delete("key", " ") + assert.Error(t, err) + + err = s.Delete("key") + assert.NoError(t, err) +} + +func TestStorage_MExpire_TTL(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + // ensure that storage is clean + v, err := s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + // set timeout to 5 sec + nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339) + + i1 := kv.Item{ + Key: "key", + Value: "", + TTL: nowPlusFive, + } + i2 := kv.Item{ + Key: "key2", + Value: "", + TTL: nowPlusFive, + } + assert.NoError(t, s.MExpire(i1, i2)) + + time.Sleep(time.Second * 6) + + // ensure that storage is clean + v, err = s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) +} + +func TestStorage_SetExpire_TTL(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + cleanup(t, "rr.db") + }() + + // ensure that storage is clean + v, err := s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339) + + // set timeout to 5 sec + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: nowPlusFive, + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: nowPlusFive, + })) + + time.Sleep(time.Second * 2) + m, err := s.TTL("key", "key2") + assert.NoError(t, err) + + // remove a precision 4.02342342 -> 4 + keyTTL, err := strconv.Atoi(m["key"].(string)[0:1]) + if err != nil { + t.Fatal(err) + } + + // remove a precision 4.02342342 -> 4 + key2TTL, err := strconv.Atoi(m["key"].(string)[0:1]) + if err != nil { + t.Fatal(err) + } + + assert.True(t, keyTTL < 5) + assert.True(t, key2TTL < 5) + + time.Sleep(time.Second * 4) + + // ensure that storage is clean + v, err = s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) +} diff --git a/plugins/kv/interface.go b/plugins/kv/interface.go new file mode 100644 index 00000000..c1367cdf --- /dev/null +++ b/plugins/kv/interface.go @@ -0,0 +1,41 @@ +package kv + +// Item represents general storage item +type Item struct { + // Key of item + Key string + // Value of item + Value string + // live until time provided by TTL in RFC 3339 format + TTL string +} + +// Storage represents single abstract storage. +type Storage interface { + // Has checks if value exists. + Has(keys ...string) (map[string]bool, error) + + // Get loads value content into a byte slice. + Get(key string) ([]byte, error) + + // MGet loads content of multiple values + // Returns the map with existing keys and associated values + MGet(keys ...string) (map[string]interface{}, error) + + // Set used to upload item to KV with TTL + // 0 value in TTL means no TTL + Set(items ...Item) error + + // MExpire sets the TTL for multiply keys + MExpire(items ...Item) error + + // TTL return the rest time to live for provided keys + // Not supported for the memcached and boltdb + TTL(keys ...string) (map[string]interface{}, error) + + // Delete one or multiple keys. + Delete(keys ...string) error + + // Close closes the storage and underlying resources. + Close() error +} diff --git a/plugins/kv/memcached/config.go b/plugins/kv/memcached/config.go new file mode 100644 index 00000000..62f29ef2 --- /dev/null +++ b/plugins/kv/memcached/config.go @@ -0,0 +1,10 @@ +package memcached + +type Config struct { + // Addr is url for memcached, 11211 port is used by default + Addr []string +} + +func (s *Config) InitDefaults() { + s.Addr = []string{"localhost:11211"} // default url for memcached // init logger +} diff --git a/plugins/kv/memcached/plugin.go b/plugins/kv/memcached/plugin.go new file mode 100644 index 00000000..f5111c04 --- /dev/null +++ b/plugins/kv/memcached/plugin.go @@ -0,0 +1,252 @@ +package memcached + +import ( + "strings" + "time" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const PluginName = "memcached" + +var EmptyItem = kv.Item{} + +type Plugin struct { + // config + cfg *Config + // logger + log logger.Logger + // memcached client + client *memcache.Client +} + +// NewMemcachedClient returns a memcache client using the provided server(s) +// with equal weight. If a server is listed multiple times, +// it gets a proportional amount of weight. +func NewMemcachedClient(url string) kv.Storage { + m := memcache.New(url) + return &Plugin{ + client: m, + } +} + +func (s *Plugin) Init(log logger.Logger, cfg config.Configurer) error { + const op = errors.Op("memcached init") + s.cfg = &Config{} + s.cfg.InitDefaults() + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, err) + } + s.log = log + return nil +} + +func (s *Plugin) Serve() chan error { + errCh := make(chan error, 1) + s.client = memcache.New(s.cfg.Addr...) + return errCh +} + +// Memcached has no stop/close or smt similar to close the connection +func (s *Plugin) Stop() error { + return nil +} + +// RPCService returns associated rpc service. +func (s *Plugin) RPC() interface{} { + return kv.NewRPCServer(s, s.log) +} + +// Name returns plugin user-friendly name +func (s *Plugin) Name() string { + return PluginName +} + +// Has checks the key for existence +func (s Plugin) Has(keys ...string) (map[string]bool, error) { + const op = errors.Op("memcached Has") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + m := make(map[string]bool, len(keys)) + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + exist, err := s.client.Get(keys[i]) + // ErrCacheMiss means that a Get failed because the item wasn't present. + if err != nil && err != memcache.ErrCacheMiss { + return nil, err + } + if exist != nil { + m[keys[i]] = true + } + } + return m, nil +} + +// Get gets the item for the given key. ErrCacheMiss is returned for a +// memcache cache miss. The key must be at most 250 bytes in length. +func (s Plugin) Get(key string) ([]byte, error) { + const op = errors.Op("memcached Get") + // to get cases like " " + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + data, err := s.client.Get(key) + // ErrCacheMiss means that a Get failed because the item wasn't present. + if err != nil && err != memcache.ErrCacheMiss { + return nil, err + } + if data != nil { + // return the value by the key + return data.Value, nil + } + // data is nil by some reason and error also nil + return nil, nil +} + +// return map with key -- string +// and map value as value -- []byte +func (s Plugin) MGet(keys ...string) (map[string]interface{}, error) { + const op = errors.Op("memcached MGet") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string]interface{}, len(keys)) + for i := range keys { + // Here also MultiGet + data, err := s.client.Get(keys[i]) + // ErrCacheMiss means that a Get failed because the item wasn't present. + if err != nil && err != memcache.ErrCacheMiss { + return nil, err + } + if data != nil { + m[keys[i]] = data.Value + } + } + + return m, nil +} + +// Set sets the KV pairs. Keys should be 250 bytes maximum +// TTL: +// Expiration is the cache expiration time, in seconds: either a relative +// time from now (up to 1 month), or an absolute Unix epoch time. +// Zero means the Item has no expiration time. +func (s Plugin) Set(items ...kv.Item) error { + const op = errors.Op("memcached Set") + if items == nil { + return errors.E(op, errors.NoKeys) + } + + for i := range items { + if items[i] == EmptyItem { + return errors.E(op, errors.EmptyItem) + } + + // pre-allocate item + memcachedItem := &memcache.Item{ + Key: items[i].Key, + // unsafe convert + Value: []byte(items[i].Value), + Flags: 0, + } + + // add additional TTL in case of TTL isn't empty + if items[i].TTL != "" { + // verify the TTL + t, err := time.Parse(time.RFC3339, items[i].TTL) + if err != nil { + return err + } + memcachedItem.Expiration = int32(t.Unix()) + } + + err := s.client.Set(memcachedItem) + if err != nil { + return err + } + } + + return nil +} + +// Expiration is the cache expiration time, in seconds: either a relative +// time from now (up to 1 month), or an absolute Unix epoch time. +// Zero means the Item has no expiration time. +func (s Plugin) MExpire(items ...kv.Item) error { + const op = errors.Op("memcached MExpire") + for i := range items { + if items[i].TTL == "" || strings.TrimSpace(items[i].Key) == "" { + return errors.E(op, errors.Str("should set timeout and at least one key")) + } + + // verify provided TTL + t, err := time.Parse(time.RFC3339, items[i].TTL) + if err != nil { + return err + } + + // Touch updates the expiry for the given key. The seconds parameter is either + // a Unix timestamp or, if seconds is less than 1 month, the number of seconds + // into the future at which time the item will expire. Zero means the item has + // no expiration time. ErrCacheMiss is returned if the key is not in the cache. + // The key must be at most 250 bytes in length. + err = s.client.Touch(items[i].Key, int32(t.Unix())) + if err != nil { + return err + } + } + return nil +} + +// return time in seconds (int32) for a given keys +func (s Plugin) TTL(keys ...string) (map[string]interface{}, error) { + const op = errors.Op("memcached HTTLas") + return nil, errors.E(op, errors.Str("not valid request for memcached, see https://github.com/memcached/memcached/issues/239")) +} + +func (s Plugin) Delete(keys ...string) error { + const op = errors.Op("memcached Has") + if keys == nil { + return errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return errors.E(op, errors.EmptyKey) + } + } + + for i := range keys { + err := s.client.Delete(keys[i]) + // ErrCacheMiss means that a Get failed because the item wasn't present. + if err != nil && err != memcache.ErrCacheMiss { + return err + } + } + return nil +} + +func (s Plugin) Close() error { + return nil +} diff --git a/plugins/kv/memcached/plugin_unit_test.go b/plugins/kv/memcached/plugin_unit_test.go new file mode 100644 index 00000000..3d37748b --- /dev/null +++ b/plugins/kv/memcached/plugin_unit_test.go @@ -0,0 +1,432 @@ +package memcached + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/stretchr/testify/assert" +) + +func initStorage() kv.Storage { + return NewMemcachedClient("localhost:11211") +} + +func cleanup(t *testing.T, s kv.Storage, keys ...string) { + err := s.Delete(keys...) + if err != nil { + t.Fatalf("error during cleanup: %s", err.Error()) + } +} + +func TestStorage_Has(t *testing.T) { + s := initStorage() + + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) +} + +func TestStorage_Has_Set_Has(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) +} + +func TestStorage_Has_Set_MGet(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) +} + +func TestStorage_Has_Set_Get(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + res, err := s.Get("key") + assert.NoError(t, err) + + if string(res) != "hello world" { + t.Fatal("wrong value by key") + } +} + +func TestStorage_Set_Del_Get(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + // check that keys are present + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) + + assert.NoError(t, s.Delete("key", "key2")) + // check that keys are not present + res, err = s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 0) +} + +func TestStorage_Set_GetM(t *testing.T) { + s := initStorage() + + defer func() { + cleanup(t, s, "key", "key2") + + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) +} + +func TestStorage_MExpire_TTL(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + // ensure that storage is clean + v, err := s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + // set timeout to 5 sec + nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339) + + i1 := kv.Item{ + Key: "key", + Value: "", + TTL: nowPlusFive, + } + i2 := kv.Item{ + Key: "key2", + Value: "", + TTL: nowPlusFive, + } + assert.NoError(t, s.MExpire(i1, i2)) + + time.Sleep(time.Second * 6) + + // ensure that storage is clean + v, err = s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) +} + +func TestNilAndWrongArgs(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key") + if err := s.Close(); err != nil { + panic(err) + } + }() + + // check + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) + + _, err = s.Has("") + assert.Error(t, err) + + _, err = s.Get("") + assert.Error(t, err) + + _, err = s.Get(" ") + assert.Error(t, err) + + _, err = s.Get(" ") + assert.Error(t, err) + + _, err = s.MGet("key", "key2", "") + assert.Error(t, err) + + _, err = s.MGet("key", "key2", " ") + assert.Error(t, err) + + assert.Error(t, s.Set(kv.Item{})) + + err = s.Delete("") + assert.Error(t, err) + + err = s.Delete("key", "") + assert.Error(t, err) + + err = s.Delete("key", " ") + assert.Error(t, err) + + err = s.Delete("key") + assert.NoError(t, err) +} + +func TestStorage_SetExpire_TTL(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + // ensure that storage is clean + v, err := s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339) + + // set timeout to 5 sec + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: nowPlusFive, + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: nowPlusFive, + })) + + time.Sleep(time.Second * 6) + + // ensure that storage is clean + v, err = s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) +} + +func TestConcurrentReadWriteTransactions(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + wg := &sync.WaitGroup{} + wg.Add(3) + + m := &sync.RWMutex{} + // concurrently set the keys + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.Lock() + // set is writable transaction + // it should stop readable + assert.NoError(t, s.Set(kv.Item{ + Key: "key" + strconv.Itoa(i), + Value: "hello world" + strconv.Itoa(i), + TTL: "", + }, kv.Item{ + Key: "key2" + strconv.Itoa(i), + Value: "hello world" + strconv.Itoa(i), + TTL: "", + })) + m.Unlock() + } + }(s) + + // should be no errors + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.RLock() + v, err = s.Has("key") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + m.RUnlock() + } + }(s) + + // should be no errors + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.Lock() + err = s.Delete("key" + strconv.Itoa(i)) + assert.NoError(t, err) + m.Unlock() + } + }(s) + + wg.Wait() +} diff --git a/plugins/kv/memory/config.go b/plugins/kv/memory/config.go new file mode 100644 index 00000000..0816f734 --- /dev/null +++ b/plugins/kv/memory/config.go @@ -0,0 +1,15 @@ +package memory + +// Config is default config for the in-memory driver +type Config struct { + // Enabled or disabled (true or false) + Enabled bool + // Interval for the check + Interval int +} + +// InitDefaults by default driver is turned off +func (c *Config) InitDefaults() { + c.Enabled = false + c.Interval = 60 // seconds +} diff --git a/plugins/kv/memory/plugin.go b/plugins/kv/memory/plugin.go new file mode 100644 index 00000000..d2d3721b --- /dev/null +++ b/plugins/kv/memory/plugin.go @@ -0,0 +1,262 @@ +package memory + +import ( + "strings" + "sync" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// PluginName is user friendly name for the plugin +const PluginName = "memory" + +type Plugin struct { + // heap is user map for the key-value pairs + heap sync.Map + stop chan struct{} + + log logger.Logger + cfg *Config +} + +func (s *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("in-memory storage init") + s.cfg = &Config{} + s.cfg.InitDefaults() + + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, err) + } + s.log = log + + s.stop = make(chan struct{}, 1) + return nil +} + +func (s *Plugin) Serve() chan error { + errCh := make(chan error, 1) + // start in-memory gc for kv + go s.gc() + + return errCh +} + +func (s *Plugin) Stop() error { + const op = errors.Op("in-memory storage stop") + err := s.Close() + if err != nil { + return errors.E(op, err) + } + return nil +} + +func (s *Plugin) Has(keys ...string) (map[string]bool, error) { + const op = errors.Op("in-memory storage Has") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + m := make(map[string]bool) + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + + if _, ok := s.heap.Load(keys[i]); ok { + m[keys[i]] = true + } + } + + return m, nil +} + +func (s *Plugin) Get(key string) ([]byte, error) { + const op = errors.Op("in-memory storage Get") + // to get cases like " " + keyTrimmed := strings.TrimSpace(key) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + + if data, exist := s.heap.Load(key); exist { + // here might be a panic + // but data only could be a string, see Set function + return []byte(data.(kv.Item).Value), nil + } + return nil, nil +} + +func (s *Plugin) MGet(keys ...string) (map[string]interface{}, error) { + const op = errors.Op("in-memory storage MGet") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string]interface{}, len(keys)) + + for i := range keys { + if value, ok := s.heap.Load(keys[i]); ok { + m[keys[i]] = value.(kv.Item).Value + } + } + + return m, nil +} + +func (s *Plugin) Set(items ...kv.Item) error { + const op = errors.Op("in-memory storage Set") + if items == nil { + return errors.E(op, errors.NoKeys) + } + + for i := range items { + // TTL is set + if items[i].TTL != "" { + // check the TTL in the item + _, err := time.Parse(time.RFC3339, items[i].TTL) + if err != nil { + return err + } + } + + s.heap.Store(items[i].Key, items[i]) + } + return nil +} + +// MExpire sets the expiration time to the key +// If key already has the expiration time, it will be overwritten +func (s *Plugin) MExpire(items ...kv.Item) error { + const op = errors.Op("in-memory storage MExpire") + for i := range items { + if items[i].TTL == "" || strings.TrimSpace(items[i].Key) == "" { + return errors.E(op, errors.Str("should set timeout and at least one key")) + } + + // if key exist, overwrite it value + if pItem, ok := s.heap.Load(items[i].Key); ok { + // check that time is correct + _, err := time.Parse(time.RFC3339, items[i].TTL) + if err != nil { + return errors.E(op, err) + } + tmp := pItem.(kv.Item) + // guess that t is in the future + // in memory is just FOR TESTING PURPOSES + // LOGIC ISN'T IDEAL + s.heap.Store(items[i].Key, kv.Item{ + Key: items[i].Key, + Value: tmp.Value, + TTL: items[i].TTL, + }) + } + } + + return nil +} + +func (s *Plugin) TTL(keys ...string) (map[string]interface{}, error) { + const op = errors.Op("in-memory storage TTL") + if keys == nil { + return nil, errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return nil, errors.E(op, errors.EmptyKey) + } + } + + m := make(map[string]interface{}, len(keys)) + + for i := range keys { + if item, ok := s.heap.Load(keys[i]); ok { + m[keys[i]] = item.(kv.Item).TTL + } + } + return m, nil +} + +func (s *Plugin) Delete(keys ...string) error { + const op = errors.Op("in-memory storage Delete") + if keys == nil { + return errors.E(op, errors.NoKeys) + } + + // should not be empty keys + for i := range keys { + keyTrimmed := strings.TrimSpace(keys[i]) + if keyTrimmed == "" { + return errors.E(op, errors.EmptyKey) + } + } + + for i := range keys { + s.heap.Delete(keys[i]) + } + return nil +} + +// Close clears the in-memory storage +func (s *Plugin) Close() error { + s.stop <- struct{}{} + return nil +} + +// RPCService returns associated rpc service. +func (s *Plugin) RPC() interface{} { + return kv.NewRPCServer(s, s.log) +} + +// Name returns plugin user-friendly name +func (s *Plugin) Name() string { + return PluginName +} + +// ================================== PRIVATE ====================================== + +func (s *Plugin) gc() { + // TODO check + ticker := time.NewTicker(time.Duration(s.cfg.Interval) * time.Second) + for { + select { + case <-s.stop: + ticker.Stop() + return + case now := <-ticker.C: + // check every second + s.heap.Range(func(key, value interface{}) bool { + v := value.(kv.Item) + if v.TTL == "" { + return true + } + + t, err := time.Parse(time.RFC3339, v.TTL) + if err != nil { + return false + } + + if now.After(t) { + s.log.Debug("key deleted", "key", key) + s.heap.Delete(key) + } + return true + }) + } + } +} diff --git a/plugins/kv/memory/plugin_unit_test.go b/plugins/kv/memory/plugin_unit_test.go new file mode 100644 index 00000000..d3b24860 --- /dev/null +++ b/plugins/kv/memory/plugin_unit_test.go @@ -0,0 +1,473 @@ +package memory + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/spiral/roadrunner/v2/plugins/kv" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func initStorage() kv.Storage { + p := &Plugin{ + stop: make(chan struct{}), + } + p.cfg = &Config{ + Enabled: true, + Interval: 1, + } + + l, _ := zap.NewDevelopment() + p.log = logger.NewZapAdapter(l) + + go p.gc() + + return p +} + +func cleanup(t *testing.T, s kv.Storage, keys ...string) { + err := s.Delete(keys...) + if err != nil { + t.Fatalf("error during cleanup: %s", err.Error()) + } +} + +func TestStorage_Has(t *testing.T) { + s := initStorage() + + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) +} + +func TestStorage_Has_Set_Has(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) +} + +func TestStorage_Has_Set_MGet(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) +} + +func TestStorage_Has_Set_Get(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + res, err := s.Get("key") + assert.NoError(t, err) + + if string(res) != "value" { + t.Fatal("wrong value by key") + } +} + +func TestStorage_Set_Del_Get(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + panic(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + // check that keys are present + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) + + assert.NoError(t, s.Delete("key", "key2")) + // check that keys are not presents -eo state,uid,pid,ppid,rtprio,time,comm + res, err = s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 0) +} + +func TestStorage_Set_GetM(t *testing.T) { + s := initStorage() + + defer func() { + cleanup(t, s, "key", "key2") + + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: "", + })) + + res, err := s.MGet("key", "key2") + assert.NoError(t, err) + assert.Len(t, res, 2) +} + +func TestStorage_MExpire_TTL(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + // ensure that storage is clean + v, err := s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + // set timeout to 5 sec + nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339) + + i1 := kv.Item{ + Key: "key", + Value: "", + TTL: nowPlusFive, + } + i2 := kv.Item{ + Key: "key2", + Value: "", + TTL: nowPlusFive, + } + assert.NoError(t, s.MExpire(i1, i2)) + + time.Sleep(time.Second * 6) + + // ensure that storage is clean + v, err = s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) +} + +func TestNilAndWrongArgs(t *testing.T) { + s := initStorage() + defer func() { + if err := s.Close(); err != nil { + panic(err) + } + }() + + // check + v, err := s.Has("key") + assert.NoError(t, err) + assert.False(t, v["key"]) + + _, err = s.Has("") + assert.Error(t, err) + + _, err = s.Get("") + assert.Error(t, err) + + _, err = s.Get(" ") + assert.Error(t, err) + + _, err = s.Get(" ") + assert.Error(t, err) + + _, err = s.MGet("key", "key2", "") + assert.Error(t, err) + + _, err = s.MGet("key", "key2", " ") + assert.Error(t, err) + + assert.NoError(t, s.Set(kv.Item{})) + _, err = s.Has("key") + assert.NoError(t, err) + + err = s.Delete("") + assert.Error(t, err) + + err = s.Delete("key", "") + assert.Error(t, err) + + err = s.Delete("key", " ") + assert.Error(t, err) + + err = s.Delete("key") + assert.NoError(t, err) +} + +func TestStorage_SetExpire_TTL(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + // ensure that storage is clean + v, err := s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, + kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339) + + // set timeout to 5 sec + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "value", + TTL: nowPlusFive, + }, + kv.Item{ + Key: "key2", + Value: "value", + TTL: nowPlusFive, + })) + + time.Sleep(time.Second * 2) + m, err := s.TTL("key", "key2") + assert.NoError(t, err) + + // remove a precision 4.02342342 -> 4 + keyTTL, err := strconv.Atoi(m["key"].(string)[0:1]) + if err != nil { + t.Fatal(err) + } + + // remove a precision 4.02342342 -> 4 + key2TTL, err := strconv.Atoi(m["key"].(string)[0:1]) + if err != nil { + t.Fatal(err) + } + + assert.True(t, keyTTL < 5) + assert.True(t, key2TTL < 5) + + time.Sleep(time.Second * 4) + + // ensure that storage is clean + v, err = s.Has("key", "key2") + assert.NoError(t, err) + assert.False(t, v["key"]) + assert.False(t, v["key2"]) +} + +func TestConcurrentReadWriteTransactions(t *testing.T) { + s := initStorage() + defer func() { + cleanup(t, s, "key", "key2") + if err := s.Close(); err != nil { + t.Fatal(err) + } + }() + + v, err := s.Has("key") + assert.NoError(t, err) + // no such key + assert.False(t, v["key"]) + + assert.NoError(t, s.Set(kv.Item{ + Key: "key", + Value: "hello world", + TTL: "", + }, kv.Item{ + Key: "key2", + Value: "hello world", + TTL: "", + })) + + v, err = s.Has("key", "key2") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + assert.True(t, v["key2"]) + + wg := &sync.WaitGroup{} + wg.Add(3) + + m := &sync.RWMutex{} + // concurrently set the keys + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.Lock() + // set is writable transaction + // it should stop readable + assert.NoError(t, s.Set(kv.Item{ + Key: "key" + strconv.Itoa(i), + Value: "hello world" + strconv.Itoa(i), + TTL: "", + }, kv.Item{ + Key: "key2" + strconv.Itoa(i), + Value: "hello world" + strconv.Itoa(i), + TTL: "", + })) + m.Unlock() + } + }(s) + + // should be no errors + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.RLock() + v, err = s.Has("key") + assert.NoError(t, err) + // no such key + assert.True(t, v["key"]) + m.RUnlock() + } + }(s) + + // should be no errors + go func(s kv.Storage) { + defer wg.Done() + for i := 0; i <= 1000; i++ { + m.Lock() + err = s.Delete("key" + strconv.Itoa(i)) + assert.NoError(t, err) + m.Unlock() + } + }(s) + + wg.Wait() +} diff --git a/plugins/kv/rpc.go b/plugins/kv/rpc.go new file mode 100644 index 00000000..751f0d12 --- /dev/null +++ b/plugins/kv/rpc.go @@ -0,0 +1,110 @@ +package kv + +import ( + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// Wrapper for the plugin +type RPCServer struct { + // svc is a plugin implementing Storage interface + svc Storage + // Logger + log logger.Logger +} + +// NewRPCServer construct RPC server for the particular plugin +func NewRPCServer(srv Storage, log logger.Logger) *RPCServer { + return &RPCServer{ + svc: srv, + log: log, + } +} + +// data Data +func (r *RPCServer) Has(in []string, res *map[string]bool) error { + const op = errors.Op("rpc server Has") + ret, err := r.svc.Has(in...) + if err != nil { + return errors.E(op, err) + } + + // update the value in the pointer + *res = ret + return nil +} + +// in SetData +func (r *RPCServer) Set(in []Item, ok *bool) error { + const op = errors.Op("rpc server Set") + + err := r.svc.Set(in...) + if err != nil { + return errors.E(op, err) + } + + *ok = true + return nil +} + +// in Data +func (r *RPCServer) MGet(in []string, res *map[string]interface{}) error { + const op = errors.Op("rpc server MGet") + ret, err := r.svc.MGet(in...) + if err != nil { + return errors.E(op, err) + } + + // update return value + *res = ret + return nil +} + +// in Data +func (r *RPCServer) MExpire(in []Item, ok *bool) error { + const op = errors.Op("rpc server MExpire") + + err := r.svc.MExpire(in...) + if err != nil { + return errors.E(op, err) + } + + *ok = true + return nil +} + +// in Data +func (r *RPCServer) TTL(in []string, res *map[string]interface{}) error { + const op = errors.Op("rpc server TTL") + + ret, err := r.svc.TTL(in...) + if err != nil { + return errors.E(op, err) + } + + *res = ret + return nil +} + +// in Data +func (r *RPCServer) Delete(in []string, ok *bool) error { + const op = errors.Op("rpc server Delete") + err := r.svc.Delete(in...) + if err != nil { + return errors.E(op, err) + } + *ok = true + return nil +} + +// in string, storages +func (r *RPCServer) Close(storage string, ok *bool) error { + const op = errors.Op("rpc server Close") + err := r.svc.Close() + if err != nil { + return errors.E(op, err) + } + *ok = true + + return nil +} diff --git a/plugins/logger/config.go b/plugins/logger/config.go new file mode 100644 index 00000000..f7a5742c --- /dev/null +++ b/plugins/logger/config.go @@ -0,0 +1,94 @@ +package logger + +import ( + "strings" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// ChannelConfig configures loggers per channel. +type ChannelConfig struct { + // Dedicated channels per logger. By default logger allocated via named logger. + Channels map[string]Config `json:"channels" yaml:"channels"` +} + +type Config struct { + // Mode configures logger based on some default template (development, production, off). + Mode string `json:"mode" yaml:"mode"` + + // Level is the minimum enabled logging level. Note that this is a dynamic + // level, so calling ChannelConfig.Level.SetLevel will atomically change the log + // level of all loggers descended from this config. + Level string `json:"level" yaml:"level"` + + // Encoding sets the logger's encoding. Valid values are "json" and + // "console", as well as any third-party encodings registered via + // RegisterEncoder. + Encoding string `json:"encoding" yaml:"encoding"` + + // Output is a list of URLs or file paths to write logging output to. + // See Open for details. + Output []string `json:"output" yaml:"output"` + + // ErrorOutput is a list of URLs to write internal logger errors to. + // The default is standard error. + // + // Note that this setting only affects internal errors; for sample code that + // sends error-level logs to a different location from info- and debug-level + // logs, see the package-level AdvancedConfiguration example. + ErrorOutput []string `json:"errorOutput" yaml:"errorOutput"` +} + +// ZapConfig converts config into Zap configuration. +func (cfg *Config) BuildLogger() (*zap.Logger, error) { + var zCfg zap.Config + switch strings.ToLower(cfg.Mode) { + case "off", "none": + return zap.NewNop(), nil + case "production": + zCfg = zap.NewProductionConfig() + case "development": + zCfg = zap.NewDevelopmentConfig() + default: + zCfg = zap.Config{ + Level: zap.NewAtomicLevelAt(zap.DebugLevel), + Encoding: "console", + EncoderConfig: zapcore.EncoderConfig{ + MessageKey: "message", + LevelKey: "level", + TimeKey: "time", + NameKey: "name", + EncodeName: ColoredHashedNameEncoder, + EncodeLevel: ColoredLevelEncoder, + EncodeTime: UTCTimeEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }, + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + } + } + + if cfg.Level != "" { + level := zap.NewAtomicLevel() + if err := level.UnmarshalText([]byte(cfg.Level)); err == nil { + zCfg.Level = level + } + } + + if cfg.Encoding != "" { + zCfg.Encoding = cfg.Encoding + } + + if len(cfg.Output) != 0 { + zCfg.OutputPaths = cfg.Output + } + + if len(cfg.ErrorOutput) != 0 { + zCfg.ErrorOutputPaths = cfg.ErrorOutput + } + + // todo: https://github.com/uber-go/zap/blob/master/FAQ.md#does-zap-support-log-rotation + + return zCfg.Build() +} diff --git a/plugins/logger/encoder.go b/plugins/logger/encoder.go new file mode 100644 index 00000000..4ff583c4 --- /dev/null +++ b/plugins/logger/encoder.go @@ -0,0 +1,66 @@ +package logger + +import ( + "hash/fnv" + "strings" + "time" + + "github.com/fatih/color" + "go.uber.org/zap/zapcore" +) + +var colorMap = []func(string, ...interface{}) string{ + color.HiYellowString, + color.HiGreenString, + color.HiBlueString, + color.HiRedString, + color.HiCyanString, + color.HiMagentaString, +} + +// ColoredLevelEncoder colorizes log levels. +func ColoredLevelEncoder(level zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { + switch level { + case zapcore.DebugLevel: + enc.AppendString(color.HiWhiteString(level.CapitalString())) + case zapcore.InfoLevel: + enc.AppendString(color.HiCyanString(level.CapitalString())) + case zapcore.WarnLevel: + enc.AppendString(color.HiYellowString(level.CapitalString())) + case zapcore.ErrorLevel, zapcore.DPanicLevel: + enc.AppendString(color.HiRedString(level.CapitalString())) + case zapcore.PanicLevel, zapcore.FatalLevel: + enc.AppendString(color.HiMagentaString(level.CapitalString())) + } +} + +// ColoredNameEncoder colorizes service names. +func ColoredNameEncoder(s string, enc zapcore.PrimitiveArrayEncoder) { + if len(s) < 12 { + s += strings.Repeat(" ", 12-len(s)) + } + + enc.AppendString(color.HiGreenString(s)) +} + +// ColoredHashedNameEncoder colorizes service names and assigns different colors to different names. +func ColoredHashedNameEncoder(s string, enc zapcore.PrimitiveArrayEncoder) { + if len(s) < 12 { + s += strings.Repeat(" ", 12-len(s)) + } + + colorID := stringHash(s, len(colorMap)) + enc.AppendString(colorMap[colorID](s)) +} + +// UTCTimeEncoder encodes time into short UTC specific timestamp. +func UTCTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.UTC().Format("2006/01/02 15:04:05")) +} + +// returns string hash +func stringHash(name string, base int) int { + h := fnv.New32a() + _, _ = h.Write([]byte(name)) + return int(h.Sum32()) % base +} diff --git a/plugins/logger/interface.go b/plugins/logger/interface.go new file mode 100644 index 00000000..876629a9 --- /dev/null +++ b/plugins/logger/interface.go @@ -0,0 +1,16 @@ +package logger + +type ( + // Logger is an general RR log interface + Logger interface { + Debug(msg string, keyvals ...interface{}) + Info(msg string, keyvals ...interface{}) + Warn(msg string, keyvals ...interface{}) + Error(msg string, keyvals ...interface{}) + } +) + +// With creates a child logger and adds structured context to it +type WithLogger interface { + With(keyvals ...interface{}) Logger +} diff --git a/plugins/logger/plugin.go b/plugins/logger/plugin.go new file mode 100644 index 00000000..01bf5cc0 --- /dev/null +++ b/plugins/logger/plugin.go @@ -0,0 +1,69 @@ +package logger + +import ( + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "go.uber.org/zap" +) + +// PluginName declares plugin name. +const PluginName = "logs" + +// ZapLogger manages zap logger. +type ZapLogger struct { + base *zap.Logger + cfg Config + channels ChannelConfig +} + +// Init logger service. +func (z *ZapLogger) Init(cfg config.Configurer) error { + const op = errors.Op("zap logger init") + err := cfg.UnmarshalKey(PluginName, &z.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + err = cfg.UnmarshalKey(PluginName, &z.channels) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + z.base, err = z.cfg.BuildLogger() + if err != nil { + return errors.E(op, errors.Disabled, err) + } + return nil +} + +// DefaultLogger returns default logger. +func (z *ZapLogger) DefaultLogger() (Logger, error) { + return NewZapAdapter(z.base), nil +} + +// NamedLogger returns logger dedicated to the specific channel. Similar to Named() but also reads the core params. +func (z *ZapLogger) NamedLogger(name string) (Logger, error) { + if cfg, ok := z.channels.Channels[name]; ok { + l, err := cfg.BuildLogger() + if err != nil { + return nil, err + } + return NewZapAdapter(l), nil + } + + return NewZapAdapter(z.base.Named(name)), nil +} + +// NamedLogger returns logger dedicated to the specific channel. Similar to Named() but also reads the core params. +func (z *ZapLogger) ServiceLogger(n endure.Named) (Logger, error) { + return z.NamedLogger(n.Name()) +} + +// Provides declares factory methods. +func (z *ZapLogger) Provides() []interface{} { + return []interface{}{ + z.ServiceLogger, + z.DefaultLogger, + } +} diff --git a/plugins/logger/zap_adapter.go b/plugins/logger/zap_adapter.go new file mode 100644 index 00000000..0a0855b8 --- /dev/null +++ b/plugins/logger/zap_adapter.go @@ -0,0 +1,56 @@ +package logger + +import ( + "fmt" + + "go.uber.org/zap" +) + +type ZapAdapter struct { + zl *zap.Logger +} + +// Create NewZapAdapter which uses general log interface +func NewZapAdapter(zapLogger *zap.Logger) *ZapAdapter { + return &ZapAdapter{ + zl: zapLogger.WithOptions(zap.AddCallerSkip(1)), + } +} + +func (log *ZapAdapter) fields(keyvals []interface{}) []zap.Field { + // we should have even number of keys and values + if len(keyvals)%2 != 0 { + return []zap.Field{zap.Error(fmt.Errorf("odd number of keyvals pairs: %v", keyvals))} + } + + var fields []zap.Field + for i := 0; i < len(keyvals); i += 2 { + key, ok := keyvals[i].(string) + if !ok { + key = fmt.Sprintf("%v", keyvals[i]) + } + fields = append(fields, zap.Any(key, keyvals[i+1])) + } + + return fields +} + +func (log *ZapAdapter) Debug(msg string, keyvals ...interface{}) { + log.zl.Debug(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) Info(msg string, keyvals ...interface{}) { + log.zl.Info(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) Warn(msg string, keyvals ...interface{}) { + log.zl.Warn(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) Error(msg string, keyvals ...interface{}) { + log.zl.Error(msg, log.fields(keyvals)...) +} + +func (log *ZapAdapter) With(keyvals ...interface{}) Logger { + return NewZapAdapter(log.zl.With(log.fields(keyvals)...)) +} diff --git a/plugins/metrics/config.go b/plugins/metrics/config.go new file mode 100644 index 00000000..9459bc9b --- /dev/null +++ b/plugins/metrics/config.go @@ -0,0 +1,138 @@ +package metrics + +import ( + "fmt" + + "github.com/prometheus/client_golang/prometheus" +) + +// Config configures metrics service. +type Config struct { + // Address to listen + Address string + + // Collect define application specific metrics. + Collect map[string]Collector +} + +type NamedCollector struct { + // Name of the collector + Name string `json:"name"` + + // Collector structure + Collector `json:"collector"` +} + +// CollectorType represents prometheus collector types +type CollectorType string + +const ( + // Histogram type + Histogram CollectorType = "histogram" + + // Gauge type + Gauge CollectorType = "gauge" + + // Counter type + Counter CollectorType = "counter" + + // Summary type + Summary CollectorType = "summary" +) + +// Collector describes single application specific metric. +type Collector struct { + // Namespace of the metric. + Namespace string `json:"namespace"` + // Subsystem of the metric. + Subsystem string `json:"subsystem"` + // Collector type (histogram, gauge, counter, summary). + Type CollectorType `json:"type"` + // Help of collector. + Help string `json:"help"` + // Labels for vectorized metrics. + Labels []string `json:"labels"` + // Buckets for histogram metric. + Buckets []float64 `json:"buckets"` + // Objectives for the summary opts + Objectives map[float64]float64 `json:"objectives"` +} + +// register application specific metrics. +func (c *Config) getCollectors() (map[string]prometheus.Collector, error) { + if c.Collect == nil { + return nil, nil + } + + collectors := make(map[string]prometheus.Collector) + + for name, m := range c.Collect { + var collector prometheus.Collector + switch m.Type { + case Histogram: + opts := prometheus.HistogramOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + Buckets: m.Buckets, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewHistogramVec(opts, m.Labels) + } else { + collector = prometheus.NewHistogram(opts) + } + case Gauge: + opts := prometheus.GaugeOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewGaugeVec(opts, m.Labels) + } else { + collector = prometheus.NewGauge(opts) + } + case Counter: + opts := prometheus.CounterOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewCounterVec(opts, m.Labels) + } else { + collector = prometheus.NewCounter(opts) + } + case Summary: + opts := prometheus.SummaryOpts{ + Name: name, + Namespace: m.Namespace, + Subsystem: m.Subsystem, + Help: m.Help, + Objectives: m.Objectives, + } + + if len(m.Labels) != 0 { + collector = prometheus.NewSummaryVec(opts, m.Labels) + } else { + collector = prometheus.NewSummary(opts) + } + default: + return nil, fmt.Errorf("invalid metric type `%s` for `%s`", m.Type, name) + } + + collectors[name] = collector + } + + return collectors, nil +} + +func (c *Config) InitDefaults() { + +} diff --git a/plugins/metrics/config_test.go b/plugins/metrics/config_test.go new file mode 100644 index 00000000..665ec9cd --- /dev/null +++ b/plugins/metrics/config_test.go @@ -0,0 +1,89 @@ +package metrics + +import ( + "bytes" + "testing" + + j "github.com/json-iterator/go" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +var json = j.ConfigCompatibleWithStandardLibrary + +func Test_Config_Hydrate_Error1(t *testing.T) { + cfg := `{"request": {"From": "Something"}}` + c := &Config{} + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + if err != nil { + t.Fatal(err) + } +} + +func Test_Config_Hydrate_Error2(t *testing.T) { + cfg := `{"dir": "/dir/"` + c := &Config{} + + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + assert.Error(t, err) +} + +func Test_Config_Metrics(t *testing.T) { + cfg := `{ +"collect":{ + "metric1":{"type": "gauge"}, + "metric2":{ "type": "counter"}, + "metric3":{"type": "summary"}, + "metric4":{"type": "histogram"} +} +}` + c := &Config{} + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + if err != nil { + t.Fatal(err) + } + + m, err := c.getCollectors() + assert.NoError(t, err) + + assert.IsType(t, prometheus.NewGauge(prometheus.GaugeOpts{}), m["metric1"]) + assert.IsType(t, prometheus.NewCounter(prometheus.CounterOpts{}), m["metric2"]) + assert.IsType(t, prometheus.NewSummary(prometheus.SummaryOpts{}), m["metric3"]) + assert.IsType(t, prometheus.NewHistogram(prometheus.HistogramOpts{}), m["metric4"]) +} + +func Test_Config_MetricsVector(t *testing.T) { + cfg := `{ +"collect":{ + "metric1":{"type": "gauge","labels":["label"]}, + "metric2":{ "type": "counter","labels":["label"]}, + "metric3":{"type": "summary","labels":["label"]}, + "metric4":{"type": "histogram","labels":["label"]} +} +}` + c := &Config{} + f := new(bytes.Buffer) + f.WriteString(cfg) + + err := json.Unmarshal(f.Bytes(), &c) + if err != nil { + t.Fatal(err) + } + + m, err := c.getCollectors() + assert.NoError(t, err) + + assert.IsType(t, prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{}), m["metric1"]) + assert.IsType(t, prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{}), m["metric2"]) + assert.IsType(t, prometheus.NewSummaryVec(prometheus.SummaryOpts{}, []string{}), m["metric3"]) + assert.IsType(t, prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{}), m["metric4"]) +} diff --git a/plugins/metrics/doc.go b/plugins/metrics/doc.go new file mode 100644 index 00000000..1abe097a --- /dev/null +++ b/plugins/metrics/doc.go @@ -0,0 +1 @@ +package metrics diff --git a/plugins/metrics/interface.go b/plugins/metrics/interface.go new file mode 100644 index 00000000..87ba4017 --- /dev/null +++ b/plugins/metrics/interface.go @@ -0,0 +1,7 @@ +package metrics + +import "github.com/prometheus/client_golang/prometheus" + +type StatProvider interface { + MetricsCollector() []prometheus.Collector +} diff --git a/plugins/metrics/plugin.go b/plugins/metrics/plugin.go new file mode 100644 index 00000000..fb9096a1 --- /dev/null +++ b/plugins/metrics/plugin.go @@ -0,0 +1,229 @@ +package metrics + +import ( + "context" + "crypto/tls" + "net/http" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "golang.org/x/sys/cpu" +) + +const ( + // PluginName declares plugin name. + PluginName = "metrics" + // maxHeaderSize declares max header size for prometheus server + maxHeaderSize = 1024 * 1024 * 100 // 104MB +) + +type statsProvider struct { + collectors []prometheus.Collector + name string +} + +// Plugin to manage application metrics using Prometheus. +type Plugin struct { + cfg Config + log logger.Logger + mu sync.Mutex // all receivers are pointers + http *http.Server + collectors sync.Map // all receivers are pointers + registry *prometheus.Registry +} + +// Init service. +func (m *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("metrics init") + err := cfg.UnmarshalKey(PluginName, &m.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + // TODO figure out what is Init + m.cfg.InitDefaults() + + m.log = log + m.registry = prometheus.NewRegistry() + + // Default + err = m.registry.Register(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{})) + if err != nil { + return errors.E(op, err) + } + + // Default + err = m.registry.Register(prometheus.NewGoCollector()) + if err != nil { + return errors.E(op, err) + } + + collectors, err := m.cfg.getCollectors() + if err != nil { + return errors.E(op, err) + } + + // Register invocation will be later in the Serve method + for k, v := range collectors { + m.collectors.Store(k, statsProvider{ + collectors: []prometheus.Collector{v}, + name: k, + }) + } + return nil +} + +// Register new prometheus collector. +func (m *Plugin) Register(c prometheus.Collector) error { + return m.registry.Register(c) +} + +// Serve prometheus metrics service. +func (m *Plugin) Serve() chan error { + errCh := make(chan error, 1) + m.collectors.Range(func(key, value interface{}) bool { + // key - name + // value - statsProvider struct + c := value.(statsProvider) + for _, v := range c.collectors { + if err := m.registry.Register(v); err != nil { + errCh <- err + return false + } + } + + return true + }) + + var topCipherSuites []uint16 + var defaultCipherSuitesTLS13 []uint16 + + hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) + + hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + + if hasGCMAsm { + // If AES-GCM hardware is provided then prioritise AES-GCM + // cipher suites. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } else { + // Without AES-GCM hardware, we put the ChaCha20-Poly1305 + // cipher suites first. + topCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + } + defaultCipherSuitesTLS13 = []uint16{ + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + } + } + + DefaultCipherSuites := make([]uint16, 0, 22) + DefaultCipherSuites = append(DefaultCipherSuites, topCipherSuites...) + DefaultCipherSuites = append(DefaultCipherSuites, defaultCipherSuitesTLS13...) + + m.http = &http.Server{ + Addr: m.cfg.Address, + Handler: promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{}), + IdleTimeout: time.Hour * 24, + ReadTimeout: time.Minute * 60, + MaxHeaderBytes: maxHeaderSize, + ReadHeaderTimeout: time.Minute * 60, + WriteTimeout: time.Minute * 60, + TLSConfig: &tls.Config{ + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.CurveP384, + tls.CurveP521, + tls.X25519, + }, + CipherSuites: DefaultCipherSuites, + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + }, + } + + go func() { + err := m.http.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + errCh <- err + return + } + }() + + return errCh +} + +// Stop prometheus metrics service. +func (m *Plugin) Stop() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.http != nil { + // timeout is 10 seconds + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + err := m.http.Shutdown(ctx) + if err != nil { + // Function should be Stop() error + m.log.Error("stop error", "error", errors.Errorf("error shutting down the metrics server: error %v", err)) + } + } + return nil +} + +// Collects used to collect all plugins which implement metrics.StatProvider interface (and Named) +func (m *Plugin) Collects() []interface{} { + return []interface{}{ + m.AddStatProvider, + } +} + +// Collector returns application specific collector by name or nil if collector not found. +func (m *Plugin) AddStatProvider(name endure.Named, stat StatProvider) error { + m.collectors.Store(name.Name(), statsProvider{ + collectors: stat.MetricsCollector(), + name: name.Name(), + }) + return nil +} + +// RPC interface satisfaction +func (m *Plugin) Name() string { + return PluginName +} + +// RPC interface satisfaction +func (m *Plugin) RPC() interface{} { + return &rpcServer{ + svc: m, + log: m.log, + } +} diff --git a/plugins/metrics/rpc.go b/plugins/metrics/rpc.go new file mode 100644 index 00000000..f9c6accb --- /dev/null +++ b/plugins/metrics/rpc.go @@ -0,0 +1,294 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +type rpcServer struct { + svc *Plugin + log logger.Logger +} + +// Metric represent single metric produced by the application. +type Metric struct { + // Collector name. + Name string + + // Collector value. + Value float64 + + // Labels associated with metric. Only for vector metrics. Must be provided in a form of label values. + Labels []string +} + +// Add new metric to the designated collector. +func (rpc *rpcServer) Add(m *Metric, ok *bool) error { + const op = errors.Op("Add metric") + rpc.log.Info("Adding metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + rpc.log.Error("undefined collector", "collector", m.Name) + return errors.E(op, errors.Errorf("undefined collector %s, try first Declare the desired collector", m.Name)) + } + + switch c := c.(type) { + case prometheus.Gauge: + c.Add(m.Value) + + case *prometheus.GaugeVec: + if len(m.Labels) == 0 { + rpc.log.Error("required labels for collector", "collector", m.Name) + return errors.E(op, errors.Errorf("required labels for collector %s", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Add(m.Value) + case prometheus.Counter: + c.Add(m.Value) + + case *prometheus.CounterVec: + if len(m.Labels) == 0 { + return errors.E(op, errors.Errorf("required labels for collector `%s`", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Add(m.Value) + + default: + return errors.E(op, errors.Errorf("collector %s does not support method `Add`", m.Name)) + } + + // RPC, set ok to true as return value. Need by rpc.Call reply argument + *ok = true + rpc.log.Info("new metric successfully added", "name", m.Name, "labels", m.Labels, "value", m.Value) + return nil +} + +// Sub subtract the value from the specific metric (gauge only). +func (rpc *rpcServer) Sub(m *Metric, ok *bool) error { + const op = errors.Op("Subtracting metric") + rpc.log.Info("Subtracting value from metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + rpc.log.Error("undefined collector", "name", m.Name, "value", m.Value, "labels", m.Labels) + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + if c == nil { + // can it be nil ??? I guess can't + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + + switch c := c.(type) { + case prometheus.Gauge: + c.Sub(m.Value) + + case *prometheus.GaugeVec: + if len(m.Labels) == 0 { + rpc.log.Error("required labels for collector, but none was provided", "name", m.Name, "value", m.Value) + return errors.E(op, errors.Errorf("required labels for collector %s", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Sub(m.Value) + default: + return errors.E(op, errors.Errorf("collector `%s` does not support method `Sub`", m.Name)) + } + rpc.log.Info("Subtracting operation applied successfully", "name", m.Name, "labels", m.Labels, "value", m.Value) + + *ok = true + return nil +} + +// Observe the value (histogram and summary only). +func (rpc *rpcServer) Observe(m *Metric, ok *bool) error { + const op = errors.Op("Observe metrics") + rpc.log.Info("Observing metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + rpc.log.Error("undefined collector", "name", m.Name, "value", m.Value, "labels", m.Labels) + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + if c == nil { + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + + switch c := c.(type) { + case *prometheus.SummaryVec: + if len(m.Labels) == 0 { + return errors.E(op, errors.Errorf("required labels for collector `%s`", m.Name)) + } + + observer, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + return errors.E(op, err) + } + observer.Observe(m.Value) + + case prometheus.Histogram: + c.Observe(m.Value) + + case *prometheus.HistogramVec: + if len(m.Labels) == 0 { + return errors.E(op, errors.Errorf("required labels for collector `%s`", m.Name)) + } + + observer, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + observer.Observe(m.Value) + default: + return errors.E(op, errors.Errorf("collector `%s` does not support method `Observe`", m.Name)) + } + + rpc.log.Info("observe operation finished successfully", "name", m.Name, "labels", m.Labels, "value", m.Value) + + *ok = true + return nil +} + +// Declare is used to register new collector in prometheus +// THE TYPES ARE: +// NamedCollector -> Collector with the name +// bool -> RPC reply value +// RETURNS: +// error +func (rpc *rpcServer) Declare(nc *NamedCollector, ok *bool) error { + const op = errors.Op("Declare metric") + rpc.log.Info("Declaring new metric", "name", nc.Name, "type", nc.Type, "namespace", nc.Namespace) + _, exist := rpc.svc.collectors.Load(nc.Name) + if exist { + rpc.log.Error("metric with provided name already exist", "name", nc.Name, "type", nc.Type, "namespace", nc.Namespace) + return errors.E(op, errors.Errorf("tried to register existing collector with the name `%s`", nc.Name)) + } + + var collector prometheus.Collector + switch nc.Type { + case Histogram: + opts := prometheus.HistogramOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + Buckets: nc.Buckets, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewHistogramVec(opts, nc.Labels) + } else { + collector = prometheus.NewHistogram(opts) + } + case Gauge: + opts := prometheus.GaugeOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewGaugeVec(opts, nc.Labels) + } else { + collector = prometheus.NewGauge(opts) + } + case Counter: + opts := prometheus.CounterOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewCounterVec(opts, nc.Labels) + } else { + collector = prometheus.NewCounter(opts) + } + case Summary: + opts := prometheus.SummaryOpts{ + Name: nc.Name, + Namespace: nc.Namespace, + Subsystem: nc.Subsystem, + Help: nc.Help, + } + + if len(nc.Labels) != 0 { + collector = prometheus.NewSummaryVec(opts, nc.Labels) + } else { + collector = prometheus.NewSummary(opts) + } + + default: + return errors.E(op, errors.Errorf("unknown collector type %s", nc.Type)) + } + + // add collector to sync.Map + rpc.svc.collectors.Store(nc.Name, collector) + // that method might panic, we handle it by recover + err := rpc.svc.Register(collector) + if err != nil { + *ok = false + return errors.E(op, err) + } + + rpc.log.Info("metric successfully added", "name", nc.Name, "type", nc.Type, "namespace", nc.Namespace) + + *ok = true + return nil +} + +// Set the metric value (only for gaude). +func (rpc *rpcServer) Set(m *Metric, ok *bool) (err error) { + const op = errors.Op("Set metric") + rpc.log.Info("Observing metric", "name", m.Name, "value", m.Value, "labels", m.Labels) + + c, exist := rpc.svc.collectors.Load(m.Name) + if !exist { + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + if c == nil { + return errors.E(op, errors.Errorf("undefined collector %s", m.Name)) + } + + switch c := c.(type) { + case prometheus.Gauge: + c.Set(m.Value) + + case *prometheus.GaugeVec: + if len(m.Labels) == 0 { + rpc.log.Error("required labels for collector", "collector", m.Name) + return errors.E(op, errors.Errorf("required labels for collector %s", m.Name)) + } + + gauge, err := c.GetMetricWithLabelValues(m.Labels...) + if err != nil { + rpc.log.Error("failed to get metrics with label values", "collector", m.Name, "labels", m.Labels) + return errors.E(op, err) + } + gauge.Set(m.Value) + + default: + return errors.E(op, errors.Errorf("collector `%s` does not support method Set", m.Name)) + } + + rpc.log.Info("set operation finished successfully", "name", m.Name, "labels", m.Labels, "value", m.Value) + + *ok = true + return nil +} diff --git a/plugins/redis/config.go b/plugins/redis/config.go new file mode 100644 index 00000000..ebcefed1 --- /dev/null +++ b/plugins/redis/config.go @@ -0,0 +1,32 @@ +package redis + +import "time" + +type Config struct { + Addrs []string `yaml:"addrs"` + DB int `yaml:"db"` + Username string `yaml:"username"` + Password string `yaml:"password"` + MasterName string `yaml:"master_name"` + SentinelPassword string `yaml:"sentinel_password"` + RouteByLatency bool `yaml:"route_by_latency"` + RouteRandomly bool `yaml:"route_randomly"` + MaxRetries int `yaml:"max_retries"` + DialTimeout time.Duration `yaml:"dial_timeout"` + MinRetryBackoff time.Duration `yaml:"min_retry_backoff"` + MaxRetryBackoff time.Duration `yaml:"max_retry_backoff"` + PoolSize int `yaml:"pool_size"` + MinIdleConns int `yaml:"min_idle_conns"` + MaxConnAge time.Duration `yaml:"max_conn_age"` + ReadTimeout time.Duration `yaml:"read_timeout"` + WriteTimeout time.Duration `yaml:"write_timeout"` + PoolTimeout time.Duration `yaml:"pool_timeout"` + IdleTimeout time.Duration `yaml:"idle_timeout"` + IdleCheckFreq time.Duration `yaml:"idle_check_freq"` + ReadOnly bool `yaml:"read_only"` +} + +// InitDefaults initializing fill config with default values +func (s *Config) InitDefaults() { + s.Addrs = []string{"localhost:6379"} // default addr is pointing to local storage +} diff --git a/plugins/redis/interface.go b/plugins/redis/interface.go new file mode 100644 index 00000000..909c8ca4 --- /dev/null +++ b/plugins/redis/interface.go @@ -0,0 +1,9 @@ +package redis + +import "github.com/go-redis/redis/v8" + +// Redis in the redis KV plugin interface +type Redis interface { + // GetClient + GetClient() redis.UniversalClient +} diff --git a/plugins/redis/plugin.go b/plugins/redis/plugin.go new file mode 100644 index 00000000..fe465340 --- /dev/null +++ b/plugins/redis/plugin.go @@ -0,0 +1,75 @@ +package redis + +import ( + "github.com/go-redis/redis/v8" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const PluginName = "redis" + +type Plugin struct { + // config for RR integration + cfg *Config + // logger + log logger.Logger + // redis universal client + universalClient redis.UniversalClient +} + +func (s *Plugin) GetClient() redis.UniversalClient { + return s.universalClient +} + +func (s *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("redis plugin init") + s.cfg = &Config{} + s.cfg.InitDefaults() + + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + s.log = log + + s.universalClient = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: s.cfg.Addrs, + DB: s.cfg.DB, + Username: s.cfg.Username, + Password: s.cfg.Password, + SentinelPassword: s.cfg.SentinelPassword, + MaxRetries: s.cfg.MaxRetries, + MinRetryBackoff: s.cfg.MaxRetryBackoff, + MaxRetryBackoff: s.cfg.MaxRetryBackoff, + DialTimeout: s.cfg.DialTimeout, + ReadTimeout: s.cfg.ReadTimeout, + WriteTimeout: s.cfg.WriteTimeout, + PoolSize: s.cfg.PoolSize, + MinIdleConns: s.cfg.MinIdleConns, + MaxConnAge: s.cfg.MaxConnAge, + PoolTimeout: s.cfg.PoolTimeout, + IdleTimeout: s.cfg.IdleTimeout, + IdleCheckFrequency: s.cfg.IdleCheckFreq, + ReadOnly: s.cfg.ReadOnly, + RouteByLatency: s.cfg.RouteByLatency, + RouteRandomly: s.cfg.RouteRandomly, + MasterName: s.cfg.MasterName, + }) + + return nil +} + +func (s *Plugin) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +func (s Plugin) Stop() error { + return s.universalClient.Close() +} + +func (s *Plugin) Name() string { + return PluginName +} diff --git a/plugins/reload/config.go b/plugins/reload/config.go new file mode 100644 index 00000000..9ca2c0dc --- /dev/null +++ b/plugins/reload/config.go @@ -0,0 +1,58 @@ +package reload + +import ( + "time" + + "github.com/spiral/errors" +) + +// Config is a Reload configuration point. +type Config struct { + // Interval is a global refresh interval + Interval time.Duration + + // Patterns is a global file patterns to watch. It will be applied to every directory in project + Patterns []string + + // Services is set of services which would be reloaded in case of FS changes + Services map[string]ServiceConfig +} + +type ServiceConfig struct { + // Enabled indicates that service must be watched, doest not required when any other option specified + Enabled bool + + // Recursive is options to use nested files from root folder + Recursive bool + + // Patterns is per-service specific files to watch + Patterns []string + + // Dirs is per-service specific dirs which will be combined with Patterns + Dirs []string + + // Ignore is set of files which would not be watched + Ignore []string +} + +// InitDefaults sets missing values to their default values. +func InitDefaults(c *Config) { + c.Interval = time.Second + c.Patterns = []string{".php"} +} + +// Valid validates the configuration. +func (c *Config) Valid() error { + const op = errors.Op("config validation [reload plugin]") + if c.Interval < time.Second { + return errors.E(op, errors.Str("too short interval")) + } + + if c.Services == nil { + return errors.E(op, errors.Str("should add at least 1 service")) + } else if len(c.Services) == 0 { + return errors.E(op, errors.Str("service initialized, however, no config added")) + } + + return nil +} diff --git a/plugins/reload/plugin.go b/plugins/reload/plugin.go new file mode 100644 index 00000000..eb1b61b2 --- /dev/null +++ b/plugins/reload/plugin.go @@ -0,0 +1,159 @@ +package reload + +import ( + "os" + "strings" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + "github.com/spiral/roadrunner/v2/plugins/resetter" +) + +// PluginName contains default plugin name. +const PluginName string = "reload" +const thresholdChanBuffer uint = 1000 + +type Plugin struct { + cfg *Config + log logger.Logger + watcher *Watcher + services map[string]interface{} + res resetter.Resetter + stopc chan struct{} +} + +// Init controller service +func (s *Plugin) Init(cfg config.Configurer, log logger.Logger, res resetter.Resetter) error { + const op = errors.Op("reload plugin init") + s.cfg = &Config{} + InitDefaults(s.cfg) + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + // disable plugin in case of error + return errors.E(op, errors.Disabled, err) + } + + s.log = log + s.res = res + s.stopc = make(chan struct{}, 1) + s.services = make(map[string]interface{}) + + var configs []WatcherConfig + + for serviceName, serviceConfig := range s.cfg.Services { + ignored, err := ConvertIgnored(serviceConfig.Ignore) + if err != nil { + return errors.E(op, err) + } + configs = append(configs, WatcherConfig{ + ServiceName: serviceName, + Recursive: serviceConfig.Recursive, + Directories: serviceConfig.Dirs, + FilterHooks: func(filename string, patterns []string) error { + for i := 0; i < len(patterns); i++ { + if strings.Contains(filename, patterns[i]) { + return nil + } + } + return errors.E(op, errors.SkipFile) + }, + Files: make(map[string]os.FileInfo), + Ignored: ignored, + FilePatterns: append(serviceConfig.Patterns, s.cfg.Patterns...), + }) + } + + s.watcher, err = NewWatcher(configs, s.log) + if err != nil { + return errors.E(op, err) + } + + return nil +} + +func (s *Plugin) Serve() chan error { + const op = errors.Op("reload plugin serve") + errCh := make(chan error, 1) + if s.cfg.Interval < time.Second { + errCh <- errors.E(op, errors.Str("reload interval is too fast")) + return errCh + } + + // make a map with unique services + // so, if we would have a 100 events from http service + // in map we would see only 1 key and it's config + treshholdc := make(chan struct { + serviceConfig ServiceConfig + service string + }, thresholdChanBuffer) + + // use the same interval + timer := time.NewTimer(s.cfg.Interval) + + go func() { + for e := range s.watcher.Event { + treshholdc <- struct { + serviceConfig ServiceConfig + service string + }{serviceConfig: s.cfg.Services[e.service], service: e.service} + } + }() + + // map with configs by services + updated := make(map[string]ServiceConfig, len(s.cfg.Services)) + + go func() { + for { + select { + case cfg := <-treshholdc: + // logic is following: + // restart + timer.Stop() + // replace previous value in map by more recent without adding new one + updated[cfg.service] = cfg.serviceConfig + // if we getting a lot of events, we shouldn't restart particular service on each of it (user doing batch move or very fast typing) + // instead, we are resetting the timer and wait for s.cfg.Interval time + // If there is no more events, we restart service only once + timer.Reset(s.cfg.Interval) + case <-timer.C: + if len(updated) > 0 { + for name := range updated { + err := s.res.ResetByName(name) + if err != nil { + timer.Stop() + errCh <- errors.E(op, err) + return + } + } + // zero map + updated = make(map[string]ServiceConfig, len(s.cfg.Services)) + } + case <-s.stopc: + timer.Stop() + return + } + } + }() + + go func() { + err := s.watcher.StartPolling(s.cfg.Interval) + if err != nil { + errCh <- errors.E(op, err) + return + } + }() + + return errCh +} + +func (s *Plugin) Stop() error { + s.watcher.Stop() + s.stopc <- struct{}{} + return nil +} + +func (s *Plugin) Name() string { + return PluginName +} diff --git a/plugins/reload/watcher.go b/plugins/reload/watcher.go new file mode 100644 index 00000000..08c85af9 --- /dev/null +++ b/plugins/reload/watcher.go @@ -0,0 +1,374 @@ +package reload + +import ( + "io/ioutil" + "os" + "path/filepath" + "sync" + "time" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// SimpleHook is used to filter by simple criteria, CONTAINS +type SimpleHook func(filename string, pattern []string) error + +// An Event describes an event that is received when files or directory +// changes occur. It includes the os.FileInfo of the changed file or +// directory and the type of event that's occurred and the full path of the file. +type Event struct { + Path string + Info os.FileInfo + + service string // type of service, http, grpc, etc... +} + +type WatcherConfig struct { + // service name + ServiceName string + + // Recursive or just add by singe directory + Recursive bool + + // Directories used per-service + Directories []string + + // simple hook, just CONTAINS + FilterHooks func(filename string, pattern []string) error + + // path to file with Files + Files map[string]os.FileInfo + + // Ignored Directories, used map for O(1) amortized get + Ignored map[string]struct{} + + // FilePatterns to ignore + FilePatterns []string +} + +type Watcher struct { + // main event channel + Event chan Event + close chan struct{} + + // ============================= + mu *sync.Mutex + + // indicates is walker started or not + started bool + + // config for each service + // need pointer here to assign files + watcherConfigs map[string]WatcherConfig + + // logger + log logger.Logger +} + +// Options is used to set Watcher Options +type Options func(*Watcher) + +// NewWatcher returns new instance of File Watcher +func NewWatcher(configs []WatcherConfig, log logger.Logger, options ...Options) (*Watcher, error) { + w := &Watcher{ + Event: make(chan Event), + mu: &sync.Mutex{}, + + log: log, + + close: make(chan struct{}), + + //workingDir: workDir, + watcherConfigs: make(map[string]WatcherConfig), + } + + // add watcherConfigs by service names + for _, v := range configs { + w.watcherConfigs[v.ServiceName] = v + } + + // apply options + for _, option := range options { + option(w) + } + err := w.initFs() + if err != nil { + return nil, err + } + + return w, nil +} + +// initFs makes initial map with files +func (w *Watcher) initFs() error { + const op = errors.Op("init fs") + for srvName, config := range w.watcherConfigs { + fileList, err := w.retrieveFileList(srvName, config) + if err != nil { + return errors.E(op, err) + } + // workaround. in golang you can't assign to map in struct field + tmp := w.watcherConfigs[srvName] + tmp.Files = fileList + w.watcherConfigs[srvName] = tmp + } + return nil +} + +// ConvertIgnored is used to convert slice to map with ignored files +func ConvertIgnored(ignored []string) (map[string]struct{}, error) { + if len(ignored) == 0 { + return nil, nil + } + + ign := make(map[string]struct{}, len(ignored)) + for i := 0; i < len(ignored); i++ { + abs, err := filepath.Abs(ignored[i]) + if err != nil { + return nil, err + } + ign[abs] = struct{}{} + } + + return ign, nil +} + +// https://en.wikipedia.org/wiki/Inotify +// SetMaxFileEvents sets max file notify events for Watcher +// In case of file watch errors, this value can be increased system-wide +// For linux: set --> fs.inotify.max_user_watches = 600000 (under /etc/<choose_name_here>.conf) +// Add apply: sudo sysctl -p --system +// func SetMaxFileEvents(events int) Options { +// return func(watcher *Watcher) { +// watcher.maxFileWatchEvents = events +// } +// +// } + +// pass map from outside +func (w *Watcher) retrieveFilesSingle(serviceName, path string) (map[string]os.FileInfo, error) { + const op = errors.Op("retrieve") + stat, err := os.Stat(path) + if err != nil { + return nil, err + } + + filesList := make(map[string]os.FileInfo, 10) + filesList[path] = stat + + // if it's not a dir, return + if !stat.IsDir() { + return filesList, nil + } + + fileInfoList, err := ioutil.ReadDir(path) + if err != nil { + return nil, err + } + + // recursive calls are slow in compare to goto + // so, we will add files with goto pattern +outer: + for i := 0; i < len(fileInfoList); i++ { + // if file in ignored --> continue + if _, ignored := w.watcherConfigs[serviceName].Ignored[path]; ignored { + continue + } + + // if filename does not contain pattern --> ignore that file + if w.watcherConfigs[serviceName].FilePatterns != nil && w.watcherConfigs[serviceName].FilterHooks != nil { + err = w.watcherConfigs[serviceName].FilterHooks(fileInfoList[i].Name(), w.watcherConfigs[serviceName].FilePatterns) + if errors.Is(errors.SkipFile, err) { + continue outer + } + } + + filesList[fileInfoList[i].Name()] = fileInfoList[i] + } + + return filesList, nil +} + +func (w *Watcher) StartPolling(duration time.Duration) error { + w.mu.Lock() + const op = errors.Op("start polling") + if w.started { + w.mu.Unlock() + return errors.E(op, errors.Str("already started")) + } + + w.started = true + w.mu.Unlock() + + return w.waitEvent(duration) +} + +// this is blocking operation +func (w *Watcher) waitEvent(d time.Duration) error { + ticker := time.NewTicker(d) + for { + select { + case <-w.close: + ticker.Stop() + // just exit + // no matter for the pollEvents + return nil + case <-ticker.C: + // this is not very effective way + // because we have to wait on Lock + // better is to listen files in parallel, but, since that would be used in debug... TODO + for serviceName := range w.watcherConfigs { + // TODO sync approach + fileList, _ := w.retrieveFileList(serviceName, w.watcherConfigs[serviceName]) + w.pollEvents(w.watcherConfigs[serviceName].ServiceName, fileList) + } + } + } +} + +// retrieveFileList get file list for service +func (w *Watcher) retrieveFileList(serviceName string, config WatcherConfig) (map[string]os.FileInfo, error) { + fileList := make(map[string]os.FileInfo) + if config.Recursive { + // walk through directories recursively + for i := 0; i < len(config.Directories); i++ { + // full path is workdir/relative_path + fullPath, err := filepath.Abs(config.Directories[i]) + if err != nil { + return nil, err + } + list, err := w.retrieveFilesRecursive(serviceName, fullPath) + if err != nil { + return nil, err + } + + for k := range list { + fileList[k] = list[k] + } + } + return fileList, nil + } + + for i := 0; i < len(config.Directories); i++ { + // full path is workdir/relative_path + fullPath, err := filepath.Abs(config.Directories[i]) + if err != nil { + return nil, err + } + + // list is pathToFiles with files + list, err := w.retrieveFilesSingle(serviceName, fullPath) + if err != nil { + return nil, err + } + + for pathToFile, file := range list { + fileList[pathToFile] = file + } + } + + return fileList, nil +} + +func (w *Watcher) retrieveFilesRecursive(serviceName, root string) (map[string]os.FileInfo, error) { + fileList := make(map[string]os.FileInfo) + + return fileList, filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + const op = errors.Op("retrieve files recursive") + if err != nil { + return errors.E(op, err) + } + + // If path is ignored and it's a directory, skip the directory. If it's + // ignored and it's a single file, skip the file. + _, ignored := w.watcherConfigs[serviceName].Ignored[path] + if ignored { + if info.IsDir() { + // if it's dir, ignore whole + return filepath.SkipDir + } + return nil + } + + // if filename does not contain pattern --> ignore that file + err = w.watcherConfigs[serviceName].FilterHooks(info.Name(), w.watcherConfigs[serviceName].FilePatterns) + if errors.Is(errors.SkipFile, err) { + return nil + } + + // Add the path and it's info to the file list. + fileList[path] = info + return nil + }) +} + +func (w *Watcher) pollEvents(serviceName string, files map[string]os.FileInfo) { + w.mu.Lock() + defer w.mu.Unlock() + + // Store create and remove events for use to check for rename events. + creates := make(map[string]os.FileInfo) + removes := make(map[string]os.FileInfo) + + // Check for removed files. + for pth := range w.watcherConfigs[serviceName].Files { + if _, found := files[pth]; !found { + removes[pth] = w.watcherConfigs[serviceName].Files[pth] + w.log.Debug("file added to the list of removed files", "path", pth, "name", w.watcherConfigs[serviceName].Files[pth].Name(), "size", w.watcherConfigs[serviceName].Files[pth].Size()) + } + } + + // Check for created files, writes and chmods. + for pth := range files { + if files[pth].IsDir() { + continue + } + oldInfo, found := w.watcherConfigs[serviceName].Files[pth] + if !found { + // A file was created. + creates[pth] = files[pth] + w.log.Debug("file was created", "path", pth, "name", files[pth].Name(), "size", files[pth].Size()) + continue + } + + if oldInfo.ModTime() != files[pth].ModTime() || oldInfo.Mode() != files[pth].Mode() { + w.watcherConfigs[serviceName].Files[pth] = files[pth] + w.log.Debug("file was updated", "path", pth, "name", files[pth].Name(), "size", files[pth].Size()) + w.Event <- Event{ + Path: pth, + Info: files[pth], + service: serviceName, + } + } + } + + // Send all the remaining create and remove events. + for pth := range creates { + // add file to the plugin watch files + w.watcherConfigs[serviceName].Files[pth] = creates[pth] + w.log.Debug("file was added to watcher", "path", pth, "name", creates[pth].Name(), "size", creates[pth].Size()) + + w.Event <- Event{ + Path: pth, + Info: creates[pth], + service: serviceName, + } + } + + for pth := range removes { + // delete path from the config + delete(w.watcherConfigs[serviceName].Files, pth) + w.log.Debug("file was removed from watcher", "path", pth, "name", removes[pth].Name(), "size", removes[pth].Size()) + + w.Event <- Event{ + Path: pth, + Info: removes[pth], + service: serviceName, + } + } +} + +func (w *Watcher) Stop() { + w.close <- struct{}{} +} diff --git a/plugins/resetter/interface.go b/plugins/resetter/interface.go new file mode 100644 index 00000000..47d8d791 --- /dev/null +++ b/plugins/resetter/interface.go @@ -0,0 +1,17 @@ +package resetter + +// If plugin implements Resettable interface, than it state can be resetted without reload in runtime via RPC/HTTP +type Resettable interface { + // Reset reload all plugins + Reset() error +} + +// Resetter interface is the Resetter plugin main interface +type Resetter interface { + // Reset all registered plugins + ResetAll() error + // Reset by plugin name + ResetByName(string) error + // GetAll registered plugins + GetAll() []string +} diff --git a/plugins/resetter/plugin.go b/plugins/resetter/plugin.go new file mode 100644 index 00000000..5d294086 --- /dev/null +++ b/plugins/resetter/plugin.go @@ -0,0 +1,80 @@ +package resetter + +import ( + "github.com/spiral/endure" + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +const PluginName = "resetter" + +type Plugin struct { + registry map[string]Resettable + log logger.Logger +} + +func (p *Plugin) ResetAll() error { + const op = errors.Op("reset all") + for name := range p.registry { + err := p.registry[name].Reset() + if err != nil { + return errors.E(op, err) + } + } + return nil +} + +func (p *Plugin) ResetByName(plugin string) error { + const op = errors.Op("reset by name") + if plugin, ok := p.registry[plugin]; ok { + return plugin.Reset() + } + return errors.E(op, errors.Errorf("can't find plugin: %s", plugin)) +} + +func (p *Plugin) GetAll() []string { + all := make([]string, 0, len(p.registry)) + for name := range p.registry { + all = append(all, name) + } + return all +} + +func (p *Plugin) Init(log logger.Logger) error { + p.registry = make(map[string]Resettable) + p.log = log + return nil +} + +// Reset named service. +func (p *Plugin) Reset(name string) error { + svc, ok := p.registry[name] + if !ok { + return errors.E("no such service", errors.Str(name)) + } + + return svc.Reset() +} + +// RegisterTarget resettable service. +func (p *Plugin) RegisterTarget(name endure.Named, r Resettable) error { + p.registry[name.Name()] = r + return nil +} + +// Collects declares services to be collected. +func (p *Plugin) Collects() []interface{} { + return []interface{}{ + p.RegisterTarget, + } +} + +// Name of the service. +func (p *Plugin) Name() string { + return PluginName +} + +// RPCService returns associated rpc service. +func (p *Plugin) RPC() interface{} { + return &rpc{srv: p, log: p.log} +} diff --git a/plugins/resetter/rpc.go b/plugins/resetter/rpc.go new file mode 100644 index 00000000..69c955b0 --- /dev/null +++ b/plugins/resetter/rpc.go @@ -0,0 +1,30 @@ +package resetter + +import "github.com/spiral/roadrunner/v2/plugins/logger" + +type rpc struct { + srv *Plugin + log logger.Logger +} + +// List all resettable plugins. +func (rpc *rpc) List(_ bool, list *[]string) error { + rpc.log.Debug("started List method") + *list = make([]string, 0) + + for name := range rpc.srv.registry { + *list = append(*list, name) + } + rpc.log.Debug("services list", "services", *list) + + rpc.log.Debug("finished List method") + return nil +} + +// Reset named plugin. +func (rpc *rpc) Reset(service string, done *bool) error { + rpc.log.Debug("started Reset method for the service", "service", service) + defer rpc.log.Debug("finished Reset method for the service", "service", service) + *done = true + return rpc.srv.Reset(service) +} diff --git a/plugins/rpc/config.go b/plugins/rpc/config.go new file mode 100644 index 00000000..88ad7f0e --- /dev/null +++ b/plugins/rpc/config.go @@ -0,0 +1,46 @@ +package rpc + +import ( + "errors" + "net" + "strings" + + "github.com/spiral/roadrunner/v2/utils" +) + +// Config defines RPC service config. +type Config struct { + // Listen string + Listen string +} + +// InitDefaults allows to init blank config with pre-defined set of default values. +func (c *Config) InitDefaults() { + if c.Listen == "" { + c.Listen = "tcp://127.0.0.1:6001" + } +} + +// Valid returns nil if config is valid. +func (c *Config) Valid() error { + if dsn := strings.Split(c.Listen, "://"); len(dsn) != 2 { + return errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + } + + return nil +} + +// Listener creates new rpc socket Listener. +func (c *Config) Listener() (net.Listener, error) { + return utils.CreateListener(c.Listen) +} + +// Dialer creates rpc socket Dialer. +func (c *Config) Dialer() (net.Conn, error) { + dsn := strings.Split(c.Listen, "://") + if len(dsn) != 2 { + return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)") + } + + return net.Dial(dsn[0], dsn[1]) +} diff --git a/plugins/rpc/doc/plugin_arch.drawio b/plugins/rpc/doc/plugin_arch.drawio new file mode 100644 index 00000000..dec5f0b2 --- /dev/null +++ b/plugins/rpc/doc/plugin_arch.drawio @@ -0,0 +1 @@ +<mxfile host="Electron" modified="2020-10-19T17:14:19.125Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/13.7.9 Chrome/85.0.4183.121 Electron/10.1.3 Safari/537.36" etag="2J39x4EyFr1zaE9BXKM4" version="13.7.9" type="device"><diagram id="q2oMKs6VHyn7y0AfAXBL" name="Page-1">7Vttc9o4EP41zLQfksE2GPIxQHPXu7RlQntt7ptiC1sX2XJlOUB//a1sGdtIJDQFnE6YyUys1YutfR7trlai44yj5R8cJeEH5mPasbv+suNMOrZtORcO/JOSVSEZWv1CEHDiq0aVYEZ+YCXsKmlGfJw2GgrGqCBJU+ixOMaeaMgQ52zRbDZntPnWBAVYE8w8RHXpV+KLUEkt96Kq+BOTIFSvHtqDoiJCZWM1kzREPlvURM67jjPmjIniKVqOMZXKK/VS9LvaUrv+MI5jsUuHL/zu0yx7//HT3Pln8vfN59vvS/usVHMqVuWMsQ8KUEXGRcgCFiP6rpKOOMtiH8thu1Cq2lwzloDQAuF/WIiVQhNlgoEoFBFVtXhJxLfa860c6ryvSpOlGjkvrMpCLPjqW71Q6yWLVbe8VPabs1hcoYhQKRizjBPMYcIf8UJVqq+8gGKhC6mArTpWohQG8lSrfz88xF8/ds/+uiLe7MsXtLiyZ2clVxEPsHik3WDNBFhCmEUYvh36cUyRIA/N70CKy8G6XQU3PCjEfwZ9q030K8RvazVPoV8BftvA+7dE33KOBP9jX/mAaKbedDOFkbpTmgUk1qjRBH4REoFnCcr1sADj3wT55xVv0PMD5gIvayJdU6rWGSi3otyMYw3OlWRRme21VwlrFtsdHEi9jqbe9zERha+ak0DTL0xVNJWIKAliePZAMaA+ZyQVQsA5XaqKiPh+sShxSn6gu3woiU7CSCzyCfVHnf5EjgXrMC103go+3Q18hho6QwM4pfPcOzg9DZwJTnDspyBk8Rqk8ylnDxCB8N8DLcveD1z2BlxWWa4vpu4x8epreOmuK/YvZcQnIaAoTYm34XeO5kMMun/aFRjdj45QDYG+AYBStrMHUW+YSgpWBOgNtxCgHKJwgapXPercGKhvbwxkbQxUKEYbKCfJetrP542r8aa0vt0U9gsE1rpzKfWVeK97ia+Xc41glolhB1viA32Jj+3O5YhIXc9loAHFEczdpRKWO95Ay/2eyZ1UrqqzQq8S14tkmeurrIanQP0vRvmVQYA052WwVAwHE7+rXrHBp/bCI3f4tPu1jMGReyCwLT06KoLPVPDMExnHmvrSBYkoinGpIVWz07oUcm8y8kJC/Wu0YpmcXiqQd1+WRiHj5AcMi0qIoJqXMNhuo8VM9lQLO1/oeFqiY22IPqBlo+E1SoUSeIxSlKTkbj2NCGwhiUdMCBbt0/k8P47uuQarULapE8Vye4diytDg+ke7R2hAKHaPx4wyIMYkZgWBCKUbopJDFM/FVgalsOEhcXCdt5n0KsmNUoUUMeg7p3kgEoI/wHG+axZIbPUHI9DyWIYl4BnsMZStqpw7iwT22WMWw1wQycHFwKMFTsUvU+Tx1fk0cUr34e7GE/tQBqV0SxpNpJGeYf6QK+VNjMX5TeK9PbGlTbb07ZbZYl1sYUsKTCEeltvAIlKr+aNuSqHqxJw2mTMwBC7HZY6eOSiYMydYni3IeHH8aILnxIk9c8Lq9tomxQ7pCUpyqAszUZ4lWc/iw3qXqQjwOc+8n1kaSRydJI6BEBTdYTqF3WixH57woq1h0/ryueDsGLAOD0UFPeNQ2AcYPmT+G7FK8NvCTMjHkzdply1HdCfmIzhDHvMIR3Av9jDVrKTOjjnUCzPaRzpN1Ra+Ciafk9Xo/nK6wmAsfpMMhrZ+DazZmsHoNTNdPcvgD1xDpmuwB4dgpIX9dLxY8aTKdZ78wp7osn2t/lQyw8SZg3kFPTmqcSZGkTIsgNeJLS2yxZTMOCpb9IizMigcByQFmyITGlYxV4A2o0iqyc+PvOGvYYPmTNbl2Xgzq17Wgdie/Ia1cYFkqO8pHftAx2FGVPUMVVJkul8VLK61cXJl67gc6pTSbAvcVgJ245259TW5Vm5M1k6i9xPlO7uG+b1Ww3zdOVdXCk5h/pHsgtM0C64p7WNywqWz3j8tdsgLX0tXHJ+itiNFbVsu176UIN/SL7xMOQOFR2lOl7a9fN3MP4rYHpbzxq7dsGk/1O1QMzT6nYOAqSAZFqaPvY78hYecQIBjzJGQgbNgsk2UeaH8Ji93RdLvefdY3ohDeZyNlx7G8iGjJMqvA5/pV61fE9YGy93fU6ANxer3NcWNwupXSs67/wE=</diagram></mxfile>
\ No newline at end of file diff --git a/plugins/rpc/interface.go b/plugins/rpc/interface.go new file mode 100644 index 00000000..683fd2ec --- /dev/null +++ b/plugins/rpc/interface.go @@ -0,0 +1,7 @@ +package rpc + +// RPCer declares the ability to create set of public RPC methods. +type RPCer interface { + // Provides RPC methods for the given service. + RPC() interface{} +} diff --git a/plugins/rpc/plugin.go b/plugins/rpc/plugin.go new file mode 100644 index 00000000..c5813e7b --- /dev/null +++ b/plugins/rpc/plugin.go @@ -0,0 +1,161 @@ +package rpc + +import ( + "net" + "net/rpc" + "sync/atomic" + + "github.com/spiral/endure" + "github.com/spiral/errors" + goridgeRpc "github.com/spiral/goridge/v3/pkg/rpc" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// PluginName contains default plugin name. +const PluginName = "RPC" + +type pluggable struct { + service RPCer + name string +} + +// Plugin is RPC service. +type Plugin struct { + cfg Config + log logger.Logger + rpc *rpc.Server + // set of the plugins, which are implement RPCer interface and can be plugged into the RR via RPC + plugins []pluggable + listener net.Listener + closed *uint32 +} + +// Init rpc service. Must return true if service is enabled. +func (s *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("rpc plugin init") + if !cfg.Has(PluginName) { + return errors.E(op, errors.Disabled) + } + + err := cfg.UnmarshalKey(PluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + s.cfg.InitDefaults() + + s.log = log + state := uint32(0) + s.closed = &state + atomic.StoreUint32(s.closed, 0) + + return s.cfg.Valid() +} + +// Serve serves the service. +func (s *Plugin) Serve() chan error { + const op = errors.Op("serve rpc plugin") + errCh := make(chan error, 1) + + s.rpc = rpc.NewServer() + + services := make([]string, 0, len(s.plugins)) + + // Attach all services + for i := 0; i < len(s.plugins); i++ { + err := s.Register(s.plugins[i].name, s.plugins[i].service.RPC()) + if err != nil { + errCh <- errors.E(op, err) + return errCh + } + + services = append(services, s.plugins[i].name) + } + + var err error + s.listener, err = s.cfg.Listener() + if err != nil { + errCh <- err + return errCh + } + + s.log.Debug("Started RPC service", "address", s.cfg.Listen, "services", services) + + go func() { + for { + conn, err := s.listener.Accept() + if err != nil { + if atomic.LoadUint32(s.closed) == 1 { + // just log and continue, this is not a critical issue, we just called Stop + s.log.Warn("listener accept error, connection closed", "error", err) + return + } + + s.log.Error("listener accept error", "error", err) + errCh <- errors.E(errors.Op("listener accept"), errors.Serve, err) + return + } + + go s.rpc.ServeCodec(goridgeRpc.NewCodec(conn)) + } + }() + + return errCh +} + +// Stop stops the service. +func (s *Plugin) Stop() error { + // store closed state + atomic.StoreUint32(s.closed, 1) + err := s.listener.Close() + if err != nil { + return errors.E(errors.Op("stop RPC socket"), err) + } + return nil +} + +// Name contains service name. +func (s *Plugin) Name() string { + return PluginName +} + +// Depends declares services to collect for RPC. +func (s *Plugin) Collects() []interface{} { + return []interface{}{ + s.RegisterPlugin, + } +} + +// RegisterPlugin registers RPC service plugin. +func (s *Plugin) RegisterPlugin(name endure.Named, p RPCer) { + s.plugins = append(s.plugins, pluggable{ + service: p, + name: name.Name(), + }) +} + +// Register publishes in the server the set of methods of the +// receiver value that satisfy the following conditions: +// - exported method of exported type +// - two arguments, both of exported type +// - the second argument is a pointer +// - one return value, of type error +// It returns an error if the receiver is not an exported type or has +// no suitable methods. It also logs the error using package log. +func (s *Plugin) Register(name string, svc interface{}) error { + if s.rpc == nil { + return errors.E("RPC service is not configured") + } + + return s.rpc.RegisterName(name, svc) +} + +// Client creates new RPC client. +func (s *Plugin) Client() (*rpc.Client, error) { + conn, err := s.cfg.Dialer() + if err != nil { + return nil, err + } + + return rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)), nil +} diff --git a/plugins/server/config.go b/plugins/server/config.go new file mode 100644 index 00000000..a990efd3 --- /dev/null +++ b/plugins/server/config.go @@ -0,0 +1,147 @@ +package server + +import ( + "time" +) + +// All config (.rr.yaml) +// For other section use pointer to distinguish between `empty` and `not present` +type Config struct { + // Server config section + Server struct { + // Command to run as application. + Command string `yaml:"command"` + // User to run application under. + User string `yaml:"user"` + // Group to run application under. + Group string `yaml:"group"` + // Env represents application environment. + Env Env `yaml:"env"` + // Relay defines connection method and factory to be used to connect to workers: + // "pipes", "tcp://:6001", "unix://rr.sock" + // This config section must not change on re-configuration. + Relay string `yaml:"relay"` + // RelayTimeout defines for how long socket factory will be waiting for worker connection. This config section + // must not change on re-configuration. Defaults to 60s. + RelayTimeout time.Duration `yaml:"relayTimeout"` + } `yaml:"server"` + + RPC *struct { + Listen string `yaml:"listen"` + } `yaml:"rpc"` + Logs *struct { + Mode string `yaml:"mode"` + Level string `yaml:"level"` + } `yaml:"logs"` + HTTP *struct { + Address string `yaml:"address"` + MaxRequestSize int `yaml:"max_request_size"` + Middleware []string `yaml:"middleware"` + Uploads struct { + Forbid []string `yaml:"forbid"` + } `yaml:"uploads"` + TrustedSubnets []string `yaml:"trusted_subnets"` + Pool struct { + NumWorkers int `yaml:"num_workers"` + MaxJobs int `yaml:"max_jobs"` + AllocateTimeout string `yaml:"allocate_timeout"` + DestroyTimeout string `yaml:"destroy_timeout"` + Supervisor struct { + WatchTick int `yaml:"watch_tick"` + TTL int `yaml:"ttl"` + IdleTTL int `yaml:"idle_ttl"` + ExecTTL int `yaml:"exec_ttl"` + MaxWorkerMemory int `yaml:"max_worker_memory"` + } `yaml:"supervisor"` + } `yaml:"pool"` + Ssl struct { + Port int `yaml:"port"` + Redirect bool `yaml:"redirect"` + Cert string `yaml:"cert"` + Key string `yaml:"key"` + } `yaml:"ssl"` + Fcgi struct { + Address string `yaml:"address"` + } `yaml:"fcgi"` + HTTP2 struct { + Enabled bool `yaml:"enabled"` + H2C bool `yaml:"h2c"` + MaxConcurrentStreams int `yaml:"max_concurrent_streams"` + } `yaml:"http2"` + } `yaml:"http"` + Redis *struct { + Addrs []string `yaml:"addrs"` + MasterName string `yaml:"master_name"` + Username string `yaml:"username"` + Password string `yaml:"password"` + DB int `yaml:"db"` + SentinelPassword string `yaml:"sentinel_password"` + RouteByLatency bool `yaml:"route_by_latency"` + RouteRandomly bool `yaml:"route_randomly"` + DialTimeout int `yaml:"dial_timeout"` + MaxRetries int `yaml:"max_retries"` + MinRetryBackoff int `yaml:"min_retry_backoff"` + MaxRetryBackoff int `yaml:"max_retry_backoff"` + PoolSize int `yaml:"pool_size"` + MinIdleConns int `yaml:"min_idle_conns"` + MaxConnAge int `yaml:"max_conn_age"` + ReadTimeout int `yaml:"read_timeout"` + WriteTimeout int `yaml:"write_timeout"` + PoolTimeout int `yaml:"pool_timeout"` + IdleTimeout int `yaml:"idle_timeout"` + IdleCheckFreq int `yaml:"idle_check_freq"` + ReadOnly bool `yaml:"read_only"` + } `yaml:"redis"` + Boltdb *struct { + Dir string `yaml:"dir"` + File string `yaml:"file"` + Bucket string `yaml:"bucket"` + Permissions int `yaml:"permissions"` + TTL int `yaml:"TTL"` + } `yaml:"boltdb"` + Memcached *struct { + Addr []string `yaml:"addr"` + } `yaml:"memcached"` + Memory *struct { + Enabled bool `yaml:"enabled"` + Interval int `yaml:"interval"` + } `yaml:"memory"` + Metrics *struct { + Address string `yaml:"address"` + Collect struct { + AppMetric struct { + Type string `yaml:"type"` + Help string `yaml:"help"` + Labels []string `yaml:"labels"` + Buckets []float64 `yaml:"buckets"` + Objectives []struct { + Num2 float64 `yaml:"2,omitempty"` + One4 float64 `yaml:"1.4,omitempty"` + } `yaml:"objectives"` + } `yaml:"app_metric"` + } `yaml:"collect"` + } `yaml:"metrics"` + Reload *struct { + Interval string `yaml:"interval"` + Patterns []string `yaml:"patterns"` + Services struct { + HTTP struct { + Recursive bool `yaml:"recursive"` + Ignore []string `yaml:"ignore"` + Patterns []string `yaml:"patterns"` + Dirs []string `yaml:"dirs"` + } `yaml:"http"` + } `yaml:"services"` + } `yaml:"reload"` +} + +// InitDefaults for the server config +func (cfg *Config) InitDefaults() { + if cfg.Server.Relay == "" { + cfg.Server.Relay = "pipes" + } + + if cfg.Server.RelayTimeout == 0 { + cfg.Server.RelayTimeout = time.Second * 60 + } +} diff --git a/plugins/server/interface.go b/plugins/server/interface.go new file mode 100644 index 00000000..a2d8b92b --- /dev/null +++ b/plugins/server/interface.go @@ -0,0 +1,21 @@ +package server + +import ( + "context" + "os/exec" + + "github.com/spiral/roadrunner/v2/interfaces/events" + "github.com/spiral/roadrunner/v2/interfaces/pool" + "github.com/spiral/roadrunner/v2/interfaces/worker" + poolImpl "github.com/spiral/roadrunner/v2/pkg/pool" +) + +// Env variables type alias +type Env map[string]string + +// Server creates workers for the application. +type Server interface { + CmdFactory(env Env) (func() *exec.Cmd, error) + NewWorker(ctx context.Context, env Env, listeners ...events.Listener) (worker.BaseProcess, error) + NewWorkerPool(ctx context.Context, opt poolImpl.Config, env Env, listeners ...events.Listener) (pool.Pool, error) +} diff --git a/plugins/server/plugin.go b/plugins/server/plugin.go new file mode 100644 index 00000000..565c80c4 --- /dev/null +++ b/plugins/server/plugin.go @@ -0,0 +1,257 @@ +package server + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" + + // core imports + "github.com/spiral/roadrunner/v2/interfaces/events" + "github.com/spiral/roadrunner/v2/interfaces/pool" + "github.com/spiral/roadrunner/v2/interfaces/worker" + "github.com/spiral/roadrunner/v2/pkg/pipe" + poolImpl "github.com/spiral/roadrunner/v2/pkg/pool" + "github.com/spiral/roadrunner/v2/pkg/socket" + "github.com/spiral/roadrunner/v2/utils" +) + +// PluginName for the server +const PluginName = "server" + +// RR_RELAY env variable key (internal) +const RR_RELAY = "RR_RELAY" //nolint:golint,stylecheck +// RR_RPC env variable key (internal) if the RPC presents +const RR_RPC = "" //nolint:golint,stylecheck +// RR_HTTP env variable key (internal) if the HTTP presents +const RR_HTTP = "false" //nolint:golint,stylecheck + +// Plugin manages worker +type Plugin struct { + cfg Config + log logger.Logger + factory worker.Factory +} + +// Init application provider. +func (server *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("Init") + err := cfg.Unmarshal(&server.cfg) + if err != nil { + return errors.E(op, errors.Init, err) + } + server.cfg.InitDefaults() + server.log = log + + server.factory, err = server.initFactory() + if err != nil { + return errors.E(err) + } + + return nil +} + +// Name contains service name. +func (server *Plugin) Name() string { + return PluginName +} + +// Serve (Start) server plugin (just a mock here to satisfy interface) +func (server *Plugin) Serve() chan error { + errCh := make(chan error, 1) + return errCh +} + +// Stop used to close chosen in config factory +func (server *Plugin) Stop() error { + if server.factory == nil { + return nil + } + + return server.factory.Close() +} + +// CmdFactory provides worker command factory associated with given context. +func (server *Plugin) CmdFactory(env Env) (func() *exec.Cmd, error) { + const op = errors.Op("cmd factory") + var cmdArgs []string + + // create command according to the config + cmdArgs = append(cmdArgs, strings.Split(server.cfg.Server.Command, " ")...) + if len(cmdArgs) < 2 { + return nil, errors.E(op, errors.Str("should be in form of `php <script>")) + } + if cmdArgs[0] != "php" { + return nil, errors.E(op, errors.Str("first arg in command should be `php`")) + } + + _, err := os.Stat(cmdArgs[1]) + if err != nil { + return nil, errors.E(op, err) + } + return func() *exec.Cmd { + cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) //nolint:gosec + utils.IsolateProcess(cmd) + + // if user is not empty, and OS is linux or macos + // execute php worker from that particular user + if server.cfg.Server.User != "" { + err := utils.ExecuteFromUser(cmd, server.cfg.Server.User) + if err != nil { + return nil + } + } + + cmd.Env = server.setEnv(env) + + return cmd + }, nil +} + +// NewWorker issues new standalone worker. +func (server *Plugin) NewWorker(ctx context.Context, env Env, listeners ...events.Listener) (worker.BaseProcess, error) { + const op = errors.Op("new worker") + + list := make([]events.Listener, 0, len(listeners)) + list = append(list, server.collectWorkerLogs) + + spawnCmd, err := server.CmdFactory(env) + if err != nil { + return nil, errors.E(op, err) + } + + w, err := server.factory.SpawnWorkerWithTimeout(ctx, spawnCmd(), list...) + if err != nil { + return nil, errors.E(op, err) + } + + return w, nil +} + +// NewWorkerPool issues new worker pool. +func (server *Plugin) NewWorkerPool(ctx context.Context, opt poolImpl.Config, env Env, listeners ...events.Listener) (pool.Pool, error) { + const op = errors.Op("server plugins new worker pool") + spawnCmd, err := server.CmdFactory(env) + if err != nil { + return nil, errors.E(op, err) + } + + list := make([]events.Listener, 0, 1) + list = append(list, server.collectPoolLogs) + if len(listeners) != 0 { + list = append(list, listeners...) + } + + p, err := poolImpl.Initialize(ctx, spawnCmd, server.factory, opt, poolImpl.AddListeners(list...)) + if err != nil { + return nil, errors.E(op, err) + } + + return p, nil +} + +// creates relay and worker factory. +func (server *Plugin) initFactory() (worker.Factory, error) { + const op = errors.Op("server factory init") + if server.cfg.Server.Relay == "" || server.cfg.Server.Relay == "pipes" { + return pipe.NewPipeFactory(), nil + } + + dsn := strings.Split(server.cfg.Server.Relay, "://") + if len(dsn) != 2 { + return nil, errors.E(op, errors.Network, errors.Str("invalid DSN (tcp://:6001, unix://file.sock)")) + } + + lsn, err := utils.CreateListener(server.cfg.Server.Relay) + if err != nil { + return nil, errors.E(op, errors.Network, err) + } + + switch dsn[0] { + // sockets group + case "unix": + return socket.NewSocketServer(lsn, server.cfg.Server.RelayTimeout), nil + case "tcp": + return socket.NewSocketServer(lsn, server.cfg.Server.RelayTimeout), nil + default: + return nil, errors.E(op, errors.Network, errors.Str("invalid DSN (tcp://:6001, unix://file.sock)")) + } +} + +func (server *Plugin) setEnv(e Env) []string { + env := append(os.Environ(), fmt.Sprintf(RR_RELAY+"=%s", server.cfg.Server.Relay)) + for k, v := range e { + env = append(env, fmt.Sprintf("%s=%s", strings.ToUpper(k), v)) + } + + // set internal env variables + if server.cfg.HTTP != nil { + env = append(env, fmt.Sprintf("%s=%s", RR_HTTP, "true")) + } + if server.cfg.RPC != nil && server.cfg.RPC.Listen != "" { + env = append(env, fmt.Sprintf("%s=%s", RR_RPC, server.cfg.RPC.Listen)) + } + + // set env variables from the config + if len(server.cfg.Server.Env) > 0 { + for k, v := range server.cfg.Server.Env { + env = append(env, fmt.Sprintf("%s=%s", strings.ToUpper(k), v)) + } + } + + return env +} + +func (server *Plugin) collectPoolLogs(event interface{}) { + if we, ok := event.(events.PoolEvent); ok { + switch we.Event { + case events.EventMaxMemory: + server.log.Info("worker max memory reached", "pid", we.Payload.(worker.BaseProcess).Pid()) + case events.EventNoFreeWorkers: + server.log.Info("no free workers in pool", "error", we.Payload.(error).Error()) + case events.EventPoolError: + server.log.Info("pool error", "error", we.Payload.(error).Error()) + case events.EventSupervisorError: + server.log.Info("pool supervisor error", "error", we.Payload.(error).Error()) + case events.EventTTL: + server.log.Info("worker TTL reached", "pid", we.Payload.(worker.BaseProcess).Pid()) + case events.EventWorkerConstruct: + if _, ok := we.Payload.(error); ok { + server.log.Error("worker construction error", "error", we.Payload.(error).Error()) + return + } + server.log.Info("worker constructed", "pid", we.Payload.(worker.BaseProcess).Pid()) + case events.EventWorkerDestruct: + server.log.Info("worker destructed", "pid", we.Payload.(worker.BaseProcess).Pid()) + case events.EventExecTTL: + server.log.Info("EVENT EXEC TTL PLACEHOLDER") + case events.EventIdleTTL: + server.log.Info("worker IDLE timeout reached", "pid", we.Payload.(worker.BaseProcess).Pid()) + } + } + + if we, ok := event.(events.WorkerEvent); ok { + switch we.Event { + case events.EventWorkerError: + server.log.Info(we.Payload.(error).Error(), "pid", we.Worker.(worker.BaseProcess).Pid()) + case events.EventWorkerLog: + server.log.Info(strings.TrimRight(string(we.Payload.([]byte)), " \n\t"), "pid", we.Worker.(worker.BaseProcess).Pid()) + } + } +} + +func (server *Plugin) collectWorkerLogs(event interface{}) { + if we, ok := event.(events.WorkerEvent); ok { + switch we.Event { + case events.EventWorkerError: + server.log.Error(we.Payload.(error).Error(), "pid", we.Worker.(worker.BaseProcess).Pid()) + case events.EventWorkerLog: + server.log.Info(strings.TrimRight(string(we.Payload.([]byte)), " \n\t"), "pid", we.Worker.(worker.BaseProcess).Pid()) + } + } +} diff --git a/plugins/static/config.go b/plugins/static/config.go new file mode 100644 index 00000000..f5d26b2d --- /dev/null +++ b/plugins/static/config.go @@ -0,0 +1,76 @@ +package static + +import ( + "os" + "path" + "strings" + + "github.com/spiral/errors" +) + +// Config describes file location and controls access to them. +type Config struct { + Static struct { + // Dir contains name of directory to control access to. + Dir string + + // Forbid specifies list of file extensions which are forbidden for access. + // Example: .php, .exe, .bat, .htaccess and etc. + Forbid []string + + // Always specifies list of extensions which must always be served by static + // service, even if file not found. + Always []string + + // Request headers to add to every static. + Request map[string]string + + // Response headers to add to every static. + Response map[string]string + } +} + +// Valid returns nil if config is valid. +func (c *Config) Valid() error { + const op = errors.Op("static plugin validation") + st, err := os.Stat(c.Static.Dir) + if err != nil { + if os.IsNotExist(err) { + return errors.E(op, errors.Errorf("root directory '%s' does not exists", c.Static.Dir)) + } + + return err + } + + if !st.IsDir() { + return errors.E(op, errors.Errorf("invalid root directory '%s'", c.Static.Dir)) + } + + return nil +} + +// AlwaysForbid must return true if file extension is not allowed for the upload. +func (c *Config) AlwaysForbid(filename string) bool { + ext := strings.ToLower(path.Ext(filename)) + + for _, v := range c.Static.Forbid { + if ext == v { + return true + } + } + + return false +} + +// AlwaysServe must indicate that file is expected to be served by static service. +func (c *Config) AlwaysServe(filename string) bool { + ext := strings.ToLower(path.Ext(filename)) + + for _, v := range c.Static.Always { + if ext == v { + return true + } + } + + return false +} diff --git a/plugins/static/plugin.go b/plugins/static/plugin.go new file mode 100644 index 00000000..06b384df --- /dev/null +++ b/plugins/static/plugin.go @@ -0,0 +1,110 @@ +package static + +import ( + "net/http" + "path" + + "github.com/spiral/errors" + "github.com/spiral/roadrunner/v2/plugins/config" + "github.com/spiral/roadrunner/v2/plugins/logger" +) + +// ID contains default service name. +const PluginName = "static" + +const RootPluginName = "http" + +// Plugin serves static files. Potentially convert into middleware? +type Plugin struct { + // server configuration (location, forbidden files and etc) + cfg *Config + + log logger.Logger + + // root is initiated http directory + root http.Dir +} + +// Init must return configure service and return true if service hasStatus enabled. Must return error in case of +// misconfiguration. Services must not be used without proper configuration pushed first. +func (s *Plugin) Init(cfg config.Configurer, log logger.Logger) error { + const op = errors.Op("static plugin init") + err := cfg.UnmarshalKey(RootPluginName, &s.cfg) + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + s.log = log + s.root = http.Dir(s.cfg.Static.Dir) + + err = s.cfg.Valid() + if err != nil { + return errors.E(op, errors.Disabled, err) + } + + return nil +} + +func (s *Plugin) Name() string { + return PluginName +} + +// middleware must return true if request/response pair is handled within the middleware. +func (s *Plugin) Middleware(next http.Handler) http.HandlerFunc { + // Define the http.HandlerFunc + return func(w http.ResponseWriter, r *http.Request) { + if s.cfg.Static.Request != nil { + for k, v := range s.cfg.Static.Request { + r.Header.Add(k, v) + } + } + + if s.cfg.Static.Response != nil { + for k, v := range s.cfg.Static.Response { + w.Header().Set(k, v) + } + } + + if !s.handleStatic(w, r) { + next.ServeHTTP(w, r) + } + } +} + +func (s *Plugin) handleStatic(w http.ResponseWriter, r *http.Request) bool { + fPath := path.Clean(r.URL.Path) + + if s.cfg.AlwaysForbid(fPath) { + return false + } + + f, err := s.root.Open(fPath) + if err != nil { + s.log.Error("file open error", "error", err) + if s.cfg.AlwaysServe(fPath) { + w.WriteHeader(404) + return true + } + + return false + } + defer func() { + err = f.Close() + if err != nil { + s.log.Error("file closing error", "error", err) + } + }() + + d, err := f.Stat() + if err != nil { + return false + } + + // do not serve directories + if d.IsDir() { + return false + } + + http.ServeContent(w, r, d.Name(), d.ModTime(), f) + return true +} |