summaryrefslogtreecommitdiff
path: root/internal/rpc/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/rpc/client.go')
-rw-r--r--internal/rpc/client.go87
1 files changed, 77 insertions, 10 deletions
diff --git a/internal/rpc/client.go b/internal/rpc/client.go
index 7d945add..1e0fbbac 100644
--- a/internal/rpc/client.go
+++ b/internal/rpc/client.go
@@ -1,19 +1,32 @@
// Package rpc contains wrapper around RPC client ONLY for internal usage.
+// Should be in sync with the RPC plugin
package rpc
import (
+ "errors"
+ "fmt"
+ "net"
"net/rpc"
+ "os"
+ "strings"
- "github.com/roadrunner-server/errors"
goridgeRpc "github.com/roadrunner-server/goridge/v3/pkg/rpc"
rpcPlugin "github.com/roadrunner-server/rpc/v2"
"github.com/spf13/viper"
)
+const (
+ prefix string = "rr"
+ rpcKey string = "rpc.listen"
+)
+
// NewClient creates client ONLY for internal usage (communication between our application with RR side).
// Client will be connected to the RPC.
-func NewClient(cfg string) (*rpc.Client, error) {
+func NewClient(cfg string, flags []string) (*rpc.Client, error) {
v := viper.New()
+ v.AutomaticEnv()
+ v.SetEnvPrefix(prefix)
+ v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetConfigFile(cfg)
err := v.ReadInConfig()
@@ -21,23 +34,77 @@ func NewClient(cfg string) (*rpc.Client, error) {
return nil, err
}
- if !v.IsSet(rpcPlugin.PluginName) {
- return nil, errors.E("rpc service disabled")
+ // automatically inject ENV variables using ${ENV} pattern
+ for _, key := range v.AllKeys() {
+ val := v.Get(key)
+ if s, ok := val.(string); ok {
+ v.Set(key, os.ExpandEnv(s))
+ }
}
- rpcConfig := &rpcPlugin.Config{}
+ // override config Flags
+ if len(flags) > 0 {
+ for _, f := range flags {
+ key, val, errP := parseFlag(f)
+ if errP != nil {
+ return nil, errP
+ }
- err = v.UnmarshalKey(rpcPlugin.PluginName, rpcConfig)
- if err != nil {
- return nil, err
+ v.Set(key, val)
+ }
}
- rpcConfig.InitDefaults()
+ // rpc.listen might be set by the -o flags or env variable
+ if !v.IsSet(rpcPlugin.PluginName) {
+ return nil, errors.New("rpc service not specified in the configuration. Tip: add\n rpc:\n\r listen: rr_rpc_address")
+ }
- conn, err := rpcConfig.Dialer()
+ conn, err := Dialer(v.GetString(rpcKey))
if err != nil {
return nil, err
}
return rpc.NewClientWithCodec(goridgeRpc.NewClientCodec(conn)), nil
}
+
+// Dialer creates rpc socket Dialer.
+func Dialer(addr string) (net.Conn, error) {
+ dsn := strings.Split(addr, "://")
+ if len(dsn) != 2 {
+ return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
+ }
+
+ return net.Dial(dsn[0], dsn[1])
+}
+
+func parseFlag(flag string) (string, string, error) {
+ if !strings.Contains(flag, "=") {
+ return "", "", fmt.Errorf("invalid flag `%s`", flag)
+ }
+
+ parts := strings.SplitN(strings.TrimLeft(flag, " \"'`"), "=", 2)
+ if len(parts) < 2 {
+ return "", "", errors.New("usage: -o key=value")
+ }
+
+ if parts[0] == "" {
+ return "", "", errors.New("key should not be empty")
+ }
+
+ if parts[1] == "" {
+ return "", "", errors.New("value should not be empty")
+ }
+
+ return strings.Trim(parts[0], " \n\t"), parseValue(strings.Trim(parts[1], " \n\t")), nil
+}
+
+func parseValue(value string) string {
+ escape := []rune(value)[0]
+
+ if escape == '"' || escape == '\'' || escape == '`' {
+ value = strings.Trim(value, string(escape))
+ value = strings.ReplaceAll(value, fmt.Sprintf("\\%s", string(escape)), string(escape))
+ }
+
+ return value
+}