diff options
Diffstat (limited to 'plugins/config')
-rw-r--r-- | plugins/config/provider.go | 15 | ||||
-rw-r--r-- | plugins/config/tests/.rr.yaml | 28 | ||||
-rw-r--r-- | plugins/config/tests/config_test.go | 67 | ||||
-rw-r--r-- | plugins/config/tests/plugin1.go | 54 | ||||
-rw-r--r-- | plugins/config/viper.go | 86 |
5 files changed, 250 insertions, 0 deletions
diff --git a/plugins/config/provider.go b/plugins/config/provider.go new file mode 100644 index 00000000..bec417e9 --- /dev/null +++ b/plugins/config/provider.go @@ -0,0 +1,15 @@ +package config + +type Provider interface { + // Unmarshal configuration section into configuration object. + // + // func (h *HttpService) Init(cp config.Provider) error { + // h.config := &HttpConfig{} + // if err := configProvider.UnmarshalKey("http", h.config); err != nil { + // return err + // } + // } + UnmarshalKey(name string, out interface{}) error + // Get used to get config section + Get(name string) interface{} +} diff --git a/plugins/config/tests/.rr.yaml b/plugins/config/tests/.rr.yaml new file mode 100644 index 00000000..df9077d0 --- /dev/null +++ b/plugins/config/tests/.rr.yaml @@ -0,0 +1,28 @@ +reload: + # enable or disable file watcher + enabled: true + # sync interval + interval: 1s + # global patterns to sync + patterns: [".php"] + # list of included for sync services + services: + http: + # recursive search for file patterns to add + recursive: true + # ignored folders + ignore: ["vendor"] + # service specific file pattens to sync + patterns: [".php", ".go",".md",] + # directories to sync. If recursive is set to true, + # recursive sync will be applied only to the directories in `dirs` section + dirs: ["."] + jobs: + recursive: false + ignore: ["service/metrics"] + dirs: ["./jobs"] + rpc: + recursive: true + patterns: [".json"] + # to include all project directories from workdir, leave `dirs` empty or add a dot "." + dirs: [""] diff --git a/plugins/config/tests/config_test.go b/plugins/config/tests/config_test.go new file mode 100644 index 00000000..baeafbd2 --- /dev/null +++ b/plugins/config/tests/config_test.go @@ -0,0 +1,67 @@ +package tests + +import ( + "os" + "os/signal" + "testing" + "time" + + "github.com/spiral/endure" + "github.com/stretchr/testify/assert" + "github.com/temporalio/roadrunner-temporal/config" +) + +func TestViperProvider_Init(t *testing.T) { + container, err := endure.NewContainer(endure.DebugLevel, endure.RetryOnFail(true)) + if err != nil { + t.Fatal(err) + } + vp := &config.ViperProvider{} + vp.Path = ".rr.yaml" + vp.Prefix = "rr" + err = container.Register(vp) + if err != nil { + t.Fatal(err) + } + + err = container.Register(&Foo{}) + if err != nil { + t.Fatal(err) + } + + err = container.Init() + if err != nil { + t.Fatal(err) + } + + errCh, err := container.Serve() + if err != nil { + t.Fatal(err) + } + + // stop by CTRL+C + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt) + + tt := time.NewTicker(time.Second * 2) + + for { + select { + case e := <-errCh: + assert.NoError(t, e.Error.Err) + assert.NoError(t, container.Stop()) + return + case <-c: + er := container.Stop() + if er != nil { + panic(er) + } + return + case <-tt.C: + tt.Stop() + assert.NoError(t, container.Stop()) + return + } + } + +} diff --git a/plugins/config/tests/plugin1.go b/plugins/config/tests/plugin1.go new file mode 100644 index 00000000..4e7a5317 --- /dev/null +++ b/plugins/config/tests/plugin1.go @@ -0,0 +1,54 @@ +package tests + +import ( + "errors" + "time" + + "github.com/temporalio/roadrunner-temporal/config" +) + +// ReloadConfig is a Reload configuration point. +type ReloadConfig struct { + Interval time.Duration + Patterns []string + Services map[string]ServiceConfig +} + +type ServiceConfig struct { + Enabled bool + Recursive bool + Patterns []string + Dirs []string + Ignore []string +} + +type Foo struct { + configProvider config.Provider +} + + +// Depends on S2 and DB (S3 in the current case) +func (f *Foo) Init(p config.Provider) error { + f.configProvider = p + return nil +} + +func (f *Foo) Serve() chan error { + errCh := make(chan error, 1) + + r := &ReloadConfig{} + err := f.configProvider.UnmarshalKey("reload", r) + if err != nil { + errCh <- err + } + + if len(r.Patterns) == 0 { + errCh <- errors.New("should be at least one pattern, but got 0") + } + + return errCh +} + +func (f *Foo) Stop() error { + return nil +} diff --git a/plugins/config/viper.go b/plugins/config/viper.go new file mode 100644 index 00000000..0362e79b --- /dev/null +++ b/plugins/config/viper.go @@ -0,0 +1,86 @@ +package config + +import ( + "errors" + "fmt" + "strings" + + "github.com/spf13/viper" +) + +type ViperProvider struct { + viper *viper.Viper + Path string + Prefix string +} + +//////// ENDURE ////////// +func (v *ViperProvider) Init() error { + v.viper = viper.New() + // read in environment variables that match + v.viper.AutomaticEnv() + if v.Prefix == "" { + return errors.New("prefix should be set") + } + v.viper.SetEnvPrefix(v.Prefix) + if v.Path == "" { + return errors.New("path should be set") + } + v.viper.SetConfigFile(v.Path) + v.viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + return v.viper.ReadInConfig() +} + +///////////// VIPER /////////////// + +// Overwrite overwrites existing config with provided values +func (v *ViperProvider) Overwrite(values map[string]string) error { + if len(values) != 0 { + for _, flag := range values { + key, value, err := parseFlag(flag) + if err != nil { + return err + } + v.viper.Set(key, value) + } + } + + return nil +} + +// +func (v *ViperProvider) UnmarshalKey(name string, out interface{}) error { + err := v.viper.UnmarshalKey(name, &out) + if err != nil { + return err + } + return nil +} + +// Get raw config in a form of config section. +func (v *ViperProvider) Get(name string) interface{} { + return v.viper.Get(name) +} + +/////////// PRIVATE ////////////// + +func parseFlag(flag string) (string, string, error) { + if !strings.Contains(flag, "=") { + return "", "", fmt.Errorf("invalid flag `%s`", flag) + } + + parts := strings.SplitN(strings.TrimLeft(flag, " \"'`"), "=", 2) + + return strings.Trim(parts[0], " \n\t"), parseValue(strings.Trim(parts[1], " \n\t")), nil +} + +func parseValue(value string) string { + escape := []rune(value)[0] + + if escape == '"' || escape == '\'' || escape == '`' { + value = strings.Trim(value, string(escape)) + value = strings.Replace(value, fmt.Sprintf("\\%s", string(escape)), string(escape), -1) + } + + return value +} |