diff --git a/cache/cache_test.go b/cache/cache_test.go index a22fcb56b17e..7f759e3c7833 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -142,8 +142,18 @@ func testCache(t *testing.T, cache cache.Cache) { } func TestMemcache(t *testing.T) { - cache := cache.NewMemcached(cache.MemcachedConfig{}, newMockMemcache()) - testCache(t, cache) + t.Run("Unbatched", func(t *testing.T) { + cache := cache.NewMemcached(cache.MemcachedConfig{}, newMockMemcache()) + testCache(t, cache) + }) + + t.Run("Batched", func(t *testing.T) { + cache := cache.NewMemcached(cache.MemcachedConfig{ + BatchSize: 10, + Parallelism: 3, + }, newMockMemcache()) + testCache(t, cache) + }) } func TestDiskcache(t *testing.T) { diff --git a/cache/memcached.go b/cache/memcached.go index 4e477123826e..2f40891e4f93 100644 --- a/cache/memcached.go +++ b/cache/memcached.go @@ -3,11 +3,15 @@ package cache import ( "context" "flag" + "sync" "time" "github.com/bradfitz/gomemcache/memcache" + opentracing "github.com/opentracing/opentracing-go" + otlog "github.com/opentracing/opentracing-go/log" "github.com/prometheus/client_golang/prometheus" instr "github.com/weaveworks/common/instrument" + "github.com/weaveworks/cortex/pkg/util" ) var ( @@ -27,17 +31,25 @@ func init() { // MemcachedConfig is config to make a Memcached type MemcachedConfig struct { Expiration time.Duration + + BatchSize int + Parallelism int } // RegisterFlags adds the flags required to config this to the given FlagSet func (cfg *MemcachedConfig) RegisterFlags(f *flag.FlagSet) { - f.DurationVar(&cfg.Expiration, "memcached.expiration", 0, "How long chunks stay in the memcache.") + f.DurationVar(&cfg.Expiration, "memcached.expiration", 0, "How long keys stay in the memcache.") + f.IntVar(&cfg.BatchSize, "memcached.batchsize", 0, "How many keys to fetch in each batch.") + f.IntVar(&cfg.Parallelism, "memcached.parallelism", 100, "Maximum active requests to memcache.") } // Memcached type caches chunks in memcached type Memcached struct { cfg MemcachedConfig memcache MemcachedClient + + wg sync.WaitGroup + inputCh chan *work } // NewMemcached makes a new Memcache @@ -46,9 +58,46 @@ func NewMemcached(cfg MemcachedConfig, client MemcachedClient) *Memcached { cfg: cfg, memcache: client, } + + if cfg.BatchSize == 0 || cfg.Parallelism == 0 { + return c + } + + c.inputCh = make(chan *work) + c.wg.Add(cfg.Parallelism) + + for i := 0; i < cfg.Parallelism; i++ { + go func() { + for input := range c.inputCh { + res := &result{ + batchID: input.batchID, + } + res.found, res.bufs, res.missed, res.err = c.fetch(input.ctx, input.keys) + input.resultCh <- res + } + + c.wg.Done() + }() + } + return c } +type work struct { + keys []string + ctx context.Context + resultCh chan<- *result + batchID int // For ordering results. +} + +type result struct { + found []string + bufs [][]byte + missed []string + err error + batchID int // For ordering results. +} + func memcacheStatusCode(err error) string { // See https://godoc.org/github.com/bradfitz/gomemcache/memcache#pkg-variables switch err { @@ -63,17 +112,41 @@ func memcacheStatusCode(err error) string { } } -// Fetch gets keys from the cache. +// Fetch gets keys from the cache. The keys that are found must be in the order of the keys requested. func (c *Memcached) Fetch(ctx context.Context, keys []string) (found []string, bufs [][]byte, missed []string, err error) { + err = instr.TimeRequestHistogramStatus(ctx, "Memcache.Get", memcacheRequestDuration, memcacheStatusCode, func(ctx context.Context) error { + sp := opentracing.SpanFromContext(ctx) + sp.LogFields(otlog.Int("keys requested", len(keys))) + defer func() { + sp.LogFields(otlog.Int("keys found", len(found)), otlog.Int("keys missing", len(missed))) + }() + + var err error + if c.cfg.BatchSize == 0 { + found, bufs, missed, err = c.fetch(ctx, keys) + return err + } + + found, bufs, missed, err = c.fetchKeysBatched(ctx, keys) + return err + }) + + return +} + +func (c *Memcached) fetch(ctx context.Context, keys []string) (found []string, bufs [][]byte, missed []string, err error) { var items map[string]*memcache.Item - err = instr.TimeRequestHistogramStatus(ctx, "Memcache.Get", memcacheRequestDuration, memcacheStatusCode, func(_ context.Context) error { + err = UntracedCollectedRequest(ctx, "Memcache.GetMulti", instr.NewHistogramCollector(memcacheRequestDuration), memcacheStatusCode, func(_ context.Context) error { var err error items, err = c.memcache.GetMulti(keys) return err }) + if err != nil { + missed = keys return } + for _, key := range keys { item, ok := items[key] if ok { @@ -86,6 +159,50 @@ func (c *Memcached) Fetch(ctx context.Context, keys []string) (found []string, b return } +func (c *Memcached) fetchKeysBatched(ctx context.Context, keys []string) (found []string, bufs [][]byte, missed []string, err error) { + resultsCh := make(chan *result) + batchSize := c.cfg.BatchSize + + go func() { + for i, j := 0, 0; i < len(keys); i += batchSize { + batchKeys := keys[i:util.Min(i+batchSize, len(keys))] + c.inputCh <- &work{ + keys: batchKeys, + ctx: ctx, + resultCh: resultsCh, + batchID: j, + } + j++ + } + }() + + // Read all values from this channel to avoid blocking upstream. + numResults := len(keys) / batchSize + if len(keys)%batchSize != 0 { + numResults++ + } + + // We need to order found by the input keys order. + results := make([]*result, numResults) + for i := 0; i < numResults; i++ { + result := <-resultsCh + results[result.batchID] = result + } + close(resultsCh) + + for _, result := range results { + if result.err != nil { + err = result.err + } + + found = append(found, result.found...) + bufs = append(bufs, result.bufs...) + missed = append(missed, result.missed...) + } + + return +} + // Store stores the key in the cache. func (c *Memcached) Store(ctx context.Context, key string, buf []byte) error { return instr.TimeRequestHistogramStatus(ctx, "Memcache.Put", memcacheRequestDuration, memcacheStatusCode, func(_ context.Context) error { @@ -99,6 +216,12 @@ func (c *Memcached) Store(ctx context.Context, key string, buf []byte) error { } // Stop does nothing. -func (*Memcached) Stop() error { +func (c *Memcached) Stop() error { + if c.inputCh == nil { + return nil + } + + close(c.inputCh) + c.wg.Wait() return nil } diff --git a/cache/memcached_client_test.go b/cache/memcached_client_test.go new file mode 100644 index 000000000000..028fba8ef460 --- /dev/null +++ b/cache/memcached_client_test.go @@ -0,0 +1,39 @@ +package cache_test + +import ( + "sync" + + "github.com/bradfitz/gomemcache/memcache" +) + +type mockMemcache struct { + sync.RWMutex + contents map[string][]byte +} + +func newMockMemcache() *mockMemcache { + return &mockMemcache{ + contents: map[string][]byte{}, + } +} + +func (m *mockMemcache) GetMulti(keys []string) (map[string]*memcache.Item, error) { + m.RLock() + defer m.RUnlock() + result := map[string]*memcache.Item{} + for _, k := range keys { + if c, ok := m.contents[k]; ok { + result[k] = &memcache.Item{ + Value: c, + } + } + } + return result, nil +} + +func (m *mockMemcache) Set(item *memcache.Item) error { + m.Lock() + defer m.Unlock() + m.contents[item.Key] = item.Value + return nil +} diff --git a/cache/memcached_test.go b/cache/memcached_test.go index 028fba8ef460..b003e6a2d07c 100644 --- a/cache/memcached_test.go +++ b/cache/memcached_test.go @@ -1,39 +1,146 @@ package cache_test import ( - "sync" + "context" + "errors" + "sync/atomic" + "testing" "github.com/bradfitz/gomemcache/memcache" + "github.com/stretchr/testify/require" + "github.com/weaveworks/cortex/pkg/chunk/cache" ) -type mockMemcache struct { - sync.RWMutex - contents map[string][]byte +func TestMemcached(t *testing.T) { + t.Run("unbatched", func(t *testing.T) { + client := newMockMemcache() + memcache := cache.NewMemcached(cache.MemcachedConfig{}, client) + + testMemcache(t, memcache) + }) + + t.Run("batched", func(t *testing.T) { + client := newMockMemcache() + memcache := cache.NewMemcached(cache.MemcachedConfig{ + BatchSize: 10, + Parallelism: 5, + }, client) + + testMemcache(t, memcache) + }) } -func newMockMemcache() *mockMemcache { - return &mockMemcache{ - contents: map[string][]byte{}, +func testMemcache(t *testing.T, memcache *cache.Memcached) { + numKeys := 1000 + + ctx := context.Background() + keys := make([]string, 0, numKeys) + // Insert 1000 keys skipping all multiples of 5. + for i := 0; i < numKeys; i++ { + keys = append(keys, string(i)) + if i%5 == 0 { + continue + } + + require.NoError(t, memcache.Store(ctx, string(i), []byte(string(i)))) } -} -func (m *mockMemcache) GetMulti(keys []string) (map[string]*memcache.Item, error) { - m.RLock() - defer m.RUnlock() - result := map[string]*memcache.Item{} - for _, k := range keys { - if c, ok := m.contents[k]; ok { - result[k] = &memcache.Item{ - Value: c, - } + found, bufs, missing, err := memcache.Fetch(ctx, keys) + require.NoError(t, err) + for i := 0; i < numKeys; i++ { + if i%5 == 0 { + require.Equal(t, string(i), missing[0]) + missing = missing[1:] + continue } + + require.Equal(t, string(i), found[0]) + require.Equal(t, string(i), string(bufs[0])) + found = found[1:] + bufs = bufs[1:] } - return result, nil } -func (m *mockMemcache) Set(item *memcache.Item) error { - m.Lock() - defer m.Unlock() - m.contents[item.Key] = item.Value - return nil +// mockMemcache whose calls fail 1/3rd of the time. +type mockMemcacheFailing struct { + *mockMemcache + calls uint64 +} + +func newMockMemcacheFailing() *mockMemcacheFailing { + return &mockMemcacheFailing{ + mockMemcache: newMockMemcache(), + } +} + +func (c *mockMemcacheFailing) GetMulti(keys []string) (map[string]*memcache.Item, error) { + calls := atomic.AddUint64(&c.calls, 1) + if calls%3 == 0 { + return nil, errors.New("fail") + } + + return c.mockMemcache.GetMulti(keys) +} + +func TestMemcacheFailure(t *testing.T) { + t.Run("unbatched", func(t *testing.T) { + client := newMockMemcacheFailing() + memcache := cache.NewMemcached(cache.MemcachedConfig{}, client) + + testMemcacheFailing(t, memcache) + }) + + t.Run("batched", func(t *testing.T) { + client := newMockMemcacheFailing() + memcache := cache.NewMemcached(cache.MemcachedConfig{ + BatchSize: 10, + Parallelism: 5, + }, client) + + testMemcacheFailing(t, memcache) + }) +} + +func testMemcacheFailing(t *testing.T, memcache *cache.Memcached) { + numKeys := 1000 + + ctx := context.Background() + keys := make([]string, 0, numKeys) + // Insert 1000 keys skipping all multiples of 5. + for i := 0; i < numKeys; i++ { + keys = append(keys, string(i)) + if i%5 == 0 { + continue + } + + require.NoError(t, memcache.Store(ctx, string(i), []byte(string(i)))) + } + + for i := 0; i < 10; i++ { + found, bufs, missing, _ := memcache.Fetch(ctx, keys) + + require.Equal(t, len(found), len(bufs)) + for i := range found { + require.Equal(t, found[i], string(bufs[i])) + } + + keysReturned := make(map[string]struct{}) + for _, key := range found { + _, ok := keysReturned[key] + require.False(t, ok, "duplicate key returned") + + keysReturned[key] = struct{}{} + } + for _, key := range missing { + _, ok := keysReturned[key] + require.False(t, ok, "duplicate key returned") + + keysReturned[key] = struct{}{} + } + + for _, key := range keys { + _, ok := keysReturned[key] + require.True(t, ok, "key missing %s", key) + } + } }