summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorValery Piashchynski <[email protected]>2024-05-07 18:21:01 +0200
committerValery Piashchynski <[email protected]>2024-05-07 18:21:01 +0200
commitd1928ddc8960e0b95aa0e94fa9d44eff61f7e25e (patch)
tree33c786de27c5885944ff07a3b58d16833cfb9fe2 /internal
parent277a6a99f3aa405fbcf476f475e09edf52439d42 (diff)
fix: correctly expand default bash env syntax
Signed-off-by: Valery Piashchynski <[email protected]>
Diffstat (limited to 'internal')
-rw-r--r--internal/rpc/client.go139
1 files changed, 133 insertions, 6 deletions
diff --git a/internal/rpc/client.go b/internal/rpc/client.go
index 2fe63407..7b3d8f5d 100644
--- a/internal/rpc/client.go
+++ b/internal/rpc/client.go
@@ -18,6 +18,8 @@ import (
const (
prefix string = "rr"
rpcKey string = "rpc.listen"
+ // default envs
+ envDefault = ":-"
)
// NewClient creates client ONLY for internal usage (communication between our application with RR side).
@@ -35,12 +37,7 @@ func NewClient(cfg string, flags []string) (*rpc.Client, error) {
}
// automatically inject ENV variables using ${ENV} pattern
- for _, key := range v.AllKeys() {
- val := v.Get(key)
- if s, ok := val.(string); ok {
- v.Set(key, os.ExpandEnv(s))
- }
- }
+ expandEnvViper(v)
// override config Flags
if len(flags) > 0 {
@@ -108,3 +105,133 @@ func parseValue(value string) string {
return value
}
+
+// ExpandVal replaces ${var} or $var in the string based on the mapping function.
+// For example, os.ExpandEnv(s) is equivalent to os.Expand(s, os.Getenv).
+func ExpandVal(s string, mapping func(string) string) string {
+ var buf []byte
+ // ${} is all ASCII, so bytes are fine for this operation.
+ i := 0
+ for j := 0; j < len(s); j++ {
+ if s[j] == '$' && j+1 < len(s) {
+ if buf == nil {
+ buf = make([]byte, 0, 2*len(s))
+ }
+ buf = append(buf, s[i:j]...)
+ name, w := getShellName(s[j+1:])
+ if name == "" && w > 0 { //nolint:revive
+ // Encountered invalid syntax; eat the
+ // characters.
+ } else if name == "" {
+ // Valid syntax, but $ was not followed by a
+ // name. Leave the dollar character untouched.
+ buf = append(buf, s[j])
+ // parse default syntax
+ } else if idx := strings.Index(s, envDefault); idx != -1 {
+ // ${key:=default} or ${key:-val}
+ substr := strings.Split(name, envDefault)
+ if len(substr) != 2 {
+ return ""
+ }
+
+ key := substr[0]
+ defaultVal := substr[1]
+
+ res := mapping(key)
+ if res == "" {
+ res = defaultVal
+ }
+
+ buf = append(buf, res...)
+ } else {
+ buf = append(buf, mapping(name)...)
+ }
+ j += w
+ i = j + 1
+ }
+ }
+ if buf == nil {
+ return s
+ }
+ return string(buf) + s[i:]
+}
+
+// getShellName returns the name that begins the string and the number of bytes
+// consumed to extract it. If the name is enclosed in {}, it's part of a ${}
+// expansion and two more bytes are needed than the length of the name.
+func getShellName(s string) (string, int) {
+ switch {
+ case s[0] == '{':
+ if len(s) > 2 && isShellSpecialVar(s[1]) && s[2] == '}' {
+ return s[1:2], 3
+ }
+ // Scan to closing brace
+ for i := 1; i < len(s); i++ {
+ if s[i] == '}' {
+ if i == 1 {
+ return "", 2 // Bad syntax; eat "${}"
+ }
+ return s[1:i], i + 1
+ }
+ }
+ return "", 1 // Bad syntax; eat "${"
+ case isShellSpecialVar(s[0]):
+ return s[0:1], 1
+ }
+ // Scan alphanumerics.
+ var i int
+ for i = 0; i < len(s) && isAlphaNum(s[i]); i++ { //nolint:revive
+
+ }
+ return s[:i], i
+}
+
+func expandEnvViper(v *viper.Viper) {
+ for _, key := range v.AllKeys() {
+ val := v.Get(key)
+ switch t := val.(type) {
+ case string:
+ // for string expand it
+ v.Set(key, parseEnvDefault(t))
+ case []any:
+ // for slice -> check if it's a slice of strings
+ strArr := make([]string, 0, len(t))
+ for i := 0; i < len(t); i++ {
+ if valStr, ok := t[i].(string); ok {
+ strArr = append(strArr, parseEnvDefault(valStr))
+ continue
+ }
+
+ v.Set(key, val)
+ }
+
+ // we should set the whole array
+ if len(strArr) > 0 {
+ v.Set(key, strArr)
+ }
+ default:
+ v.Set(key, val)
+ }
+ }
+}
+
+// isShellSpecialVar reports whether the character identifies a special
+// shell variable such as $*.
+func isShellSpecialVar(c uint8) bool {
+ switch c {
+ case '*', '#', '$', '@', '!', '?', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ return true
+ }
+ return false
+}
+
+// isAlphaNum reports whether the byte is an ASCII letter, number, or underscore.
+func isAlphaNum(c uint8) bool {
+ return c == '_' || '0' <= c && c <= '9' || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z'
+}
+
+func parseEnvDefault(val string) string {
+ // tcp://127.0.0.1:${RPC_PORT:-36643}
+ // for envs like this, part would be tcp://127.0.0.1:
+ return ExpandVal(val, os.Getenv)
+}