diff --git a/topdown/http.go b/topdown/http.go index 84ef7ac5c2..f4bd64f145 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -50,6 +50,12 @@ var allowedKeys = ast.NewSet() var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) +type httpSendKey string + +// httpSendBuiltinCacheKey is the key in the builtin context cache that +// points to the http.send() specific cache resides at. +const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY" + func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { req, err := validateHTTPRequestOperand(args[0], 1) @@ -57,9 +63,17 @@ func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) } - resp, err := executeHTTPRequest(bctx, req) - if err != nil { - return handleHTTPSendErr(bctx, err) + // check if cache already has a response for this query + resp := checkHTTPSendCache(bctx, req) + if resp == nil { + var err error + resp, err = executeHTTPRequest(bctx, req) + if err != nil { + return handleHTTPSendErr(bctx, err) + } + + // add result to cache + insertIntoHTTPSendCache(bctx, req, resp) } return iter(ast.NewTerm(resp)) @@ -295,12 +309,6 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) body = bytes.NewBufferString("") } - // check if cache already has a response for this query - cachedResponse := checkCache(method, url, bctx) - if cachedResponse != nil { - return cachedResponse, nil - } - // create the http request, use the builtin context's context to ensure // the request is cancelled if evaluation is cancelled. req, err := http.NewRequest(method, url, body) @@ -309,8 +317,7 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) } req = req.WithContext(bctx.Context) - // Add custom headers passed from CLI - + // Add custom headers if len(customHeaders) != 0 { if ok, err := addHeaders(req, customHeaders); !ok { return nil, err @@ -358,10 +365,6 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) return nil, err } - // add result to cache - key := getCtxKey(method, url) - bctx.Cache.Put(key, resultObj) - return resultObj, nil } @@ -369,22 +372,41 @@ func isContentTypeJSON(header http.Header) bool { return strings.Contains(header.Get("Content-Type"), "application/json") } -// getCtxKey returns the cache key. -// Key format: _ -func getCtxKey(method string, url string) string { - keyTerms := []string{strings.ToUpper(method), url} - return strings.Join(keyTerms, "_") +// In the BuiltinContext cache we only store a single entry that points to +// our ValueMap which is the "real" http.send() cache. +func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap { + raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey) + if !ok { + // Initialize if it isn't there + cache := ast.NewValueMap() + bctx.Cache.Put(httpSendBuiltinCacheKey, cache) + return cache + } + + cache, ok := raw.(*ast.ValueMap) + if !ok { + return nil + } + return cache } -// checkCache checks for the given key's value in the cache -func checkCache(method string, url string, bctx BuiltinContext) ast.Value { - key := getCtxKey(method, url) +// checkHTTPSendCache checks for the given key's value in the cache +func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + return nil + } + + return requestCache.Get(key) +} - val, ok := bctx.Cache.Get(key) - if ok { - return val.(ast.Value) +func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + // Should never happen.. if it does just skip caching the value + return } - return nil + requestCache.Put(key, value) } func createAllowedKeys() { diff --git a/topdown/http_test.go b/topdown/http_test.go index 64d6cc0942..c793a8de79 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -633,6 +633,130 @@ func TestHTTPRedirectEnable(t *testing.T) { runTopDownTestCase(t, data, "http.send", rule, resultObj.String()) } +func TestHTTPSendCaching(t *testing.T) { + // test server + nextResponse := "{}" + var requests []*http.Request + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r) + w.WriteHeader(http.StatusOK) + w.Write([]byte(nextResponse)) + })) + defer ts.Close() + + // expected result + + var body []interface{} + bodyMap := map[string]string{"id": "1", "firstname": "John"} + body = append(body, bodyMap) + + // run the test + tests := []struct { + note string + ruleTemplate string + body string + response string + expectedReqCount int + }{ + { + note: "http.send GET single", + ruleTemplate: `p = x { http.send({"method": "get", "url": "%URL%", "force_json_decode": true}, r); x = r.body }`, + response: `{"x": 1}`, + expectedReqCount: 1, + }, + { + note: "http.send GET cache hit", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r3 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r1 == r2 + r2 == r3 + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 1, + }, + { + note: "http.send GET cache miss different method", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) + r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true}) + r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + { + note: "http.send GET cache miss different url", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true}) + r2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true}) + r1_2 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + { + note: "http.send GET cache miss different decode opt", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false}) + r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + { + note: "http.send GET cache miss different headers", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}}) + r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 3, + }, + { + note: "http.send POST cache miss different body", + ruleTemplate: `p = x { + r1 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"}) + r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"}) + r1_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"}) # cached + r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + } + + data := loadSmallTestData() + + for _, tc := range tests { + nextResponse = tc.response + requests = nil + runTopDownTestCase(t, data, tc.note, []string{strings.ReplaceAll(tc.ruleTemplate, "%URL%", ts.URL)}, tc.response) + + // Note: The runTopDownTestCase ends up evaluating twice (once with and once without partial + // eval first), so expect 2x the total request count the test case specified. + actualCount := len(requests) / 2 + if actualCount != tc.expectedReqCount { + t.Fatalf("Expected to only get %d requests, got %d", tc.expectedReqCount, actualCount) + } + } +} + func getTestServer() (baseURL string, teardownFn func()) { mux := http.NewServeMux() ts := httptest.NewServer(mux)