diff --git a/ast/term.go b/ast/term.go index 370095895e..3ace3c1864 100644 --- a/ast/term.go +++ b/ast/term.go @@ -364,6 +364,22 @@ func (term *Term) Equal(other *Term) bool { if term == other { return true } + + // TODO(tsandall): This early-exit avoids allocations for types that have + // Equal() functions that just use == underneath. We should revisit the + // other types and implement Equal() functions that do not require + // allocations. + switch v := term.Value.(type) { + case Null: + return v.Equal(other.Value) + case Boolean: + return v.Equal(other.Value) + case String: + return v.Equal(other.Value) + case Var: + return v.Equal(other.Value) + } + return term.Value.Compare(other.Value) == 0 } diff --git a/topdown/cache.go b/topdown/cache.go index f69c8d9ac7..33a8b16f26 100644 --- a/topdown/cache.go +++ b/topdown/cache.go @@ -6,6 +6,7 @@ package topdown import ( "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/util" ) type virtualCache struct { @@ -14,7 +15,7 @@ type virtualCache struct { type virtualCacheElem struct { value *ast.Term - children map[ast.Value]*virtualCacheElem + children *util.HashMap } func newVirtualCache() *virtualCache { @@ -34,12 +35,11 @@ func (c *virtualCache) Pop() { func (c *virtualCache) Get(ref ast.Ref) *ast.Term { node := c.stack[len(c.stack)-1] for i := 0; i < len(ref); i++ { - key := ref[i].Value - next := node.children[key] - if next == nil { + x, ok := node.children.Get(ref[i]) + if !ok { return nil } - node = next + node = x.(*virtualCacheElem) } return node.value } @@ -47,21 +47,28 @@ func (c *virtualCache) Get(ref ast.Ref) *ast.Term { func (c *virtualCache) Put(ref ast.Ref, value *ast.Term) { node := c.stack[len(c.stack)-1] for i := 0; i < len(ref); i++ { - key := ref[i].Value - next := node.children[key] - if next == nil { - next = newVirtualCacheElem() - node.children[key] = next + x, ok := node.children.Get(ref[i]) + if ok { + node = x.(*virtualCacheElem) + } else { + next := newVirtualCacheElem() + node.children.Put(ref[i], next) + node = next } - node = next } node.value = value } func newVirtualCacheElem() *virtualCacheElem { - return &virtualCacheElem{ - children: map[ast.Value]*virtualCacheElem{}, - } + return &virtualCacheElem{children: newVirtualCacheHashMap()} +} + +func newVirtualCacheHashMap() *util.HashMap { + return util.NewHashMap(func(a, b util.T) bool { + return a.(*ast.Term).Equal(b.(*ast.Term)) + }, func(x util.T) int { + return x.(*ast.Term).Hash() + }) } // baseCache implements a trie structure to cache base documents read out of diff --git a/topdown/cache_bench_test.go b/topdown/cache_bench_test.go new file mode 100644 index 0000000000..ea29298f37 --- /dev/null +++ b/topdown/cache_bench_test.go @@ -0,0 +1,48 @@ +// Copyright 2019 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "fmt" + "testing" + + "github.com/open-policy-agent/opa/ast" +) + +func BenchmarkVirtualCache(b *testing.B) { + + n := 10 + max := n * n * n + + keys := make([]ast.Ref, 0, max) + values := make([]*ast.Term, 0, max) + + for i := 0; i < n; i++ { + k1 := ast.StringTerm(fmt.Sprintf("aaaa%v", i)) + for j := 0; j < n; j++ { + k2 := ast.StringTerm(fmt.Sprintf("bbbb%v", j)) + for k := 0; k < n; k++ { + k3 := ast.StringTerm(fmt.Sprintf("cccc%v", k)) + key := ast.Ref{ast.DefaultRootDocument, k1, k2, k3} + value := ast.ArrayTerm(k1, k2, k3) + keys = append(keys, key) + values = append(values, value) + } + } + } + + cache := newVirtualCache() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + idx := i % max + cache.Put(keys[idx], values[idx]) + result := cache.Get(keys[idx]) + if !result.Equal(values[idx]) { + b.Fatal("expected equal") + } + } + +} diff --git a/topdown/cache_test.go b/topdown/cache_test.go index ab71204e1b..7e06f2c687 100644 --- a/topdown/cache_test.go +++ b/topdown/cache_test.go @@ -10,6 +10,16 @@ import ( "github.com/open-policy-agent/opa/ast" ) +func TestVirtualCacheCompositeKey(t *testing.T) { + cache := newVirtualCache() + ref := ast.MustParseRef("data.x.y[[1]].z") + cache.Put(ref, ast.BooleanTerm(true)) + result := cache.Get(ref) + if !result.Equal(ast.BooleanTerm(true)) { + t.Fatalf("Expected true but got %v", result) + } +} + func TestVirtualCacheInvalidate(t *testing.T) { cache := newVirtualCache() cache.Push()