Skip to content

Commit

Permalink
feat: implements external data response cache
Browse files Browse the repository at this point in the history
Signed-off-by: Nilekh Chaudhari <1626598+nilekhc@users.noreply.github.com>
  • Loading branch information
nilekhc committed Jul 6, 2023
1 parent 6ccacf8 commit 06a54a8
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 4 deletions.
8 changes: 8 additions & 0 deletions constraint/pkg/client/drivers/rego/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ func AddExternalDataProviderCache(providerCache *externaldata.ProviderCache) Arg
}
}

func AddExternalDataProviderResponseCache(providerResponseCache *externaldata.ProviderResponseCache) Arg {
return func(d *Driver) error {
d.providerResponseCache = providerResponseCache

return nil
}
}

func DisableBuiltins(builtins ...string) Arg {
return func(d *Driver) error {
if d.compilers.capabilities == nil {
Expand Down
80 changes: 76 additions & 4 deletions constraint/pkg/client/drivers/rego/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@ package rego

import (
"net/http"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/externaldata"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
)

const (
providerResponseAPIVersion = "externaldata.gatekeeper.sh/v1beta1"
providerResponseKind = "ProviderResponse"
)

func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *ast.Term) (*ast.Term, error) {
return func(bctx rego.BuiltinContext, regorequest *ast.Term) (*ast.Term, error) {
var regoReq externaldata.RegoRequest
Expand All @@ -25,12 +31,78 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
return externaldata.HandleError(http.StatusBadRequest, err)
}

externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, regoReq.Keys, clientCert)
if err != nil {
return externaldata.HandleError(statusCode, err)
// check provider response cache
var providerRequestKeys []string
var providerResponseStatusCode int
var prepareResponse externaldata.Response

prepareResponse.Idempotent = true
for _, k := range regoReq.Keys {
cachedResponse, err := d.providerResponseCache.Get(
externaldata.CacheKey{
ProviderName: regoReq.ProviderName,
Key: k,
},
)
if err != nil || time.Since(time.Unix(cachedResponse.Received, 0)) > d.providerResponseCache.TTL {
// key is not found or cache entry is stale, add key to the provider request keys
providerRequestKeys = append(providerRequestKeys, k)
} else {
prepareResponse.Items = append(
prepareResponse.Items, externaldata.Item{
Key: k,
Value: cachedResponse.Value,
Error: cachedResponse.Error,
},
)

// we are taking conservative approach here, if any of the cached response is not idempotent
// we will mark the whole response as not idempotent
if !cachedResponse.Idempotent {
prepareResponse.Idempotent = false
}
}
}

if len(providerRequestKeys) > 0 {
externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, clientCert)
if err != nil {
return externaldata.HandleError(statusCode, err)
}

for _, item := range externaldataResponse.Response.Items {
d.providerResponseCache.Upsert(
externaldata.CacheKey{
ProviderName: regoReq.ProviderName,
Key: item.Key,
},
externaldata.CacheValue{
Received: time.Now().Unix(),
Value: item.Value,
Error: item.Error,
Idempotent: externaldataResponse.Response.Idempotent,
},
)
}

// we are taking conservative approach here, if any of the response is not idempotent
// we will mark the whole response as not idempotent
if !externaldataResponse.Response.Idempotent {
prepareResponse.Idempotent = false
}

prepareResponse.Items = append(prepareResponse.Items, externaldataResponse.Response.Items...)
prepareResponse.SystemError = externaldataResponse.Response.SystemError
providerResponseStatusCode = statusCode
}

providerResponse := &externaldata.ProviderResponse{
APIVersion: providerResponseAPIVersion,
Kind: providerResponseKind,
Response: prepareResponse,
}

regoResponse := externaldata.NewRegoResponse(statusCode, externaldataResponse)
regoResponse := externaldata.NewRegoResponse(providerResponseStatusCode, providerResponse)
return externaldata.PrepareRegoResponse(regoResponse)
}
}
3 changes: 3 additions & 0 deletions constraint/pkg/client/drivers/rego/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ type Driver struct {
// providerCache allows Rego to read from external_data in Rego queries.
providerCache *externaldata.ProviderCache

// providerResponseCache allows to cache responses from external_data providers.
providerResponseCache *externaldata.ProviderResponseCache

// sendRequestToProvider allows Rego to send requests to the provider specified in external_data.
sendRequestToProvider externaldata.SendRequestToProvider

Expand Down
2 changes: 2 additions & 0 deletions constraint/pkg/client/drivers/rego/driver_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"sort"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -743,6 +744,7 @@ func TestDriver_ExternalData(t *testing.T) {

d, err := New(
AddExternalDataProviderCache(externaldata.NewCache()),
AddExternalDataProviderResponseCache(externaldata.NewProviderResponseCache(context.Background(), 1*time.Minute)),
EnableExternalDataClientAuth(),
AddExternalDataClientCertWatcher(clientCertWatcher),
)
Expand Down
70 changes: 70 additions & 0 deletions constraint/pkg/externaldata/cache.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,91 @@
package externaldata

import (
"context"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"net/url"
"sync"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
"k8s.io/apimachinery/pkg/util/wait"
)

type ProviderCache struct {
cache map[string]unversioned.Provider
mux sync.RWMutex
}

type ProviderResponseCache struct {
cache sync.Map
TTL time.Duration
}

type CacheKey struct {
ProviderName string
Key string
}

type CacheValue struct {
Received int64
Value interface{}
Error string
Idempotent bool
}

func NewProviderResponseCache(ctx context.Context, ttl time.Duration) *ProviderResponseCache {
providerResponseCache := &ProviderResponseCache{
cache: sync.Map{},
TTL: ttl,
}

go wait.UntilWithContext(ctx, func(ctx context.Context) {
providerResponseCache.invalidateProviderResponseCache(providerResponseCache.TTL)
}, ttl)

return providerResponseCache
}

func (c *ProviderResponseCache) Get(key CacheKey) (*CacheValue, error) {
if v, ok := c.cache.Load(key); ok {
value, ok := v.(*CacheValue)
if !ok {
return nil, fmt.Errorf("value is not of type CacheValue")
}
return value, nil
}
return nil, fmt.Errorf("key '%s:%s' is not found in provider response cache", key.ProviderName, key.Key)
}

func (c *ProviderResponseCache) Upsert(key CacheKey, value CacheValue) {
c.cache.Store(key, &value)
}

func (c *ProviderResponseCache) Remove(key CacheKey) {
c.cache.Delete(key)
}

func (c *ProviderResponseCache) invalidateProviderResponseCache(ttl time.Duration) {
c.cache.Range(func(k, v interface{}) bool {
value, ok := v.(*CacheValue)
if !ok {
return false
}

if time.Since(time.Unix(value.Received, 0)) > ttl {
key, ok := k.(CacheKey)
if !ok {
return false
}
c.Remove(key)
}
return true
})
}

func NewCache() *ProviderCache {
return &ProviderCache{
cache: make(map[string]unversioned.Provider),
Expand Down
97 changes: 97 additions & 0 deletions constraint/pkg/externaldata/cache_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package externaldata

import (
"context"
"fmt"
"testing"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -169,3 +172,97 @@ func TestRemove(t *testing.T) {
})
}
}

func TestProviderResponseCache(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

tests := []struct {
name string
key CacheKey
value CacheValue
expected *CacheValue
expectedErr error
}{
{
name: "Upsert and Get",
key: CacheKey{ProviderName: "test", Key: "key1"},
value: CacheValue{Received: time.Now().Unix(), Value: "value1"},
expected: &CacheValue{Received: time.Now().Unix(), Value: "value1"},
expectedErr: nil,
},
{
name: "Remove",
key: CacheKey{ProviderName: "test", Key: "key1"},
value: CacheValue{Received: time.Now().Unix(), Value: "value1"},
expected: nil,
expectedErr: fmt.Errorf("key 'test:key1' is not found in provider response cache"),
},
{
name: "Invalidation",
key: CacheKey{ProviderName: "test", Key: "key2"},
value: CacheValue{Value: "value2"},
expected: nil,
expectedErr: fmt.Errorf("key 'test:key2' is not found in provider response cache"),
},
{
name: "Error",
key: CacheKey{ProviderName: "test", Key: "key3"},
value: CacheValue{},
expected: nil,
expectedErr: fmt.Errorf("key 'test:key3' is not found in provider response cache"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switch tt.name {
case "Upsert and Get":
cache := NewProviderResponseCache(ctx, 1*time.Minute)
cache.Upsert(tt.key, tt.value)

cachedValue, err := cache.Get(tt.key)
if err != tt.expectedErr {
t.Errorf("Expected error to be %v, but got %v", tt.expectedErr, err)
}
if cachedValue != nil && cachedValue.Value != tt.expected.Value {
t.Errorf("Expected cached value to be %v, but got %v", tt.expected.Value, cachedValue.Value)
}
case "Remove":
cache := NewProviderResponseCache(ctx, 1*time.Minute)
cache.Remove(tt.key)

_, err := cache.Get(tt.key)
if err == nil {
t.Errorf("Expected error, but got nil")
}
if err.Error() != tt.expectedErr.Error() {
t.Errorf("Expected error message to be '%s', but got '%s'", tt.expectedErr.Error(), err.Error())
}
case "Invalidation":
cache := NewProviderResponseCache(ctx, 5*time.Second)
tt.value.Received = time.Now().Add(-10 * time.Second).Unix()
cache.Upsert(tt.key, tt.value)

time.Sleep(5 * time.Second)

_, err := cache.Get(tt.key)
if err == nil {
t.Errorf("Expected error, but got nil")
}
if err.Error() != tt.expectedErr.Error() {
t.Errorf("Expected error message to be '%s', but got '%s'", tt.expectedErr.Error(), err.Error())
}
case "Error":
cache := NewProviderResponseCache(ctx, 1*time.Minute)
_, err := cache.Get(tt.key)
if err == nil {
t.Errorf("Expected error, but got nil")
}
if err.Error() != tt.expectedErr.Error() {
t.Errorf("Expected error message to be '%s', but got '%s'", tt.expectedErr.Error(), err.Error())
}
}
})
}
}

0 comments on commit 06a54a8

Please sign in to comment.