Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
control: be more aware of test clocks (#270)
Browse files Browse the repository at this point in the history
This commit fixes a bug with test clocks where customer IDs were cached
based only on the org ID, not by test clocks. This resulted in invalid
cache hits due to org name collisions if the same org name was used
across test clocks because the first customer ID would be cached and
then used in requests targeting different clocks where the customer ID
did not exist.

This also narrows the list of returned customers by Stripe to those in a
test clock, if any.
  • Loading branch information
bmizerany authored Mar 3, 2023
1 parent f366a46 commit 8d43747
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 16 deletions.
34 changes: 22 additions & 12 deletions control/cache.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
package control

import (
"strings"
"sync"

"github.com/golang/groupcache/lru"
"github.com/golang/groupcache/singleflight"
"tier.run/lru"
)

type orgKey struct {
account string
clock string // a test clock, if any
name string
}

type memo struct {
m sync.Mutex
lru *lru.Cache
lru *lru.Cache[orgKey, string] // map[orgKey] -> customerID
group singleflight.Group
}

func (m *memo) lookupCache(key string) (string, bool) {
func (m *memo) lookupCache(key orgKey) (string, bool) {
m.m.Lock()
defer m.m.Unlock()
if m.lru == nil {
return "", false
}
v, ok := m.lru.Get(key)
if !ok {
return "", false
}
return v.(string), true
return m.lru.Get(key)
}

func (m *memo) load(key string, fn func() (string, error)) (string, error) {
func (m *memo) load(key orgKey, fn func() (string, error)) (string, error) {
s, cacheHit := m.lookupCache(key)
if cacheHit {
return s, nil
}

v, err := m.group.Do(key, func() (any, error) {
// TODO(bmizerany): make a singleflight with generics to avoid building
// a string instead of using orgKey as a key
var b strings.Builder
b.WriteString(key.account)
b.WriteString(key.clock)
b.WriteString(key.name)

v, err := m.group.Do(b.String(), func() (any, error) {
v, cacheHit := m.lookupCache(key)
if cacheHit {
return v, nil
Expand All @@ -51,11 +61,11 @@ func (m *memo) load(key string, fn func() (string, error)) (string, error) {
return v.(string), nil
}

func (m *memo) add(key, val string) {
func (m *memo) add(key orgKey, val string) {
m.m.Lock()
defer m.m.Unlock()
if m.lru == nil {
m.lru = lru.New(100) // TODO(bmizerany): make configurable
m.lru = lru.New[orgKey, string](100) // TODO(bmizerany): make configurable
}
m.lru.Add(key, val)
}
34 changes: 30 additions & 4 deletions control/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,19 @@ func (c *Client) WhoIs(ctx context.Context, org string) (id string, err error) {
return "", &ValidationError{Message: "org must be prefixed with \"org:\""}
}

cid, err := c.cache.load(org, func() (string, error) {
clockID := clockFromContext(ctx)

key := orgKey{
name: org,
account: c.Stripe.AccountID,
clock: clockID,
}

cid, err := c.cache.load(key, func() (string, error) {
var f stripe.Form
if clockID != "" {
f.Add("test_clock", clockID)
}
cus, err := stripe.List[stripeCustomer](ctx, c.Stripe, "GET", "/v1/customers", f).
Find(func(v stripeCustomer) bool {
return v.Metadata.Org == org
Expand Down Expand Up @@ -934,14 +945,29 @@ func (c *Client) LookupOrg(ctx context.Context, org string) (*OrgInfo, error) {

func (c *Client) createCustomer(ctx context.Context, org string, info *OrgInfo) (id string, err error) {
defer errorfmt.Handlef("createCustomer: %w", &err)
return c.cache.load(org, func() (string, error) {

clockID := clockFromContext(ctx)
key := orgKey{
account: c.Stripe.AccountID,
clock: clockID,
name: org,
}

return c.cache.load(key, func() (string, error) {
var f stripe.Form
f.SetIdempotencyKey("customer:create:" + org)

var b strings.Builder
b.WriteString("customer:create:")
b.WriteString(key.name)
b.WriteString(key.account)
b.WriteString(key.clock)
f.SetIdempotencyKey(b.String())

f.Set("metadata[tier.org]", org)
if err := setOrgInfo(&f, info); err != nil {
return "", err
}
if clockID := clockFromContext(ctx); clockID != "" {
if clockID != "" {
f.Set("test_clock", clockID)
}
var created struct {
Expand Down
29 changes: 29 additions & 0 deletions control/schedule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,35 @@ func TestSchedulePutCustomer(t *testing.T) {
}, nil, nil)
}

func TestClocksWithCache(t *testing.T) {
cc := newTestClient(t)

var want string
for i := 0; i < 3; i++ {
ctx := context.Background()
c, err := cc.NewClock(ctx, t.Name(), time.Now())
if err != nil {
t.Fatal(err)
}

ctx = WithClock(ctx, c.ID())
if err := cc.PutCustomer(ctx, "org:example", nil); err != nil {
t.Fatal(err)
}
got, err := cc.WhoIs(ctx, "org:example")
if err != nil {
t.Fatal(err)
}
if i == 0 {
want = got
} else {
if got == want {
t.Errorf("unexpected match on iteration %d", i)
}
}
}
}

func ciOnly(t *testing.T) {
if os.Getenv("CI") == "" {
t.Skip("not in CI; skipping long test")
Expand Down
115 changes: 115 additions & 0 deletions lru/lru.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Package lru implements a LRU cache. It is a generic vesion of
// github.com/golang/groupcache/lru.
package lru

import "container/list"

// Cache is an LRU cache. It is not safe for concurrent access.
type Cache[K comparable, V any] struct {
// MaxEntries is the maximum number of cache entries before
// an item is evicted. Zero means no limit.
MaxEntries int

// OnEvicted optionally specifies a callback function to be
// executed when an entry is purged from the cache.
OnEvicted func(key K, value V)

ll *list.List
cache map[K]*list.Element
}

type entry[K comparable, V any] struct {
key K
value V
}

// New creates a new Cache.
// If maxEntries is zero, the cache has no limit and it's assumed
// that eviction is done by the caller.
func New[K comparable, V any](maxEntries int) *Cache[K, V] {
return &Cache[K, V]{
MaxEntries: maxEntries,
ll: list.New(),
cache: make(map[K]*list.Element),
}
}

// Add adds a value to the cache.
func (c *Cache[K, V]) Add(key K, value V) {
if c.cache == nil {
c.cache = make(map[K]*list.Element)
c.ll = list.New()
}
if ee, ok := c.cache[key]; ok {
c.ll.MoveToFront(ee)
ee.Value.(*entry[K, V]).value = value
return
}
ele := c.ll.PushFront(&entry[K, V]{key, value})
c.cache[key] = ele
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
c.RemoveOldest()
}
}

// Get looks up a key's value from the cache.
func (c *Cache[K, V]) Get(key K) (value V, ok bool) {
if c.cache == nil {
return
}
if ele, hit := c.cache[key]; hit {
c.ll.MoveToFront(ele)
return ele.Value.(*entry[K, V]).value, true
}
return
}

// Remove removes the provided key from the cache.
func (c *Cache[K, V]) Remove(key K) {
if c.cache == nil {
return
}
if ele, hit := c.cache[key]; hit {
c.removeElement(ele)
}
}

// RemoveOldest removes the oldest item from the cache.
func (c *Cache[K, V]) RemoveOldest() {
if c.cache == nil {
return
}
ele := c.ll.Back()
if ele != nil {
c.removeElement(ele)
}
}

func (c *Cache[K, V]) removeElement(e *list.Element) {
c.ll.Remove(e)
kv := e.Value.(*entry[K, V])
delete(c.cache, kv.key)
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}

// Len returns the number of items in the cache.
func (c *Cache[K, V]) Len() int {
if c.cache == nil {
return 0
}
return c.ll.Len()
}

// Clear purges all stored items from the cache.
func (c *Cache[K, V]) Clear() {
if c.OnEvicted != nil {
for _, e := range c.cache {
kv := e.Value.(*entry[K, V])
c.OnEvicted(kv.key, kv.value)
}
}
c.ll = nil
c.cache = nil
}
97 changes: 97 additions & 0 deletions lru/lru_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright 2013 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package lru

import (
"fmt"
"testing"
)

type simpleStruct struct {
int
string
}

type complexStruct struct {
int
simpleStruct
}

var getTests = []struct {
name string
keyToAdd interface{}
keyToGet interface{}
expectedOk bool
}{
{"string_hit", "myKey", "myKey", true},
{"string_miss", "myKey", "nonsense", false},
{"simple_struct_hit", simpleStruct{1, "two"}, simpleStruct{1, "two"}, true},
{"simple_struct_miss", simpleStruct{1, "two"}, simpleStruct{0, "noway"}, false},
{"complex_struct_hit", complexStruct{1, simpleStruct{2, "three"}},
complexStruct{1, simpleStruct{2, "three"}}, true},
}

func TestGet(t *testing.T) {
for _, tt := range getTests {
lru := New[any, int](0)
lru.Add(tt.keyToAdd, 1234)
val, ok := lru.Get(tt.keyToGet)
if ok != tt.expectedOk {
t.Fatalf("%s: cache hit = %v; want %v", tt.name, ok, !ok)
} else if ok && val != 1234 {
t.Fatalf("%s expected get to return 1234 but got %v", tt.name, val)
}
}
}

func TestRemove(t *testing.T) {
lru := New[string, int](0)
lru.Add("myKey", 1234)
if val, ok := lru.Get("myKey"); !ok {
t.Fatal("TestRemove returned no match")
} else if val != 1234 {
t.Fatalf("TestRemove failed. Expected %d, got %v", 1234, val)
}

lru.Remove("myKey")
if _, ok := lru.Get("myKey"); ok {
t.Fatal("TestRemove returned a removed entry")
}
}

func TestEvict(t *testing.T) {
evictedKeys := make([]any, 0)
onEvictedFun := func(key string, value int) {
evictedKeys = append(evictedKeys, key)
}

lru := New[string, int](20)
lru.OnEvicted = onEvictedFun
for i := 0; i < 22; i++ {
lru.Add(fmt.Sprintf("myKey%d", i), 1234)
}

if len(evictedKeys) != 2 {
t.Fatalf("got %d evicted keys; want 2", len(evictedKeys))
}
if evictedKeys[0] != "myKey0" {
t.Fatalf("got %v in first evicted key; want %s", evictedKeys[0], "myKey0")
}
if evictedKeys[1] != "myKey1" {
t.Fatalf("got %v in second evicted key; want %s", evictedKeys[1], "myKey1")
}
}

0 comments on commit 8d43747

Please sign in to comment.