Skip to content

Commit

Permalink
improve get/refresh func
Browse files Browse the repository at this point in the history
  • Loading branch information
larscom committed Sep 3, 2023
1 parent 2a7e17a commit e7fdbab
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 41 deletions.
15 changes: 0 additions & 15 deletions .vscode/launch.json

This file was deleted.

46 changes: 20 additions & 26 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/smallnest/safemap"
"golang.org/x/exp/maps"
)

type Entry[Key comparable, Value any] struct {
Expand Down Expand Up @@ -89,7 +88,7 @@ func (c *Cache[Key, Value]) Close() {

func (c *Cache[Key, Value]) Count() int {
e := c.entries
if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
e = c.getActiveEntries()
}

Expand All @@ -98,7 +97,7 @@ func (c *Cache[Key, Value]) Count() int {

func (c *Cache[Key, Value]) Channel() <-chan Entry[Key, Value] {
e := c.entries
if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
e = c.getActiveEntries()
}

Expand All @@ -119,13 +118,8 @@ func (c *Cache[Key, Value]) Channel() <-chan Entry[Key, Value] {
}

func (c *Cache[Key, Value]) ForEach(fn func(Key, Value)) {
e := c.entries
if c.expireAfterWrite > 0 {
e = c.getActiveEntries()
}

for item := range e.IterBuffered() {
fn(item.Key, item.Val.value)
for item := range c.Channel() {
fn(item.Key, item.Value)
}
}

Expand All @@ -140,7 +134,7 @@ func (c *Cache[Key, Value]) Get(key Key) (Value, error) {

value, err := c.load(key)

unlock()
defer unlock()

if err == nil {
c.Put(key, value)
Expand All @@ -162,11 +156,9 @@ func (c *Cache[Key, Value]) GetIfPresent(key Key) (Value, bool) {

func (c *Cache[Key, Value]) Refresh(key Key) (Value, error) {
unlock := c.loaderMu.lock(key)

defer unlock()
value, err := c.load(key)

unlock()

if err == nil {
c.Put(key, value)
}
Expand All @@ -181,7 +173,7 @@ func (c *Cache[Key, Value]) Has(key Key) bool {

func (c *Cache[Key, Value]) IsEmpty() bool {
e := c.entries
if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
e = c.getActiveEntries()
}

Expand All @@ -190,7 +182,7 @@ func (c *Cache[Key, Value]) IsEmpty() bool {

func (c *Cache[Key, Value]) Keys() []Key {
e := c.entries
if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
e = c.getActiveEntries()
}

Expand All @@ -206,13 +198,12 @@ func (c *Cache[Key, Value]) Remove(key Key) {
}

func (c *Cache[Key, Value]) ToMap() map[Key]Value {
m := make(map[Key]Value)

e := c.entries
if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
e = c.getActiveEntries()
}

m := make(map[Key]Value)
for item := range e.IterBuffered() {
m[item.Key] = item.Val.value
}
Expand All @@ -222,16 +213,15 @@ func (c *Cache[Key, Value]) ToMap() map[Key]Value {

func (c *Cache[Key, Value]) Values() []Value {
e := c.entries
if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
e = c.getActiveEntries()
}

entries := maps.Values(e.Items())
n := len(entries)
values := make([]Value, n)
for i := 0; i < n; i++ {
values[i] = entries[i].value
values := make([]Value, 0)
for item := range e.IterBuffered() {
values = append(values, item.Val.value)
}

return values
}

Expand Down Expand Up @@ -275,7 +265,7 @@ func WithOnExpired[Key comparable, Value any](
func (c *Cache[Key, Value]) newEntry(key Key, value Value) *cacheEntry[Key, Value] {
var expiration time.Time

if c.expireAfterWrite > 0 {
if c.isTimerEnabled() {
expiration = time.Now().Add(c.expireAfterWrite)
}

Expand Down Expand Up @@ -316,3 +306,7 @@ func (c *Cache[Key, Value]) load(key Key) (Value, error) {

return value, err
}

func (c *Cache[Key, Value]) isTimerEnabled() bool {
return c.expireAfterWrite > 0
}
130 changes: 130 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cache
import (
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -399,6 +400,135 @@ func Test_Expiration(t *testing.T) {
}

func Test_WithLoader(t *testing.T) {
t.Run("loader func called once in concurrent environment", func(t *testing.T) {
var (
counter int64
wg sync.WaitGroup
)

c := createCache(WithLoader(func(key int) (int, error) {
atomic.AddInt64(&counter, 1)
time.Sleep(time.Millisecond * 30)
return key, nil
}))

wg.Add(3)
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(100)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(100)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(100)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()

wg.Wait()
assert.Equal(t, int64(1), atomic.LoadInt64(&counter))
})

t.Run("loader func called twice in concurrent environment with 2 different keys", func(t *testing.T) {
var (
counter int64
wg sync.WaitGroup
)

c := createCache(WithLoader(func(key int) (int, error) {
atomic.AddInt64(&counter, 1)
time.Sleep(time.Millisecond * 30)
return key, nil
}))

wg.Add(6)
// key 100
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(100)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(100)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(100)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()

// key 200
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(200)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(200)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()
go func() {
defer wg.Done()
start := time.Now()
val, err := c.Get(200)
end := time.Since(start)
if err != nil {
t.Error(err)
}
fmt.Println("Took", end, "Value", val)
}()

wg.Wait()
assert.Equal(t, int64(2), atomic.LoadInt64(&counter))
})

t.Run("get from loader", func(t *testing.T) {
c := createCache(WithLoader(func(key int) (int, error) {
return 12345, nil
Expand Down

0 comments on commit e7fdbab

Please sign in to comment.