summaryrefslogtreecommitdiff
path: root/plugins/kv/memory
diff options
context:
space:
mode:
Diffstat (limited to 'plugins/kv/memory')
-rw-r--r--plugins/kv/memory/config.go12
-rw-r--r--plugins/kv/memory/storage.go263
-rw-r--r--plugins/kv/memory/storage_test.go470
3 files changed, 745 insertions, 0 deletions
diff --git a/plugins/kv/memory/config.go b/plugins/kv/memory/config.go
new file mode 100644
index 00000000..329e7fff
--- /dev/null
+++ b/plugins/kv/memory/config.go
@@ -0,0 +1,12 @@
+package memory
+
+// Config is default config for the in-memory driver
+type Config struct {
+ // Enabled or disabled (true or false)
+ Enabled bool
+}
+
+// InitDefaults by default driver is turned off
+func (c *Config) InitDefaults() {
+ c.Enabled = false
+}
diff --git a/plugins/kv/memory/storage.go b/plugins/kv/memory/storage.go
new file mode 100644
index 00000000..1b6cb580
--- /dev/null
+++ b/plugins/kv/memory/storage.go
@@ -0,0 +1,263 @@
+package memory
+
+import (
+ "context"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/spiral/errors"
+ "github.com/spiral/roadrunner/v2/plugins/config"
+ "github.com/spiral/roadrunner/v2/plugins/kv"
+ "github.com/spiral/roadrunner/v2/plugins/logger"
+)
+
+const PluginName = "memory"
+
+type Plugin struct {
+ heap *sync.Map
+ stop chan struct{}
+
+ log logger.Logger
+ cfg *Config
+}
+
+func NewInMemoryStorage() kv.Storage {
+ p := &Plugin{
+ heap: &sync.Map{},
+ stop: make(chan struct{}),
+ }
+
+ go p.gc()
+
+ return p
+}
+
+func (s *Plugin) Init(cfg config.Configurer, log logger.Logger) error {
+ const op = errors.Op("in-memory storage init")
+ s.cfg = &Config{}
+ s.cfg.InitDefaults()
+
+ err := cfg.UnmarshalKey(PluginName, &s.cfg)
+ if err != nil {
+ return errors.E(op, err)
+ }
+ s.log = log
+ // init in-memory
+ s.heap = &sync.Map{}
+ s.stop = make(chan struct{}, 1)
+ return nil
+}
+
+func (s Plugin) Serve() chan error {
+ errCh := make(chan error, 1)
+ // start in-memory gc for kv
+ go s.gc()
+
+ return errCh
+}
+
+func (s Plugin) Stop() error {
+ const op = errors.Op("in-memory storage stop")
+ err := s.Close()
+ if err != nil {
+ return errors.E(op, err)
+ }
+ return nil
+}
+
+func (s Plugin) Has(ctx context.Context, keys ...string) (map[string]bool, error) {
+ const op = errors.Op("in-memory storage Has")
+ if keys == nil {
+ return nil, errors.E(op, errors.NoKeys)
+ }
+ m := make(map[string]bool)
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+
+ if _, ok := s.heap.Load(key); ok {
+ m[key] = true
+ }
+ }
+
+ return m, nil
+}
+
+func (s Plugin) Get(ctx context.Context, key string) ([]byte, error) {
+ const op = errors.Op("in-memory storage Get")
+ // to get cases like " "
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+
+ if data, exist := s.heap.Load(key); exist {
+ // here might be a panic
+ // but data only could be a string, see Set function
+ return []byte(data.(kv.Item).Value), nil
+ }
+ return nil, nil
+}
+
+func (s Plugin) MGet(ctx context.Context, keys ...string) (map[string]interface{}, error) {
+ const op = errors.Op("in-memory storage MGet")
+ if keys == nil {
+ return nil, errors.E(op, errors.NoKeys)
+ }
+
+ // should not be empty keys
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+ }
+
+ m := make(map[string]interface{}, len(keys))
+
+ for _, key := range keys {
+ if value, ok := s.heap.Load(key); ok {
+ m[key] = value
+ }
+ }
+
+ return m, nil
+}
+
+func (s Plugin) Set(ctx context.Context, items ...kv.Item) error {
+ const op = errors.Op("in-memory storage Set")
+ if items == nil {
+ return errors.E(op, errors.NoKeys)
+ }
+
+ for _, item := range items {
+ // TTL is set
+ if item.TTL != "" {
+ // check the TTL in the item
+ _, err := time.Parse(time.RFC3339, item.TTL)
+ if err != nil {
+ return err
+ }
+ }
+
+ s.heap.Store(item.Key, item)
+ }
+ return nil
+}
+
+// MExpire sets the expiration time to the key
+// If key already has the expiration time, it will be overwritten
+func (s Plugin) MExpire(ctx context.Context, items ...kv.Item) error {
+ const op = errors.Op("in-memory storage MExpire")
+ for _, item := range items {
+ if item.TTL == "" || strings.TrimSpace(item.Key) == "" {
+ return errors.E(op, errors.Str("should set timeout and at least one key"))
+ }
+
+ // if key exist, overwrite it value
+ if pItem, ok := s.heap.Load(item.Key); ok {
+ // check that time is correct
+ _, err := time.Parse(time.RFC3339, item.TTL)
+ if err != nil {
+ return errors.E(op, err)
+ }
+ tmp := pItem.(kv.Item)
+ // guess that t is in the future
+ // in memory is just FOR TESTING PURPOSES
+ // LOGIC ISN'T IDEAL
+ s.heap.Store(item.Key, kv.Item{
+ Key: item.Key,
+ Value: tmp.Value,
+ TTL: item.TTL,
+ })
+ }
+ }
+
+ return nil
+}
+
+func (s Plugin) TTL(ctx context.Context, keys ...string) (map[string]interface{}, error) {
+ const op = errors.Op("in-memory storage TTL")
+ if keys == nil {
+ return nil, errors.E(op, errors.NoKeys)
+ }
+
+ // should not be empty keys
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return nil, errors.E(op, errors.EmptyKey)
+ }
+ }
+
+ m := make(map[string]interface{}, len(keys))
+
+ for _, key := range keys {
+ if item, ok := s.heap.Load(key); ok {
+ m[key] = item.(kv.Item).TTL
+ }
+ }
+ return m, nil
+}
+
+func (s Plugin) Delete(ctx context.Context, keys ...string) error {
+ const op = errors.Op("in-memory storage Delete")
+ if keys == nil {
+ return errors.E(op, errors.NoKeys)
+ }
+
+ // should not be empty keys
+ for _, key := range keys {
+ keyTrimmed := strings.TrimSpace(key)
+ if keyTrimmed == "" {
+ return errors.E(op, errors.EmptyKey)
+ }
+ }
+
+ for _, key := range keys {
+ s.heap.Delete(key)
+ }
+ return nil
+}
+
+// Close clears the in-memory storage
+func (s Plugin) Close() error {
+ s.heap = &sync.Map{}
+ s.stop <- struct{}{}
+ return nil
+}
+
+// ================================== PRIVATE ======================================
+
+func (s *Plugin) gc() {
+ // TODO check
+ ticker := time.NewTicker(time.Millisecond * 500)
+ for {
+ select {
+ case <-s.stop:
+ ticker.Stop()
+ return
+ case now := <-ticker.C:
+ // check every second
+ s.heap.Range(func(key, value interface{}) bool {
+ v := value.(kv.Item)
+ if v.TTL == "" {
+ return true
+ }
+
+ t, err := time.Parse(time.RFC3339, v.TTL)
+ if err != nil {
+ return false
+ }
+
+ if now.After(t) {
+ s.heap.Delete(key)
+ }
+ return true
+ })
+ }
+ }
+}
diff --git a/plugins/kv/memory/storage_test.go b/plugins/kv/memory/storage_test.go
new file mode 100644
index 00000000..b7b46637
--- /dev/null
+++ b/plugins/kv/memory/storage_test.go
@@ -0,0 +1,470 @@
+package memory
+
+import (
+ "context"
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/spiral/roadrunner/v2/plugins/kv"
+ "github.com/stretchr/testify/assert"
+)
+
+func initStorage() kv.Storage {
+ return NewInMemoryStorage()
+}
+
+func cleanup(t *testing.T, s kv.Storage, keys ...string) {
+ err := s.Delete(context.Background(), keys...)
+ if err != nil {
+ t.Fatalf("error during cleanup: %s", err.Error())
+ }
+}
+
+func TestStorage_Has(t *testing.T) {
+ s := initStorage()
+
+ ctx := context.Background()
+
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+}
+
+func TestStorage_Has_Set_Has(t *testing.T) {
+ s := initStorage()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+ if err := s.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ ctx := context.Background()
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ // no such key
+ assert.False(t, v["key"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "value",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "value",
+ TTL: "",
+ }))
+
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ // no such key
+ assert.True(t, v["key"])
+ assert.True(t, v["key2"])
+}
+
+func TestStorage_Has_Set_MGet(t *testing.T) {
+ s := initStorage()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+ if err := s.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ ctx := context.Background()
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ // no such key
+ assert.False(t, v["key"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "value",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "value",
+ TTL: "",
+ }))
+
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ // no such key
+ assert.True(t, v["key"])
+ assert.True(t, v["key2"])
+
+ res, err := s.MGet(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.Len(t, res, 2)
+}
+
+func TestStorage_Has_Set_Get(t *testing.T) {
+ s := initStorage()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+ if err := s.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ ctx := context.Background()
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ // no such key
+ assert.False(t, v["key"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "value",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "value",
+ TTL: "",
+ }))
+
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ // no such key
+ assert.True(t, v["key"])
+ assert.True(t, v["key2"])
+
+ res, err := s.Get(ctx, "key")
+ assert.NoError(t, err)
+
+ if string(res) != "value" {
+ t.Fatal("wrong value by key")
+ }
+}
+
+func TestStorage_Set_Del_Get(t *testing.T) {
+ s := initStorage()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+ if err := s.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ ctx := context.Background()
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ // no such key
+ assert.False(t, v["key"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "value",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "value",
+ TTL: "",
+ }))
+
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ // no such key
+ assert.True(t, v["key"])
+ assert.True(t, v["key2"])
+
+ // check that keys are present
+ res, err := s.MGet(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.Len(t, res, 2)
+
+ assert.NoError(t, s.Delete(ctx, "key", "key2"))
+ // check that keys are not presentps -eo state,uid,pid,ppid,rtprio,time,comm
+ res, err = s.MGet(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.Len(t, res, 0)
+}
+
+func TestStorage_Set_GetM(t *testing.T) {
+ s := initStorage()
+ ctx := context.Background()
+
+ defer func() {
+ cleanup(t, s, "key", "key2")
+
+ if err := s.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "value",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "value",
+ TTL: "",
+ }))
+
+ res, err := s.MGet(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.Len(t, res, 2)
+}
+
+func TestStorage_MExpire_TTL(t *testing.T) {
+ s := initStorage()
+ ctx := context.Background()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+
+ if err := s.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ // ensure that storage is clean
+ v, err := s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+ assert.False(t, v["key2"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "hello world",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "hello world",
+ TTL: "",
+ }))
+ // set timeout to 5 sec
+ nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339)
+
+ i1 := kv.Item{
+ Key: "key",
+ Value: "",
+ TTL: nowPlusFive,
+ }
+ i2 := kv.Item{
+ Key: "key2",
+ Value: "",
+ TTL: nowPlusFive,
+ }
+ assert.NoError(t, s.MExpire(ctx, i1, i2))
+
+ time.Sleep(time.Second * 6)
+
+ // ensure that storage is clean
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+ assert.False(t, v["key2"])
+}
+
+func TestNilAndWrongArgs(t *testing.T) {
+ s := initStorage()
+ ctx := context.Background()
+ defer func() {
+ if err := s.Close(); err != nil {
+ panic(err)
+ }
+ }()
+
+ // check
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+
+ _, err = s.Has(ctx, "")
+ assert.Error(t, err)
+
+ _, err = s.Get(ctx, "")
+ assert.Error(t, err)
+
+ _, err = s.Get(ctx, " ")
+ assert.Error(t, err)
+
+ _, err = s.Get(ctx, " ")
+ assert.Error(t, err)
+
+ _, err = s.MGet(ctx, "key", "key2", "")
+ assert.Error(t, err)
+
+ _, err = s.MGet(ctx, "key", "key2", " ")
+ assert.Error(t, err)
+
+ assert.NoError(t, s.Set(ctx, kv.Item{}))
+ _, err = s.Has(ctx, "key")
+ assert.NoError(t, err)
+
+ err = s.Delete(ctx, "")
+ assert.Error(t, err)
+
+ err = s.Delete(ctx, "key", "")
+ assert.Error(t, err)
+
+ err = s.Delete(ctx, "key", " ")
+ assert.Error(t, err)
+
+ err = s.Delete(ctx, "key")
+ assert.NoError(t, err)
+}
+
+func TestStorage_SetExpire_TTL(t *testing.T) {
+ s := initStorage()
+ ctx := context.Background()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+ if err := s.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ // ensure that storage is clean
+ v, err := s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+ assert.False(t, v["key2"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "hello world",
+ TTL: "",
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "hello world",
+ TTL: "",
+ }))
+
+ nowPlusFive := time.Now().Add(time.Second * 5).Format(time.RFC3339)
+
+ // set timeout to 5 sec
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "value",
+ TTL: nowPlusFive,
+ },
+ kv.Item{
+ Key: "key2",
+ Value: "value",
+ TTL: nowPlusFive,
+ }))
+
+ time.Sleep(time.Second * 2)
+ m, err := s.TTL(ctx, "key", "key2")
+ assert.NoError(t, err)
+
+ // remove a precision 4.02342342 -> 4
+ keyTTL, err := strconv.Atoi(m["key"].(string)[0:1])
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // remove a precision 4.02342342 -> 4
+ key2TTL, err := strconv.Atoi(m["key"].(string)[0:1])
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ assert.True(t, keyTTL < 5)
+ assert.True(t, key2TTL < 5)
+
+ time.Sleep(time.Second * 4)
+
+ // ensure that storage is clean
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ assert.False(t, v["key"])
+ assert.False(t, v["key2"])
+}
+
+func TestConcurrentReadWriteTransactions(t *testing.T) {
+ s := initStorage()
+ defer func() {
+ cleanup(t, s, "key", "key2")
+ if err := s.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ ctx := context.Background()
+ v, err := s.Has(ctx, "key")
+ assert.NoError(t, err)
+ // no such key
+ assert.False(t, v["key"])
+
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key",
+ Value: "hello world",
+ TTL: "",
+ }, kv.Item{
+ Key: "key2",
+ Value: "hello world",
+ TTL: "",
+ }))
+
+ v, err = s.Has(ctx, "key", "key2")
+ assert.NoError(t, err)
+ // no such key
+ assert.True(t, v["key"])
+ assert.True(t, v["key2"])
+
+ wg := &sync.WaitGroup{}
+ wg.Add(3)
+
+ m := &sync.RWMutex{}
+ // concurrently set the keys
+ go func(s kv.Storage) {
+ defer wg.Done()
+ for i := 0; i <= 1000; i++ {
+ m.Lock()
+ // set is writable transaction
+ // it should stop readable
+ assert.NoError(t, s.Set(ctx, kv.Item{
+ Key: "key" + strconv.Itoa(i),
+ Value: "hello world" + strconv.Itoa(i),
+ TTL: "",
+ }, kv.Item{
+ Key: "key2" + strconv.Itoa(i),
+ Value: "hello world" + strconv.Itoa(i),
+ TTL: "",
+ }))
+ m.Unlock()
+ }
+ }(s)
+
+ // should be no errors
+ go func(s kv.Storage) {
+ defer wg.Done()
+ for i := 0; i <= 1000; i++ {
+ m.RLock()
+ v, err = s.Has(ctx, "key")
+ assert.NoError(t, err)
+ // no such key
+ assert.True(t, v["key"])
+ m.RUnlock()
+ }
+ }(s)
+
+ // should be no errors
+ go func(s kv.Storage) {
+ defer wg.Done()
+ for i := 0; i <= 1000; i++ {
+ m.Lock()
+ err = s.Delete(ctx, "key"+strconv.Itoa(i))
+ assert.NoError(t, err)
+ m.Unlock()
+ }
+ }(s)
+
+ wg.Wait()
+}