diff options
author | Valery Piashchynski <[email protected]> | 2024-05-07 18:21:01 +0200 |
---|---|---|
committer | Valery Piashchynski <[email protected]> | 2024-05-07 18:21:01 +0200 |
commit | d1928ddc8960e0b95aa0e94fa9d44eff61f7e25e (patch) | |
tree | 33c786de27c5885944ff07a3b58d16833cfb9fe2 /internal | |
parent | 277a6a99f3aa405fbcf476f475e09edf52439d42 (diff) |
fix: correctly expand default bash env syntax
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'internal')
-rw-r--r-- | internal/rpc/client.go | 139 |
1 files changed, 133 insertions, 6 deletions
diff --git a/internal/rpc/client.go b/internal/rpc/client.go index 2fe63407..7b3d8f5d 100644 --- a/internal/rpc/client.go +++ b/internal/rpc/client.go @@ -18,6 +18,8 @@ import ( const ( prefix string = "rr" rpcKey string = "rpc.listen" + // default envs + envDefault = ":-" ) // NewClient creates client ONLY for internal usage (communication between our application with RR side). @@ -35,12 +37,7 @@ func NewClient(cfg string, flags []string) (*rpc.Client, error) { } // automatically inject ENV variables using ${ENV} pattern - for _, key := range v.AllKeys() { - val := v.Get(key) - if s, ok := val.(string); ok { - v.Set(key, os.ExpandEnv(s)) - } - } + expandEnvViper(v) // override config Flags if len(flags) > 0 { @@ -108,3 +105,133 @@ func parseValue(value string) string { return value } + +// ExpandVal replaces ${var} or $var in the string based on the mapping function. +// For example, os.ExpandEnv(s) is equivalent to os.Expand(s, os.Getenv). +func ExpandVal(s string, mapping func(string) string) string { + var buf []byte + // ${} is all ASCII, so bytes are fine for this operation. + i := 0 + for j := 0; j < len(s); j++ { + if s[j] == '$' && j+1 < len(s) { + if buf == nil { + buf = make([]byte, 0, 2*len(s)) + } + buf = append(buf, s[i:j]...) + name, w := getShellName(s[j+1:]) + if name == "" && w > 0 { //nolint:revive + // Encountered invalid syntax; eat the + // characters. + } else if name == "" { + // Valid syntax, but $ was not followed by a + // name. Leave the dollar character untouched. + buf = append(buf, s[j]) + // parse default syntax + } else if idx := strings.Index(s, envDefault); idx != -1 { + // ${key:=default} or ${key:-val} + substr := strings.Split(name, envDefault) + if len(substr) != 2 { + return "" + } + + key := substr[0] + defaultVal := substr[1] + + res := mapping(key) + if res == "" { + res = defaultVal + } + + buf = append(buf, res...) + } else { + buf = append(buf, mapping(name)...) + } + j += w + i = j + 1 + } + } + if buf == nil { + return s + } + return string(buf) + s[i:] +} + +// getShellName returns the name that begins the string and the number of bytes +// consumed to extract it. If the name is enclosed in {}, it's part of a ${} +// expansion and two more bytes are needed than the length of the name. +func getShellName(s string) (string, int) { + switch { + case s[0] == '{': + if len(s) > 2 && isShellSpecialVar(s[1]) && s[2] == '}' { + return s[1:2], 3 + } + // Scan to closing brace + for i := 1; i < len(s); i++ { + if s[i] == '}' { + if i == 1 { + return "", 2 // Bad syntax; eat "${}" + } + return s[1:i], i + 1 + } + } + return "", 1 // Bad syntax; eat "${" + case isShellSpecialVar(s[0]): + return s[0:1], 1 + } + // Scan alphanumerics. + var i int + for i = 0; i < len(s) && isAlphaNum(s[i]); i++ { //nolint:revive + + } + return s[:i], i +} + +func expandEnvViper(v *viper.Viper) { + for _, key := range v.AllKeys() { + val := v.Get(key) + switch t := val.(type) { + case string: + // for string expand it + v.Set(key, parseEnvDefault(t)) + case []any: + // for slice -> check if it's a slice of strings + strArr := make([]string, 0, len(t)) + for i := 0; i < len(t); i++ { + if valStr, ok := t[i].(string); ok { + strArr = append(strArr, parseEnvDefault(valStr)) + continue + } + + v.Set(key, val) + } + + // we should set the whole array + if len(strArr) > 0 { + v.Set(key, strArr) + } + default: + v.Set(key, val) + } + } +} + +// isShellSpecialVar reports whether the character identifies a special +// shell variable such as $*. +func isShellSpecialVar(c uint8) bool { + switch c { + case '*', '#', '$', '@', '!', '?', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return true + } + return false +} + +// isAlphaNum reports whether the byte is an ASCII letter, number, or underscore. +func isAlphaNum(c uint8) bool { + return c == '_' || '0' <= c && c <= '9' || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' +} + +func parseEnvDefault(val string) string { + // tcp://127.0.0.1:${RPC_PORT:-36643} + // for envs like this, part would be tcp://127.0.0.1: + return ExpandVal(val, os.Getenv) +} |