From c6a587f166c29a215b84cb0f6b681c9998b0e2f2 Mon Sep 17 00:00:00 2001 From: Patrick East Date: Wed, 5 Feb 2020 16:49:43 -0800 Subject: [PATCH] topdown: Make http.send() caching use full request Previously it just used a string key with the method and url. This causes problems if headers or body contents affect the response, and users were seeing cached responses returned for different requests. The new version takes into account the entire parameters object that is provided to the builtin function. Fixes: #1980 Signed-off-by: Patrick East --- topdown/http.go | 76 ++++++++++++++++---------- topdown/http_test.go | 124 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 27 deletions(-) diff --git a/topdown/http.go b/topdown/http.go index 84ef7ac5c2..f4bd64f145 100644 --- a/topdown/http.go +++ b/topdown/http.go @@ -50,6 +50,12 @@ var allowedKeys = ast.NewSet() var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url")) +type httpSendKey string + +// httpSendBuiltinCacheKey is the key in the builtin context cache that +// points to the http.send() specific cache resides at. +const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY" + func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { req, err := validateHTTPRequestOperand(args[0], 1) @@ -57,9 +63,17 @@ func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err) } - resp, err := executeHTTPRequest(bctx, req) - if err != nil { - return handleHTTPSendErr(bctx, err) + // check if cache already has a response for this query + resp := checkHTTPSendCache(bctx, req) + if resp == nil { + var err error + resp, err = executeHTTPRequest(bctx, req) + if err != nil { + return handleHTTPSendErr(bctx, err) + } + + // add result to cache + insertIntoHTTPSendCache(bctx, req, resp) } return iter(ast.NewTerm(resp)) @@ -295,12 +309,6 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) body = bytes.NewBufferString("") } - // check if cache already has a response for this query - cachedResponse := checkCache(method, url, bctx) - if cachedResponse != nil { - return cachedResponse, nil - } - // create the http request, use the builtin context's context to ensure // the request is cancelled if evaluation is cancelled. req, err := http.NewRequest(method, url, body) @@ -309,8 +317,7 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) } req = req.WithContext(bctx.Context) - // Add custom headers passed from CLI - + // Add custom headers if len(customHeaders) != 0 { if ok, err := addHeaders(req, customHeaders); !ok { return nil, err @@ -358,10 +365,6 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) return nil, err } - // add result to cache - key := getCtxKey(method, url) - bctx.Cache.Put(key, resultObj) - return resultObj, nil } @@ -369,22 +372,41 @@ func isContentTypeJSON(header http.Header) bool { return strings.Contains(header.Get("Content-Type"), "application/json") } -// getCtxKey returns the cache key. -// Key format: _ -func getCtxKey(method string, url string) string { - keyTerms := []string{strings.ToUpper(method), url} - return strings.Join(keyTerms, "_") +// In the BuiltinContext cache we only store a single entry that points to +// our ValueMap which is the "real" http.send() cache. +func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap { + raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey) + if !ok { + // Initialize if it isn't there + cache := ast.NewValueMap() + bctx.Cache.Put(httpSendBuiltinCacheKey, cache) + return cache + } + + cache, ok := raw.(*ast.ValueMap) + if !ok { + return nil + } + return cache } -// checkCache checks for the given key's value in the cache -func checkCache(method string, url string, bctx BuiltinContext) ast.Value { - key := getCtxKey(method, url) +// checkHTTPSendCache checks for the given key's value in the cache +func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + return nil + } + + return requestCache.Get(key) +} - val, ok := bctx.Cache.Get(key) - if ok { - return val.(ast.Value) +func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) { + requestCache := getHTTPSendCache(bctx) + if requestCache == nil { + // Should never happen.. if it does just skip caching the value + return } - return nil + requestCache.Put(key, value) } func createAllowedKeys() { diff --git a/topdown/http_test.go b/topdown/http_test.go index 64d6cc0942..c793a8de79 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -633,6 +633,130 @@ func TestHTTPRedirectEnable(t *testing.T) { runTopDownTestCase(t, data, "http.send", rule, resultObj.String()) } +func TestHTTPSendCaching(t *testing.T) { + // test server + nextResponse := "{}" + var requests []*http.Request + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r) + w.WriteHeader(http.StatusOK) + w.Write([]byte(nextResponse)) + })) + defer ts.Close() + + // expected result + + var body []interface{} + bodyMap := map[string]string{"id": "1", "firstname": "John"} + body = append(body, bodyMap) + + // run the test + tests := []struct { + note string + ruleTemplate string + body string + response string + expectedReqCount int + }{ + { + note: "http.send GET single", + ruleTemplate: `p = x { http.send({"method": "get", "url": "%URL%", "force_json_decode": true}, r); x = r.body }`, + response: `{"x": 1}`, + expectedReqCount: 1, + }, + { + note: "http.send GET cache hit", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r3 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r1 == r2 + r2 == r3 + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 1, + }, + { + note: "http.send GET cache miss different method", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) + r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true}) + r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + { + note: "http.send GET cache miss different url", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true}) + r2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true}) + r1_2 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + { + note: "http.send GET cache miss different decode opt", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false}) + r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + { + note: "http.send GET cache miss different headers", + ruleTemplate: `p = x { + r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}}) + r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}}) + r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}}) # cached + r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 3, + }, + { + note: "http.send POST cache miss different body", + ruleTemplate: `p = x { + r1 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"}) + r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"}) + r1_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"}) # cached + r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"}) # cached + x = r1.body + }`, + response: `{"x": 1}`, + expectedReqCount: 2, + }, + } + + data := loadSmallTestData() + + for _, tc := range tests { + nextResponse = tc.response + requests = nil + runTopDownTestCase(t, data, tc.note, []string{strings.ReplaceAll(tc.ruleTemplate, "%URL%", ts.URL)}, tc.response) + + // Note: The runTopDownTestCase ends up evaluating twice (once with and once without partial + // eval first), so expect 2x the total request count the test case specified. + actualCount := len(requests) / 2 + if actualCount != tc.expectedReqCount { + t.Fatalf("Expected to only get %d requests, got %d", tc.expectedReqCount, actualCount) + } + } +} + func getTestServer() (baseURL string, teardownFn func()) { mux := http.NewServeMux() ts := httptest.NewServer(mux)