Skip to content

Commit

Permalink
topdown: Make http.send() caching use full request
Browse files Browse the repository at this point in the history
Previously it just used a string key with the method and url. This
causes problems if headers or body contents affect the response, and
users were seeing cached responses returned for different requests.

The new version takes into account the entire parameters object that
is provided to the builtin function.

Fixes: #1980
Signed-off-by: Patrick East <east.patrick@gmail.com>
  • Loading branch information
patrick-east committed Feb 19, 2020
1 parent d27d687 commit c6a587f
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 27 deletions.
76 changes: 49 additions & 27 deletions topdown/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,30 @@ var allowedKeys = ast.NewSet()

var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url"))

type httpSendKey string

// httpSendBuiltinCacheKey is the key in the builtin context cache that
// points to the http.send() specific cache resides at.
const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY"

func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {

req, err := validateHTTPRequestOperand(args[0], 1)
if err != nil {
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
}

resp, err := executeHTTPRequest(bctx, req)
if err != nil {
return handleHTTPSendErr(bctx, err)
// check if cache already has a response for this query
resp := checkHTTPSendCache(bctx, req)
if resp == nil {
var err error
resp, err = executeHTTPRequest(bctx, req)
if err != nil {
return handleHTTPSendErr(bctx, err)
}

// add result to cache
insertIntoHTTPSendCache(bctx, req, resp)
}

return iter(ast.NewTerm(resp))
Expand Down Expand Up @@ -295,12 +309,6 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error)
body = bytes.NewBufferString("")
}

// check if cache already has a response for this query
cachedResponse := checkCache(method, url, bctx)
if cachedResponse != nil {
return cachedResponse, nil
}

// create the http request, use the builtin context's context to ensure
// the request is cancelled if evaluation is cancelled.
req, err := http.NewRequest(method, url, body)
Expand All @@ -309,8 +317,7 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error)
}
req = req.WithContext(bctx.Context)

// Add custom headers passed from CLI

// Add custom headers
if len(customHeaders) != 0 {
if ok, err := addHeaders(req, customHeaders); !ok {
return nil, err
Expand Down Expand Up @@ -358,33 +365,48 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error)
return nil, err
}

// add result to cache
key := getCtxKey(method, url)
bctx.Cache.Put(key, resultObj)

return resultObj, nil
}

func isContentTypeJSON(header http.Header) bool {
return strings.Contains(header.Get("Content-Type"), "application/json")
}

// getCtxKey returns the cache key.
// Key format: <METHOD>_<url>
func getCtxKey(method string, url string) string {
keyTerms := []string{strings.ToUpper(method), url}
return strings.Join(keyTerms, "_")
// In the BuiltinContext cache we only store a single entry that points to
// our ValueMap which is the "real" http.send() cache.
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
if !ok {
// Initialize if it isn't there
cache := ast.NewValueMap()
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
return cache
}

cache, ok := raw.(*ast.ValueMap)
if !ok {
return nil
}
return cache
}

// checkCache checks for the given key's value in the cache
func checkCache(method string, url string, bctx BuiltinContext) ast.Value {
key := getCtxKey(method, url)
// checkHTTPSendCache checks for the given key's value in the cache
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
return nil
}

return requestCache.Get(key)
}

val, ok := bctx.Cache.Get(key)
if ok {
return val.(ast.Value)
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
// Should never happen.. if it does just skip caching the value
return
}
return nil
requestCache.Put(key, value)
}

func createAllowedKeys() {
Expand Down
124 changes: 124 additions & 0 deletions topdown/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,130 @@ func TestHTTPRedirectEnable(t *testing.T) {
runTopDownTestCase(t, data, "http.send", rule, resultObj.String())
}

func TestHTTPSendCaching(t *testing.T) {
// test server
nextResponse := "{}"
var requests []*http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)
w.WriteHeader(http.StatusOK)
w.Write([]byte(nextResponse))
}))
defer ts.Close()

// expected result

var body []interface{}
bodyMap := map[string]string{"id": "1", "firstname": "John"}
body = append(body, bodyMap)

// run the test
tests := []struct {
note string
ruleTemplate string
body string
response string
expectedReqCount int
}{
{
note: "http.send GET single",
ruleTemplate: `p = x { http.send({"method": "get", "url": "%URL%", "force_json_decode": true}, r); x = r.body }`,
response: `{"x": 1}`,
expectedReqCount: 1,
},
{
note: "http.send GET cache hit",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r3 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r1 == r2
r2 == r3
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 1,
},
{
note: "http.send GET cache miss different method",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true})
r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true})
r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
{
note: "http.send GET cache miss different url",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true})
r2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true})
r1_2 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
{
note: "http.send GET cache miss different decode opt",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false})
r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
{
note: "http.send GET cache miss different headers",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}})
r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 3,
},
{
note: "http.send POST cache miss different body",
ruleTemplate: `p = x {
r1 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"})
r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"})
r1_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"}) # cached
r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
}

data := loadSmallTestData()

for _, tc := range tests {
nextResponse = tc.response
requests = nil
runTopDownTestCase(t, data, tc.note, []string{strings.ReplaceAll(tc.ruleTemplate, "%URL%", ts.URL)}, tc.response)

// Note: The runTopDownTestCase ends up evaluating twice (once with and once without partial
// eval first), so expect 2x the total request count the test case specified.
actualCount := len(requests) / 2
if actualCount != tc.expectedReqCount {
t.Fatalf("Expected to only get %d requests, got %d", tc.expectedReqCount, actualCount)
}
}
}

func getTestServer() (baseURL string, teardownFn func()) {
mux := http.NewServeMux()
ts := httptest.NewServer(mux)
Expand Down

0 comments on commit c6a587f

Please sign in to comment.