Skip to content

Commit

Permalink
Merge pull request #2 from nhatthm/rw-lock
Browse files Browse the repository at this point in the history
Add `RWMutex`
  • Loading branch information
nhatthm authored Mar 1, 2024
2 parents 36b8c34 + 5eef51c commit a0011fc
Showing 1 changed file with 100 additions and 69 deletions.
169 changes: 100 additions & 69 deletions keyring.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"mime"
"strconv"
"strings"
"sync"

"github.com/zalando/go-keyring"
"go.uber.org/multierr"
Expand All @@ -33,12 +34,63 @@ var (
// KeyringStorage is a storage implementation that uses the OS keyring.
type KeyringStorage[V any] struct {
keyring keyring.Keyring
mu sync.Map
}

func (ss *KeyringStorage[V]) mutex(service, key string) *sync.RWMutex {
m, _ := ss.mu.LoadOrStore(fmt.Sprintf("%s:%s", service, key), &sync.RWMutex{})

return m.(*sync.RWMutex) //nolint: forcetypeassert
}

func (ss *KeyringStorage[V]) withKeyring(keyring keyring.Keyring) {
ss.keyring = keyring
}

func (ss *KeyringStorage[V]) get(service string, key string) (V, error) {
var result V

d, err := ss.keyring.Get(service, key)
if err != nil {
return result, fmt.Errorf("failed to read data from keyring: %w", err)
}

if strings.HasPrefix(d, mimeMultipartSecret) {
_, params, err := mime.ParseMediaType(d)
if err != nil {
return result, fmt.Errorf("failed to get params from data: %w", err)
}

pages, err := strconv.Atoi(params["pages"])
if err != nil {
return result, fmt.Errorf("failed to get pages from data: %w", err)
}

if pages < minPages {
return result, fmt.Errorf("invalid secret pages: %d", pages) //nolint: goerr113
}

var sb strings.Builder

for i := 1; i <= pages; i++ {
p, err := ss.keyring.Get(service, formatPage(key, i))
if err != nil {
return result, fmt.Errorf("failed to read multipart data #%d from keyring: %w", i, err)
}

sb.WriteString(p)
}

d = sb.String()
}

if err := unmarshalData(d, &result); err != nil {
return result, fmt.Errorf("failed to unmarshal data read from keyring: %w", err)
}

return result, nil
}

func (ss *KeyringStorage[V]) set(service string, key string, value string) error {
if err := ss.keyring.Set(service, key, value); err != nil {
return fmt.Errorf("failed to write data to keyring: %w", err)
Expand Down Expand Up @@ -89,75 +141,7 @@ func (ss *KeyringStorage[V]) setMultipart(service string, key string, value stri
return nil
}

// Set sets the value for the given key.
func (ss *KeyringStorage[V]) Set(service string, key string, value V) error {
var err error

d, err := marshalData(value)
if err != nil {
return fmt.Errorf("failed to marshal data for writing to keyring: %w", err)
}

// Delete the data because it could be multipart.
if err = ss.Delete(service, key); err != nil && !errors.Is(err, ErrNotFound) {
return fmt.Errorf("failed to delete old data in keyring: %w", errors.Unwrap(err))
}

length := len(d)
if length <= maxLength {
return ss.set(service, key, d)
}

return ss.setMultipart(service, key, d)
}

// Get gets the value for the given key.
func (ss *KeyringStorage[V]) Get(service string, key string) (V, error) {
var result V

d, err := ss.keyring.Get(service, key)
if err != nil {
return result, fmt.Errorf("failed to read data from keyring: %w", err)
}

if strings.HasPrefix(d, mimeMultipartSecret) {
_, params, err := mime.ParseMediaType(d)
if err != nil {
return result, fmt.Errorf("failed to get params from data: %w", err)
}

pages, err := strconv.Atoi(params["pages"])
if err != nil {
return result, fmt.Errorf("failed to get pages from data: %w", err)
}

if pages < minPages {
return result, fmt.Errorf("invalid secret pages: %d", pages) //nolint: goerr113
}

var sb strings.Builder

for i := 1; i <= pages; i++ {
p, err := ss.keyring.Get(service, formatPage(key, i))
if err != nil {
return result, fmt.Errorf("failed to read multipart data #%d from keyring: %w", i, err)
}

sb.WriteString(p)
}

d = sb.String()
}

if err := unmarshalData(d, &result); err != nil {
return result, fmt.Errorf("failed to unmarshal data read from keyring: %w", err)
}

return result, nil
}

// Delete deletes the value for the given key.
func (ss *KeyringStorage[V]) Delete(service string, key string) error {
func (ss *KeyringStorage[V]) delete(service string, key string) error {
var err error

d, err := ss.keyring.Get(service, key)
Expand Down Expand Up @@ -207,6 +191,53 @@ func (ss *KeyringStorage[V]) Delete(service string, key string) error {
return err
}

// Get gets the value for the given key.
func (ss *KeyringStorage[V]) Get(service string, key string) (V, error) {
mu := ss.mutex(service, key)

mu.RLock()
defer mu.RUnlock()

return ss.get(service, key)
}

// Set sets the value for the given key.
func (ss *KeyringStorage[V]) Set(service string, key string, value V) error {
mu := ss.mutex(service, key)

mu.Lock()
defer mu.Unlock()

var err error

d, err := marshalData(value)
if err != nil {
return fmt.Errorf("failed to marshal data for writing to keyring: %w", err)
}

// Delete the data because it could be multipart.
if err = ss.delete(service, key); err != nil && !errors.Is(err, ErrNotFound) {
return fmt.Errorf("failed to delete old data in keyring: %w", errors.Unwrap(err))
}

length := len(d)
if length <= maxLength {
return ss.set(service, key, d)
}

return ss.setMultipart(service, key, d)
}

// Delete deletes the value for the given key.
func (ss *KeyringStorage[V]) Delete(service string, key string) error {
mu := ss.mutex(service, key)

mu.Lock()
defer mu.Unlock()

return ss.delete(service, key)
}

// NewKeyringStorage creates a new KeyringStorage that uses the OS keyring.
func NewKeyringStorage[V any](opts ...KeyringStorageOption) *KeyringStorage[V] {
s := &KeyringStorage[V]{
Expand Down

0 comments on commit a0011fc

Please sign in to comment.