From 8b64d4afc6f55acb4ee05ca455af193f064b9aad Mon Sep 17 00:00:00 2001 From: Vivek R Date: Wed, 29 May 2024 12:39:00 +0530 Subject: [PATCH] refactor: conv package to v3 spec --- conv/conv.go | 3 +- examples/fastglue-goredis/go.mod | 2 +- examples/fastglue-goredis/main.go | 2 +- examples/fasthttp-redis/go.mod | 2 +- examples/fasthttp-redis/main.go | 2 +- examples/nethttp-inmemory/go.mod | 2 +- examples/nethttp-inmemory/main.go | 2 +- examples/nethttp-redis/go.mod | 2 +- examples/nethttp-redis/main.go | 2 +- examples/nethttp-secure-cookie/go.mod | 2 +- examples/nethttp-secure-cookie/main.go | 2 +- go.mod | 2 +- go.work | 1 - stores/goredis/go.mod | 20 - stores/goredis/store.go | 341 -------------- stores/goredis/store_test.go | 477 ------------------- stores/redis/go.mod | 7 +- stores/redis/store.go | 262 ++++++----- stores/redis/store_test.go | 608 +++++++++++-------------- 19 files changed, 404 insertions(+), 1337 deletions(-) delete mode 100644 stores/goredis/go.mod delete mode 100644 stores/goredis/store.go delete mode 100644 stores/goredis/store_test.go diff --git a/conv/conv.go b/conv/conv.go index c78783d..6d46369 100644 --- a/conv/conv.go +++ b/conv/conv.go @@ -9,9 +9,8 @@ var ( // Error codes for store errors. This should match the codes // defined in the /simplesessions package exactly. ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrNil = &Err{code: 2, msg: "nil returned"} ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} ) type Err struct { diff --git a/examples/fastglue-goredis/go.mod b/examples/fastglue-goredis/go.mod index 7c4f562..aeb2984 100644 --- a/examples/fastglue-goredis/go.mod +++ b/examples/fastglue-goredis/go.mod @@ -8,6 +8,6 @@ require ( github.com/redis/go-redis/v9 v9.5.1 github.com/valyala/fasthttp v1.52.0 github.com/vividvilla/simplesessions/stores/goredis/v9 v9.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 + github.com/vividvilla/simplesessions/v3 v2.0.0 github.com/zerodha/fastglue v1.8.0 ) diff --git a/examples/fastglue-goredis/main.go b/examples/fastglue-goredis/main.go index 6475d68..1f96f85 100644 --- a/examples/fastglue-goredis/main.go +++ b/examples/fastglue-goredis/main.go @@ -9,7 +9,7 @@ import ( "github.com/redis/go-redis/v9" "github.com/valyala/fasthttp" redisstore "github.com/vividvilla/simplesessions/stores/goredis/v9" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/v3" "github.com/zerodha/fastglue" ) diff --git a/examples/fasthttp-redis/go.mod b/examples/fasthttp-redis/go.mod index 84fe957..384a419 100644 --- a/examples/fasthttp-redis/go.mod +++ b/examples/fasthttp-redis/go.mod @@ -6,7 +6,7 @@ require ( github.com/gomodule/redigo v2.0.0+incompatible github.com/valyala/fasthttp v1.52.0 github.com/vividvilla/simplesessions/stores/redis/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 + github.com/vividvilla/simplesessions/v3 v2.0.0 ) require ( diff --git a/examples/fasthttp-redis/main.go b/examples/fasthttp-redis/main.go index 0ac0035..9ba97e9 100644 --- a/examples/fasthttp-redis/main.go +++ b/examples/fasthttp-redis/main.go @@ -8,7 +8,7 @@ import ( "github.com/gomodule/redigo/redis" "github.com/valyala/fasthttp" redisstore "github.com/vividvilla/simplesessions/stores/redis/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/v3" ) var ( diff --git a/examples/nethttp-inmemory/go.mod b/examples/nethttp-inmemory/go.mod index 2ed9491..875c58b 100644 --- a/examples/nethttp-inmemory/go.mod +++ b/examples/nethttp-inmemory/go.mod @@ -4,5 +4,5 @@ go 1.14 require ( github.com/vividvilla/simplesessions/stores/memory/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 + github.com/vividvilla/simplesessions/v3 v2.0.0 ) diff --git a/examples/nethttp-inmemory/main.go b/examples/nethttp-inmemory/main.go index 6382b34..1b67e1f 100644 --- a/examples/nethttp-inmemory/main.go +++ b/examples/nethttp-inmemory/main.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/vividvilla/simplesessions/stores/memory/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/v3" ) var ( diff --git a/examples/nethttp-redis/go.mod b/examples/nethttp-redis/go.mod index ee59b93..31053d9 100644 --- a/examples/nethttp-redis/go.mod +++ b/examples/nethttp-redis/go.mod @@ -5,5 +5,5 @@ go 1.14 require ( github.com/gomodule/redigo v2.0.0+incompatible github.com/vividvilla/simplesessions/stores/redis/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 + github.com/vividvilla/simplesessions/v3 v2.0.0 ) diff --git a/examples/nethttp-redis/main.go b/examples/nethttp-redis/main.go index db7fc68..6e0dd64 100644 --- a/examples/nethttp-redis/main.go +++ b/examples/nethttp-redis/main.go @@ -8,7 +8,7 @@ import ( "github.com/gomodule/redigo/redis" redisstore "github.com/vividvilla/simplesessions/stores/redis/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/v3" ) var ( diff --git a/examples/nethttp-secure-cookie/go.mod b/examples/nethttp-secure-cookie/go.mod index f8debc2..c55d285 100644 --- a/examples/nethttp-secure-cookie/go.mod +++ b/examples/nethttp-secure-cookie/go.mod @@ -4,7 +4,7 @@ go 1.14 require ( github.com/vividvilla/simplesessions/stores/securecookie/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 + github.com/vividvilla/simplesessions/v3 v2.0.0 ) require github.com/gorilla/securecookie v1.1.2 // indirect diff --git a/examples/nethttp-secure-cookie/main.go b/examples/nethttp-secure-cookie/main.go index 29fd47d..f956d45 100644 --- a/examples/nethttp-secure-cookie/main.go +++ b/examples/nethttp-secure-cookie/main.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/vividvilla/simplesessions/stores/securecookie/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/v3" ) var ( diff --git a/go.mod b/go.mod index 3aa6170..df9cb90 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/vividvilla/simplesessions/v2 +module github.com/vividvilla/simplesessions/v3 require github.com/stretchr/testify v1.9.0 diff --git a/go.work b/go.work index f692f89..fbc2fb4 100644 --- a/go.work +++ b/go.work @@ -3,7 +3,6 @@ go 1.14 use ( . ./conv - ./stores/goredis ./stores/memory ./stores/redis ./stores/securecookie diff --git a/stores/goredis/go.mod b/stores/goredis/go.mod deleted file mode 100644 index b268fed..0000000 --- a/stores/goredis/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/vividvilla/simplesessions/stores/goredis/v9 - -go 1.18 - -require ( - github.com/alicebob/miniredis/v2 v2.32.1 - github.com/redis/go-redis/v9 v9.5.1 - github.com/stretchr/testify v1.9.0 - github.com/vividvilla/simplesessions/conv v1.0.0 -) - -require ( - github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/yuin/gopher-lua v1.1.1 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/stores/goredis/store.go b/stores/goredis/store.go deleted file mode 100644 index 82c159f..0000000 --- a/stores/goredis/store.go +++ /dev/null @@ -1,341 +0,0 @@ -package goredis - -import ( - "context" - "crypto/rand" - "time" - "unicode" - - "github.com/redis/go-redis/v9" - "github.com/vividvilla/simplesessions/conv" -) - -var ( - // Error codes for store errors. This should match the codes - // defined in the /simplesessions package exactly. - ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} - ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} -) - -type Err struct { - code int - msg string -} - -func (e *Err) Error() string { - return e.msg -} - -func (e *Err) Code() int { - return e.code -} - -// Store represents redis session store for simple sessions. -// Each session is stored as redis hashmap. -type Store struct { - // Maximum lifetime sessions has to be persisted. - ttl time.Duration - - // Prefix for session id. - prefix string - - // Redis client - client redis.UniversalClient - clientCtx context.Context -} - -const ( - // Default prefix used to store session redis - defaultPrefix = "session:" - sessionIDLen = 32 -) - -// New creates a new Redis store instance. -func New(ctx context.Context, client redis.UniversalClient) *Store { - return &Store{ - clientCtx: ctx, - client: client, - prefix: defaultPrefix, - } -} - -// SetPrefix sets session id prefix in backend -func (s *Store) SetPrefix(val string) { - s.prefix = val -} - -// SetTTL sets TTL for session in redis. -func (s *Store) SetTTL(d time.Duration) { - s.ttl = d -} - -// Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created. -func (s *Store) Create() (string, error) { - id, err := generateID(sessionIDLen) - if err != nil { - return "", err - } - - return id, err -} - -// Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised -func (s *Store) Get(id, key string) (interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - get := pipe.HGet(s.clientCtx, s.prefix+id, key) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { - return nil, err - } - - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return nil, err - } else if ex == 0 { - return nil, ErrInvalidSession - } - - v, err := get.Result() - if err != nil && err == redis.Nil { - return nil, ErrFieldNotFound - } - - return v, nil -} - -// GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. -func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - get := pipe.HMGet(s.clientCtx, s.prefix+id, keys...) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { - return nil, err - } - - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return nil, err - } else if ex == 0 { - return nil, ErrInvalidSession - } - - v, err := get.Result() - if err != nil { - return nil, err - } - - // Form a map with returned results - res := make(map[string]interface{}) - for i, k := range keys { - if v[i] == nil { - res[k] = ErrFieldNotFound - } else { - res[k] = v[i] - } - } - - return res, err -} - -// GetAll gets all fields from hashmap. -func (s *Store) GetAll(id string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - get := pipe.HGetAll(s.clientCtx, s.prefix+id) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { - return nil, err - } - - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return nil, err - } else if ex == 0 { - return nil, ErrInvalidSession - } - - res, err := get.Result() - if err != nil { - return nil, err - } - - // Convert results to type `map[string]interface{}` - out := make(map[string]interface{}, len(res)) - for k, v := range res { - out[k] = v - } - - return out, nil -} - -// Set sets a value to given session. -func (s *Store) Set(id, key string, val interface{}) error { - if !validateID(id) { - return ErrInvalidSession - } - - pipe := s.client.TxPipeline() - pipe.HSet(s.clientCtx, s.prefix+id, key, val) - - // Set expiry of key only if 'ttl' is set, this is to - // ensure that the key remains valid indefinitely like - // how redis handles it by default - if s.ttl > 0 { - pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) - } - - _, err := pipe.Exec(s.clientCtx) - return err -} - -// Set sets a value to given session. -func (s *Store) SetMulti(id string, data map[string]interface{}) error { - if !validateID(id) { - return ErrInvalidSession - } - - // Make slice of arguments to be passed in HGETALL command - args := []interface{}{} - for k, v := range data { - args = append(args, k, v) - } - - pipe := s.client.TxPipeline() - pipe.HMSet(s.clientCtx, s.prefix+id, args...) - // Set expiry of key only if 'ttl' is set, this is to - // ensure that the key remains valid indefinitely like - // how redis handles it by default - if s.ttl > 0 { - pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) - } - - _, err := pipe.Exec(s.clientCtx) - return err -} - -// Delete deletes a key from redis session hashmap. -func (s *Store) Delete(id string, key string) error { - if !validateID(id) { - return ErrInvalidSession - } - - pipe := s.client.TxPipeline() - exists := pipe.Exists(s.clientCtx, s.prefix+id) - del := pipe.HDel(s.clientCtx, s.prefix+id, key) - _, err := pipe.Exec(s.clientCtx) - // redis.Nil is returned if a field does not exist. - // Ignore the error and check for key existence check. - if err != nil && err != redis.Nil { - return err - } - - // Check if key exists and return ErrInvalidSession if not. - if ex, err := exists.Result(); err != nil { - return err - } else if ex == 0 { - return ErrInvalidSession - } - - if v, err := del.Result(); err != nil { - return err - } else if v == 0 { - return ErrFieldNotFound - } - - return nil -} - -// Clear clears session in redis. -func (s *Store) Clear(id string) error { - if !validateID(id) { - return ErrInvalidSession - } - - return s.client.Del(s.clientCtx, s.prefix+id).Err() -} - -// Int returns redis reply as integer. -func (s *Store) Int(r interface{}, err error) (int, error) { - return conv.Int(r, err) -} - -// Int64 returns redis reply as Int64. -func (s *Store) Int64(r interface{}, err error) (int64, error) { - return conv.Int64(r, err) -} - -// UInt64 returns redis reply as UInt64. -func (s *Store) UInt64(r interface{}, err error) (uint64, error) { - return conv.UInt64(r, err) -} - -// Float64 returns redis reply as Float64. -func (s *Store) Float64(r interface{}, err error) (float64, error) { - return conv.Float64(r, err) -} - -// String returns redis reply as String. -func (s *Store) String(r interface{}, err error) (string, error) { - return conv.String(r, err) -} - -// Bytes returns redis reply as Bytes. -func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { - return conv.Bytes(r, err) -} - -// Bool returns redis reply as Bool. -func (s *Store) Bool(r interface{}, err error) (bool, error) { - return conv.Bool(r, err) -} - -func validateID(id string) bool { - if len(id) != sessionIDLen { - return false - } - - for _, r := range id { - if !unicode.IsDigit(r) && !unicode.IsLetter(r) { - return false - } - } - - return true -} - -// generateID generates a random alpha-num session ID. -func generateID(n int) (string, error) { - const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - bytes := make([]byte, n) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - for k, v := range bytes { - bytes[k] = dict[v%byte(len(dict))] - } - - return string(bytes), nil -} diff --git a/stores/goredis/store_test.go b/stores/goredis/store_test.go deleted file mode 100644 index 332b54e..0000000 --- a/stores/goredis/store_test.go +++ /dev/null @@ -1,477 +0,0 @@ -package goredis - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" - "github.com/stretchr/testify/assert" -) - -var ( - mockRedis *miniredis.Miniredis -) - -func init() { - var err error - mockRedis, err = miniredis.Run() - if err != nil { - panic(err) - } -} - -func getRedisClient() redis.UniversalClient { - return redis.NewClient(&redis.Options{ - Addr: mockRedis.Addr(), - }) -} - -func TestNew(t *testing.T) { - client := getRedisClient() - ctx := context.Background() - str := New(ctx, client) - assert.Equal(t, str.prefix, defaultPrefix) - assert.Equal(t, str.client, client) - assert.Equal(t, str.clientCtx, ctx) -} - -func TestSetPrefix(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - str.SetPrefix("test") - assert.Equal(t, str.prefix, "test") -} - -func TestSetTTL(t *testing.T) { - testDur := time.Second * 10 - str := New(context.TODO(), getRedisClient()) - str.SetTTL(testDur) - assert.Equal(t, str.ttl, testDur) -} - -func TestCreate(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - id, err := str.Create() - assert.Nil(t, err) - assert.Equal(t, len(id), sessionIDLen) -} - -func TestGet(t *testing.T) { - key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somekey" - value := 100 - client := getRedisClient() - - // Set a key - err := client.HSet(context.TODO(), defaultPrefix+key, field, value).Err() - assert.NoError(t, err) - - str := New(context.TODO(), client) - - val, err := str.Int(str.Get(key, field)) - assert.NoError(t, err) - assert.Equal(t, val, value) - - // Check for invalid key. - _, err = str.Int(str.Get(key, "invalidfield")) - assert.ErrorIs(t, ErrFieldNotFound, err) -} - -func TestGetInvalidSession(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, err, ErrInvalidSession) - - id := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err = str.Get(id, "invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) -} - -func TestGetMultiInvalidSession(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) - - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somefield" - _, err = str.GetMulti(key, field) - assert.ErrorIs(t, err, ErrInvalidSession) -} - -func TestGetMulti(t *testing.T) { - var ( - key = "5dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 = "somekey" - value1 = 100 - field2 = "someotherkey" - value2 = "abc123" - field3 = "thishouldntbethere" - value3 = 100.10 - invalidField = "foo" - client = getRedisClient() - ) - - // Set a key - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() - assert.NoError(t, err) - - str := New(context.TODO(), client) - vals, err := str.GetMulti(key, field1, field2, invalidField) - assert.NoError(t, err) - assert.Contains(t, vals, field1) - assert.Contains(t, vals, field2) - assert.NotContains(t, vals, field3) - - val1, err := str.Int(vals[field1], nil) - assert.NoError(t, err) - assert.Equal(t, val1, value1) - - val2, err := str.String(vals[field2], nil) - assert.NoError(t, err) - assert.Equal(t, val2, value2) - - // Check for invalid key. - _, err = str.String(vals[invalidField], nil) - assert.ErrorIs(t, ErrFieldNotFound, err) -} - -func TestGetAllInvalidSession(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - val, err := str.GetAll("invalidkey") - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) - - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err = str.GetAll(key) - assert.Nil(t, val) - assert.ErrorIs(t, ErrInvalidSession, err) -} - -func TestGetAll(t *testing.T) { - key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - client := getRedisClient() - - // Set a key - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() - assert.NoError(t, err) - - str := New(context.TODO(), client) - - vals, err := str.GetAll(key) - assert.NoError(t, err) - assert.Contains(t, vals, field1) - assert.Contains(t, vals, field2) - assert.Contains(t, vals, field3) - - val1, err := str.Int(vals[field1], nil) - assert.NoError(t, err) - assert.Equal(t, val1, value1) - - val2, err := str.String(vals[field2], nil) - assert.NoError(t, err) - assert.Equal(t, val2, value2) - - val3, err := str.Float64(vals[field3], nil) - assert.NoError(t, err) - assert.Equal(t, val3, value3) -} - -func TestSetInvalidSessionError(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - err := str.Set("invalidid", "key", "value") - assert.ErrorIs(t, ErrInvalidSession, err) -} - -func TestSet(t *testing.T) { - // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - ttl := time.Second * 10 - str.SetTTL(ttl) - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field := "somekey" - value := 100 - - err := str.Set(key, field, value) - assert.NoError(t, err) - - // Check ifs not commited to redis - v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(t, err) - assert.Equal(t, int64(1), v1) - - v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, v2) - - dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() - assert.NoError(t, err) - assert.Equal(t, dur, ttl) -} - -func TestSetMulti(t *testing.T) { - // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - ttl := time.Second * 10 - str.SetTTL(ttl) - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field1 := "somekey1" - value1 := 100 - field2 := "somekey2" - value2 := "somevalue" - - err := str.SetMulti(key, map[string]interface{}{ - field1: value1, - field2: value2, - }) - assert.NoError(t, err) - - // Check ifs not commited to redis - v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(t, err) - assert.Equal(t, int64(1), v1) - - v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field1).Result()) - assert.NoError(t, err) - assert.Equal(t, value1, v2) - - dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() - assert.NoError(t, err) - assert.Equal(t, dur, ttl) -} - -func TestDeleteInvalidSessionError(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - err := str.Delete("invalidkey", "somefield") - assert.ErrorIs(t, ErrInvalidSession, err) - - str = New(context.TODO(), getRedisClient()) - err = str.Delete("8dIHy6S2uBuKaNnTUszB2180898ikGY1", "somefield") - assert.ErrorIs(t, ErrInvalidSession, err) -} - -func TestDelete(t *testing.T) { - // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() - assert.NoError(t, err) - - err = str.Delete(key, field1) - assert.NoError(t, err) - - val, err := client.HExists(context.TODO(), defaultPrefix+key, field1).Result() - assert.False(t, val) - assert.NoError(t, err) - - val, err = client.HExists(context.TODO(), defaultPrefix+key, field2).Result() - assert.True(t, val) - assert.NoError(t, err) - - err = str.Delete(key, "xxxxx") - assert.ErrorIs(t, err, ErrFieldNotFound) -} - -func TestClearInvalidSessionError(t *testing.T) { - str := New(context.TODO(), getRedisClient()) - err := str.Clear("invalidkey") - assert.ErrorIs(t, ErrInvalidSession, err) -} - -func TestClear(t *testing.T) { - // Test should only set in internal map and not in redis - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() - assert.NoError(t, err) - - // Check if its set - val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(t, err) - assert.NotEqual(t, val, int64(0)) - - err = str.Clear(key) - assert.NoError(t, err) - - val, err = client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(t, err) - assert.Equal(t, val, int64(0)) -} - -func TestInt(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - value := 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.Int(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.Int(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestInt64(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value int64 = 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.Int64(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.Int64(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestUInt64(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value uint64 = 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.UInt64(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.UInt64(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestFloat64(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value float64 = 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.Float64(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.Float64(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestString(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - value := "abc123" - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.String(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.String(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestBytes(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value []byte = []byte("abc123") - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.Bytes(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.Bytes(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestBool(t *testing.T) { - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - value := true - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(t, err) - - val, err := str.Bool(client.Get(context.TODO(), field).Result()) - assert.NoError(t, err) - assert.Equal(t, value, val) - - testError := errors.New("test error") - _, err = str.Bool(value, testError) - assert.ErrorIs(t, testError, err) -} - -func TestValidateID(t *testing.T) { - ok := validateID("xxxx") - assert.False(t, ok) - - ok = validateID("8dIHy6S2uBuKaNnTUszB2180898ikGY&") - assert.False(t, ok) - - id, err := generateID(sessionIDLen) - assert.NoError(t, err) - ok = validateID(id) - assert.True(t, ok) -} diff --git a/stores/redis/go.mod b/stores/redis/go.mod index b15a078..b268fed 100644 --- a/stores/redis/go.mod +++ b/stores/redis/go.mod @@ -1,16 +1,19 @@ -module github.com/vividvilla/simplesessions/stores/redis/v2 +module github.com/vividvilla/simplesessions/stores/goredis/v9 go 1.18 require ( github.com/alicebob/miniredis/v2 v2.32.1 - github.com/gomodule/redigo v1.9.2 + github.com/redis/go-redis/v9 v9.5.1 github.com/stretchr/testify v1.9.0 + github.com/vividvilla/simplesessions/conv v1.0.0 ) require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/stores/redis/store.go b/stores/redis/store.go index 50977dd..82c159f 100644 --- a/stores/redis/store.go +++ b/stores/redis/store.go @@ -1,13 +1,13 @@ -package redis +package goredis import ( + "context" "crypto/rand" - "errors" - "sync" "time" "unicode" - "github.com/gomodule/redigo/redis" + "github.com/redis/go-redis/v9" + "github.com/vividvilla/simplesessions/conv" ) var ( @@ -41,12 +41,9 @@ type Store struct { // Prefix for session id. prefix string - // Temp map to store values before commit. - tempSetMap map[string]map[string]interface{} - mu sync.RWMutex - - // Redis pool - pool *redis.Pool + // Redis client + client redis.UniversalClient + clientCtx context.Context } const ( @@ -56,11 +53,11 @@ const ( ) // New creates a new Redis store instance. -func New(pool *redis.Pool) *Store { +func New(ctx context.Context, client redis.UniversalClient) *Store { return &Store{ - pool: pool, - prefix: defaultPrefix, - tempSetMap: make(map[string]map[string]interface{}), + clientCtx: ctx, + client: client, + prefix: defaultPrefix, } } @@ -90,15 +87,29 @@ func (s *Store) Get(id, key string) (interface{}, error) { return nil, ErrInvalidSession } - conn := s.pool.Get() - defer conn.Close() + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + get := pipe.HGet(s.clientCtx, s.prefix+id, key) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return nil, err + } - v, err := conn.Do("HGET", s.prefix+id, key) - if v == nil || err == redis.ErrNil { + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return nil, err + } else if ex == 0 { + return nil, ErrInvalidSession + } + + v, err := get.Result() + if err != nil && err == redis.Nil { return nil, ErrFieldNotFound } - return v, err + return v, nil } // GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. @@ -107,26 +118,36 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err return nil, ErrInvalidSession } - conn := s.pool.Get() - defer conn.Close() + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + get := pipe.HMGet(s.clientCtx, s.prefix+id, keys...) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return nil, err + } - // Make list of args for HMGET - args := make([]interface{}, len(keys)+1) - args[0] = s.prefix + id - for i := range keys { - args[i+1] = keys[i] + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return nil, err + } else if ex == 0 { + return nil, ErrInvalidSession } - v, err := redis.Values(conn.Do("HMGET", args...)) - // If field is not found then return map with fields as nil - if len(v) == 0 || err == redis.ErrNil { - v = make([]interface{}, len(keys)) + v, err := get.Result() + if err != nil { + return nil, err } // Form a map with returned results res := make(map[string]interface{}) for i, k := range keys { - res[k] = v[i] + if v[i] == nil { + res[k] = ErrFieldNotFound + } else { + res[k] = v[i] + } } return res, err @@ -138,89 +159,80 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) { return nil, ErrInvalidSession } - conn := s.pool.Get() - defer conn.Close() + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + get := pipe.HGetAll(s.clientCtx, s.prefix+id) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return nil, err + } + + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return nil, err + } else if ex == 0 { + return nil, ErrInvalidSession + } + + res, err := get.Result() + if err != nil { + return nil, err + } + + // Convert results to type `map[string]interface{}` + out := make(map[string]interface{}, len(res)) + for k, v := range res { + out[k] = v + } - return s.interfaceMap(conn.Do("HGETALL", s.prefix+id)) + return out, nil } -// Set sets a value to given session but stored only on commit +// Set sets a value to given session. func (s *Store) Set(id, key string, val interface{}) error { if !validateID(id) { return ErrInvalidSession } - s.mu.Lock() - defer s.mu.Unlock() + pipe := s.client.TxPipeline() + pipe.HSet(s.clientCtx, s.prefix+id, key, val) - // Create session map if doesn't exist - if _, ok := s.tempSetMap[id]; !ok { - s.tempSetMap[id] = make(map[string]interface{}) + // Set expiry of key only if 'ttl' is set, this is to + // ensure that the key remains valid indefinitely like + // how redis handles it by default + if s.ttl > 0 { + pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) } - // set value to map - s.tempSetMap[id][key] = val - - return nil + _, err := pipe.Exec(s.clientCtx) + return err } -// Commit sets all set values -func (s *Store) Commit(id string) error { +// Set sets a value to given session. +func (s *Store) SetMulti(id string, data map[string]interface{}) error { if !validateID(id) { return ErrInvalidSession } - s.mu.RLock() - vals, ok := s.tempSetMap[id] - if !ok { - // Nothing to commit - s.mu.RUnlock() - return nil - } - // Make slice of arguments to be passed in HGETALL command - args := make([]interface{}, len(vals)*2+1, len(vals)*2+1) - args[0] = s.prefix + id - - c := 1 - for k, v := range s.tempSetMap[id] { - args[c] = k - args[c+1] = v - c += 2 + args := []interface{}{} + for k, v := range data { + args = append(args, k, v) } - s.mu.RUnlock() - - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() - - // Set to redis - conn := s.pool.Get() - defer conn.Close() - - conn.Send("MULTI") - conn.Send("HMSET", args...) + pipe := s.client.TxPipeline() + pipe.HMSet(s.clientCtx, s.prefix+id, args...) // Set expiry of key only if 'ttl' is set, this is to // ensure that the key remains valid indefinitely like // how redis handles it by default if s.ttl > 0 { - conn.Send("EXPIRE", args[0], s.ttl.Seconds()) - } - - res, err := redis.Values(conn.Do("EXEC")) - if err != nil { - return err - } - - for _, r := range res { - if v, ok := r.(redis.Error); ok { - return v - } + pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) } - return nil + _, err := pipe.Exec(s.clientCtx) + return err } // Delete deletes a key from redis session hashmap. @@ -229,16 +241,30 @@ func (s *Store) Delete(id string, key string) error { return ErrInvalidSession } - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + del := pipe.HDel(s.clientCtx, s.prefix+id, key) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return err + } - conn := s.pool.Get() - defer conn.Close() + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return err + } else if ex == 0 { + return ErrInvalidSession + } - _, err := conn.Do("HDEL", s.prefix+id, key) - return err + if v, err := del.Result(); err != nil { + return err + } else if v == 0 { + return ErrFieldNotFound + } + + return nil } // Clear clears session in redis. @@ -247,70 +273,42 @@ func (s *Store) Clear(id string) error { return ErrInvalidSession } - conn := s.pool.Get() - defer conn.Close() - - _, err := conn.Do("DEL", s.prefix+id) - return err -} - -// interfaceMap is a helper method which converts HGETALL reply to map of string interface -func (s *Store) interfaceMap(result interface{}, err error) (map[string]interface{}, error) { - values, err := redis.Values(result, err) - if err != nil { - return nil, err - } - - if len(values)%2 != 0 { - return nil, errors.New("redigo: StringMap expects even number of values result") - } - - m := make(map[string]interface{}, len(values)/2) - for i := 0; i < len(values); i += 2 { - key, ok := values[i].([]byte) - if !ok { - return nil, errors.New("redigo: StringMap key not a bulk string value") - } - - m[string(key)] = values[i+1] - } - - return m, nil + return s.client.Del(s.clientCtx, s.prefix+id).Err() } // Int returns redis reply as integer. func (s *Store) Int(r interface{}, err error) (int, error) { - return redis.Int(r, err) + return conv.Int(r, err) } // Int64 returns redis reply as Int64. func (s *Store) Int64(r interface{}, err error) (int64, error) { - return redis.Int64(r, err) + return conv.Int64(r, err) } // UInt64 returns redis reply as UInt64. func (s *Store) UInt64(r interface{}, err error) (uint64, error) { - return redis.Uint64(r, err) + return conv.UInt64(r, err) } // Float64 returns redis reply as Float64. func (s *Store) Float64(r interface{}, err error) (float64, error) { - return redis.Float64(r, err) + return conv.Float64(r, err) } // String returns redis reply as String. func (s *Store) String(r interface{}, err error) (string, error) { - return redis.String(r, err) + return conv.String(r, err) } // Bytes returns redis reply as Bytes. func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { - return redis.Bytes(r, err) + return conv.Bytes(r, err) } // Bool returns redis reply as Bool. func (s *Store) Bool(r interface{}, err error) (bool, error) { - return redis.Bool(r, err) + return conv.Bool(r, err) } func validateID(id string) bool { diff --git a/stores/redis/store_test.go b/stores/redis/store_test.go index f9bcd5e..332b54e 100644 --- a/stores/redis/store_test.go +++ b/stores/redis/store_test.go @@ -1,12 +1,13 @@ -package redis +package goredis import ( + "context" "errors" "testing" "time" "github.com/alicebob/miniredis/v2" - "github.com/gomodule/redigo/redis" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" ) @@ -22,156 +23,136 @@ func init() { } } -func getRedisPool() *redis.Pool { - return &redis.Pool{ - Wait: true, - Dial: func() (redis.Conn, error) { - c, err := redis.Dial( - "tcp", - mockRedis.Addr(), - ) - - return c, err - }, - } +func getRedisClient() redis.UniversalClient { + return redis.NewClient(&redis.Options{ + Addr: mockRedis.Addr(), + }) } func TestNew(t *testing.T) { - assert := assert.New(t) - rPool := getRedisPool() - str := New(rPool) - assert.Equal(str.prefix, defaultPrefix) - assert.Equal(str.pool, rPool) - assert.NotNil(str.tempSetMap) + client := getRedisClient() + ctx := context.Background() + str := New(ctx, client) + assert.Equal(t, str.prefix, defaultPrefix) + assert.Equal(t, str.client, client) + assert.Equal(t, str.clientCtx, ctx) } func TestSetPrefix(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + str := New(context.TODO(), getRedisClient()) str.SetPrefix("test") - assert.Equal(str.prefix, "test") + assert.Equal(t, str.prefix, "test") } func TestSetTTL(t *testing.T) { - assert := assert.New(t) testDur := time.Second * 10 - str := New(getRedisPool()) + str := New(context.TODO(), getRedisClient()) str.SetTTL(testDur) - assert.Equal(str.ttl, testDur) + assert.Equal(t, str.ttl, testDur) } func TestCreate(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - + str := New(context.TODO(), getRedisClient()) id, err := str.Create() - assert.Nil(err) - assert.Equal(len(id), sessionIDLen) -} - -func TestGetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + assert.Nil(t, err) + assert.Equal(t, len(id), sessionIDLen) } func TestGet(t *testing.T) { - assert := assert.New(t) key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" field := "somekey" value := 100 - redisPool := getRedisPool() + client := getRedisClient() // Set a key - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HSET", defaultPrefix+key, field, value) - assert.NoError(err) - - str := New(redisPool) + err := client.HSet(context.TODO(), defaultPrefix+key, field, value).Err() + assert.NoError(t, err) - val, err := redis.Int(str.Get(key, field)) - assert.NoError(err) - assert.Equal(val, value) -} + str := New(context.TODO(), client) -func TestGetFieldNotFoundError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + val, err := str.Int(str.Get(key, field)) + assert.NoError(t, err) + assert.Equal(t, val, value) - key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(key, "invalidkey") - assert.Nil(val) - assert.Error(err, ErrFieldNotFound.Error()) + // Check for invalid key. + _, err = str.Int(str.Get(key, "invalidfield")) + assert.ErrorIs(t, ErrFieldNotFound, err) } -func TestGetMultiInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) +func TestGetInvalidSession(t *testing.T) { + str := New(context.TODO(), getRedisClient()) + val, err := str.Get("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + id := "10IHy6S2uBuKaNnTUszB218L898ikGY1" + val, err = str.Get(id, "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) } -func TestGetMultiFieldEmptySession(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) +func TestGetMultiInvalidSession(t *testing.T) { + str := New(context.TODO(), getRedisClient()) + val, err := str.GetMulti("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" field := "somefield" - _, err := str.GetMulti(key, field) - assert.Nil(err) + _, err = str.GetMulti(key, field) + assert.ErrorIs(t, err, ErrInvalidSession) } func TestGetMulti(t *testing.T) { - assert := assert.New(t) - key := "5dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - redisPool := getRedisPool() + var ( + key = "5dIHy6S2uBuKaNnTUszB218L898ikGY1" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + field3 = "thishouldntbethere" + value3 = 100.10 + invalidField = "foo" + client = getRedisClient() + ) // Set a key - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2, field3, value3) - assert.NoError(err) - - str := New(redisPool) - - vals, err := str.GetMulti(key, field1, field2) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.NotContains(vals, field3) - - val1, err := redis.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) - - val2, err := redis.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) -} + err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() + assert.NoError(t, err) + + str := New(context.TODO(), client) + vals, err := str.GetMulti(key, field1, field2, invalidField) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.NotContains(t, vals, field3) + + val1, err := str.Int(vals[field1], nil) + assert.NoError(t, err) + assert.Equal(t, val1, value1) + + val2, err := str.String(vals[field2], nil) + assert.NoError(t, err) + assert.Equal(t, val2, value2) -func TestGetAllInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + // Check for invalid key. + _, err = str.String(vals[invalidField], nil) + assert.ErrorIs(t, ErrFieldNotFound, err) +} +func TestGetAllInvalidSession(t *testing.T) { + str := New(context.TODO(), getRedisClient()) val, err := str.GetAll("invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) + + key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" + val, err = str.GetAll(key) + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) } func TestGetAll(t *testing.T) { - assert := assert.New(t) key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" field1 := "somekey" value1 := 100 @@ -179,136 +160,116 @@ func TestGetAll(t *testing.T) { value2 := "abc123" field3 := "thishouldntbethere" value3 := 100.10 - redisPool := getRedisPool() + client := getRedisClient() // Set a key - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2, field3, value3) - assert.NoError(err) + err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() + assert.NoError(t, err) - str := New(redisPool) + str := New(context.TODO(), client) vals, err := str.GetAll(key) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.Contains(vals, field3) - - val1, err := redis.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) - - val2, err := redis.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) - - val3, err := redis.Float64(vals[field3], nil) - assert.NoError(err) - assert.Equal(val3, value3) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.Contains(t, vals, field3) + + val1, err := str.Int(vals[field1], nil) + assert.NoError(t, err) + assert.Equal(t, val1, value1) + + val2, err := str.String(vals[field2], nil) + assert.NoError(t, err) + assert.Equal(t, val2, value2) + + val3, err := str.Float64(vals[field3], nil) + assert.NoError(t, err) + assert.Equal(t, val3, value3) } func TestSetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - + str := New(context.TODO(), getRedisClient()) err := str.Set("invalidid", "key", "value") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) } func TestSet(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) + ttl := time.Second * 10 + str.SetTTL(ttl) // this key is unique across all tests key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" field := "somekey" value := 100 - assert.NotNil(str.tempSetMap) - assert.NotContains(str.tempSetMap, key) - err := str.Set(key, field, value) - assert.NoError(err) - assert.Contains(str.tempSetMap, key) - assert.Contains(str.tempSetMap[key], field) - assert.Equal(str.tempSetMap[key][field], value) + assert.NoError(t, err) // Check ifs not commited to redis - conn := redisPool.Get() - defer conn.Close() - val, err := conn.Do("TTL", defaultPrefix+key) - assert.NoError(err) - // -2 represents key doesn't exist - assert.Equal(val, int64(-2)) -} - -func TestCommitInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - - err := str.Commit("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) -} + v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), v1) -func TestEmptyCommit(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, v2) - err := str.Commit("15IHy6S2uBuKaNnTUszB2180898ikGY1") - assert.NoError(err) + dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, dur, ttl) } -func TestCommit(t *testing.T) { - // Test should commit in redis with expiry on key - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - str.SetTTL(10 * time.Second) +func TestSetMulti(t *testing.T) { + // Test should only set in internal map and not in redis + client := getRedisClient() + str := New(context.TODO(), client) + ttl := time.Second * 10 + str.SetTTL(ttl) // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" + key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" + field1 := "somekey1" value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := str.Set(key, field1, value1) - assert.NoError(err) - - err = str.Set(key, field2, value2) - assert.NoError(err) + field2 := "somekey2" + value2 := "somevalue" - err = str.Commit(key) - assert.NoError(err) + err := str.SetMulti(key, map[string]interface{}{ + field1: value1, + field2: value2, + }) + assert.NoError(t, err) - conn := redisPool.Get() - defer conn.Close() - vals, err := redis.Values(conn.Do("HGETALL", defaultPrefix+key)) - assert.Equal(2*2, len(vals)) + // Check ifs not commited to redis + v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), v1) - ttl, err := redis.Int(conn.Do("TTL", defaultPrefix+key)) - assert.NoError(err) + v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field1).Result()) + assert.NoError(t, err) + assert.Equal(t, value1, v2) - assert.Equal(true, ttl > 0 && ttl <= 10) + dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, dur, ttl) } func TestDeleteInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - + str := New(context.TODO(), getRedisClient()) err := str.Delete("invalidkey", "somefield") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) + + str = New(context.TODO(), getRedisClient()) + err = str.Delete("8dIHy6S2uBuKaNnTUszB2180898ikGY1", "somefield") + assert.ErrorIs(t, ErrInvalidSession, err) } func TestDelete(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -317,34 +278,34 @@ func TestDelete(t *testing.T) { field2 := "someotherkey" value2 := "abc123" - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) - assert.NoError(err) + err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() + assert.NoError(t, err) err = str.Delete(key, field1) - assert.NoError(err) + assert.NoError(t, err) + + val, err := client.HExists(context.TODO(), defaultPrefix+key, field1).Result() + assert.False(t, val) + assert.NoError(t, err) - val, err := redis.Bool(conn.Do("HEXISTS", defaultPrefix+key, field1)) - assert.False(val) + val, err = client.HExists(context.TODO(), defaultPrefix+key, field2).Result() + assert.True(t, val) + assert.NoError(t, err) - val, err = redis.Bool(conn.Do("HEXISTS", defaultPrefix+key, field2)) - assert.True(val) + err = str.Delete(key, "xxxxx") + assert.ErrorIs(t, err, ErrFieldNotFound) } func TestClearInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - + str := New(context.TODO(), getRedisClient()) err := str.Clear("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) } func TestClear(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) // this key is unique across all tests key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" @@ -353,219 +314,164 @@ func TestClear(t *testing.T) { field2 := "someotherkey" value2 := "abc123" - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) - assert.NoError(err) + err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() + assert.NoError(t, err) // Check if its set - val, err := conn.Do("TTL", defaultPrefix+key) - assert.NoError(err) - // -2 represents key doesn't exist - assert.NotEqual(val, int64(-2)) + val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.NotEqual(t, val, int64(0)) err = str.Clear(key) - assert.NoError(err) + assert.NoError(t, err) - val, err = conn.Do("TTL", defaultPrefix+key) - assert.NoError(err) - // -2 represents key doesn't exist - assert.Equal(val, int64(-2)) -} - -func TestInterfaceMap(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) - assert.NoError(err) - - vals, err := str.interfaceMap(conn.Do("HGETALL", defaultPrefix+key)) - assert.Contains(vals, field1) - assert.Contains(vals, field2) -} - -func TestInterfaceMapWithError(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - testError := errors.New("test error") - vals, err := str.interfaceMap(nil, testError) - assert.Nil(vals) - assert.Error(err, testError.Error()) - - valsInfSlice := []interface{}{nil, nil, nil} - vals, err = str.interfaceMap(valsInfSlice, nil) - assert.Nil(vals) - assert.Equal(err.Error(), "redigo: StringMap expects even number of values result") - - valsInfSlice = []interface{}{"abc123", 123} - vals, err = str.interfaceMap(valsInfSlice, nil) - assert.Nil(vals) - assert.Equal(err.Error(), "redigo: StringMap key not a bulk string value") + val, err = client.Exists(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, val, int64(0)) } func TestInt(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" value := 100 - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.Int(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.Int(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Int(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Int(value, testError) + assert.ErrorIs(t, testError, err) } func TestInt64(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" var value int64 = 100 - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.Int64(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.Int64(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Int64(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Int64(value, testError) + assert.ErrorIs(t, testError, err) } func TestUInt64(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" var value uint64 = 100 - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.UInt64(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.UInt64(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.UInt64(value, testError) - assert.Error(err, testError.Error()) + _, err = str.UInt64(value, testError) + assert.ErrorIs(t, testError, err) } func TestFloat64(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" var value float64 = 100 - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.Float64(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.Float64(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Float64(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Float64(value, testError) + assert.ErrorIs(t, testError, err) } func TestString(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" value := "abc123" - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.String(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.String(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.String(value, testError) - assert.Error(err, testError.Error()) + _, err = str.String(value, testError) + assert.ErrorIs(t, testError, err) } func TestBytes(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" var value []byte = []byte("abc123") - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.Bytes(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.Bytes(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Bytes(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Bytes(value, testError) + assert.ErrorIs(t, testError, err) } func TestBool(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + client := getRedisClient() + str := New(context.TODO(), client) field := "somekey" value := true - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + err := client.Set(context.TODO(), field, value, 0).Err() + assert.NoError(t, err) - val, err := str.Bool(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + val, err := str.Bool(client.Get(context.TODO(), field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Bool(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Bool(value, testError) + assert.ErrorIs(t, testError, err) +} + +func TestValidateID(t *testing.T) { + ok := validateID("xxxx") + assert.False(t, ok) + + ok = validateID("8dIHy6S2uBuKaNnTUszB2180898ikGY&") + assert.False(t, ok) + + id, err := generateID(sessionIDLen) + assert.NoError(t, err) + ok = validateID(id) + assert.True(t, ok) }