Skip to content

Commit

Permalink
Merge pull request #31 from vividvilla/refactor-stores
Browse files Browse the repository at this point in the history
Refactor all stores to remove the simplesessions dependency and adhere
  • Loading branch information
vividvilla authored May 10, 2024
2 parents e4b1649 + cf36607 commit 3214678
Show file tree
Hide file tree
Showing 12 changed files with 448 additions and 580 deletions.
2 changes: 1 addition & 1 deletion stores/goredis/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/alicebob/miniredis/v2 v2.32.1
github.com/redis/go-redis/v9 v9.5.1
github.com/stretchr/testify v1.9.0
github.com/vividvilla/simplesessions v0.2.0
github.com/vividvilla/simplesessions/conv v1.0.0
)

require (
Expand Down
124 changes: 79 additions & 45 deletions stores/goredis/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,37 @@ package goredis

import (
"context"
"crypto/rand"
"sync"
"time"
"unicode"

"github.com/redis/go-redis/v9"
"github.com/vividvilla/simplesessions"
"github.com/vividvilla/simplesessions/conv"
)

var (
// Error codes for store errors. This should match the codes
// defined in the /simplesessions package exactly.
ErrInvalidSession = &Err{code: 1, msg: "invalid session"}
ErrFieldNotFound = &Err{code: 2, msg: "field not found"}
ErrAssertType = &Err{code: 3, msg: "assertion failed"}
ErrNil = &Err{code: 4, msg: "nil returned"}
)

type Err struct {
code int
msg string
}

func (e *Err) Error() string {
return e.msg
}

func (e *Err) Code() int {
return e.code
}

// Store represents redis session store for simple sessions.
// Each session is stored as redis hashmap.
type Store struct {
Expand Down Expand Up @@ -54,20 +77,9 @@ func (s *Store) SetTTL(d time.Duration) {
s.ttl = d
}

// isValidSessionID checks is the given session id is valid.
func (s *Store) isValidSessionID(sess *simplesessions.Session, id string) bool {
return len(id) == sessionIDLen && sess.IsValidRandomString(id)
}

// IsValid checks if the session is set for the id.
func (s *Store) IsValid(sess *simplesessions.Session, id string) (bool, error) {
// Validate session is valid generate string or not
return s.isValidSessionID(sess, id), nil
}

// Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created.
func (s *Store) Create(sess *simplesessions.Session) (string, error) {
id, err := sess.GenerateRandomString(sessionIDLen)
func (s *Store) Create() (string, error) {
id, err := generateID(sessionIDLen)
if err != nil {
return "", err
}
Expand All @@ -76,25 +88,23 @@ func (s *Store) Create(sess *simplesessions.Session) (string, error) {
}

// Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised
func (s *Store) Get(sess *simplesessions.Session, id, key string) (interface{}, error) {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return nil, simplesessions.ErrInvalidSession
func (s *Store) Get(id, key string) (interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

v, err := s.client.HGet(s.clientCtx, s.prefix+id, key).Result()
if err == redis.Nil {
return nil, simplesessions.ErrFieldNotFound
return nil, ErrFieldNotFound
}

return v, err
}

// GetMulti gets a map for values for multiple keys. If key is not found then its set as nil.
func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string) (map[string]interface{}, error) {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return nil, simplesessions.ErrInvalidSession
func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

v, err := s.client.HMGet(s.clientCtx, s.prefix+id, keys...).Result()
Expand All @@ -113,10 +123,9 @@ func (s *Store) GetMulti(sess *simplesessions.Session, id string, keys ...string
}

// GetAll gets all fields from hashmap.
func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]interface{}, error) {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return nil, simplesessions.ErrInvalidSession
func (s *Store) GetAll(id string) (map[string]interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

res, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result()
Expand All @@ -136,10 +145,9 @@ func (s *Store) GetAll(sess *simplesessions.Session, id string) (map[string]inte
}

// Set sets a value to given session but stored only on commit
func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{}) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
func (s *Store) Set(id, key string, val interface{}) error {
if !validateID(id) {
return ErrInvalidSession
}

s.mu.Lock()
Expand All @@ -156,11 +164,10 @@ func (s *Store) Set(sess *simplesessions.Session, id, key string, val interface{
return nil
}

// Commit sets all set values
func (s *Store) Commit(sess *simplesessions.Session, id string) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
// Commit sets all set values.
func (s *Store) Commit(id string) error {
if !validateID(id) {
return ErrInvalidSession
}

s.mu.RLock()
Expand Down Expand Up @@ -200,10 +207,9 @@ func (s *Store) Commit(sess *simplesessions.Session, id string) error {
}

// Delete deletes a key from redis session hashmap.
func (s *Store) Delete(sess *simplesessions.Session, id string, key string) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
func (s *Store) Delete(id string, key string) error {
if !validateID(id) {
return ErrInvalidSession
}

// Clear temp map for given session id
Expand All @@ -213,16 +219,15 @@ func (s *Store) Delete(sess *simplesessions.Session, id string, key string) erro

err := s.client.HDel(s.clientCtx, s.prefix+id, key).Err()
if err == redis.Nil {
return simplesessions.ErrFieldNotFound
return ErrFieldNotFound
}
return err
}

// Clear clears session in redis.
func (s *Store) Clear(sess *simplesessions.Session, id string) error {
// Check if valid session
if !s.isValidSessionID(sess, id) {
return simplesessions.ErrInvalidSession
func (s *Store) Clear(id string) error {
if !validateID(id) {
return ErrInvalidSession
}

return s.client.Del(s.clientCtx, s.prefix+id).Err()
Expand Down Expand Up @@ -262,3 +267,32 @@ func (s *Store) Bytes(r interface{}, err error) ([]byte, error) {
func (s *Store) Bool(r interface{}, err error) (bool, error) {
return conv.Bool(r, err)
}

func validateID(id string) bool {
if len(id) != sessionIDLen {
return false
}

for _, r := range id {
if !unicode.IsDigit(r) && !unicode.IsLetter(r) {
return false
}
}

return true
}

// generateID generates a random alpha-num session ID.
func generateID(n int) (string, error) {
const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, n)
if _, err := rand.Read(bytes); err != nil {
return "", err
}

for k, v := range bytes {
bytes[k] = dict[v%byte(len(dict))]
}

return string(bytes), nil
}
Loading

0 comments on commit 3214678

Please sign in to comment.