Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topdown: Make http.send() caching use full request #2067

Merged
merged 1 commit into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 49 additions & 27 deletions topdown/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,30 @@ var allowedKeys = ast.NewSet()

var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url"))

type httpSendKey string

// httpSendBuiltinCacheKey is the key in the builtin context cache that
// points to the http.send() specific cache resides at.
const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY"

func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {

req, err := validateHTTPRequestOperand(args[0], 1)
if err != nil {
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
}

resp, err := executeHTTPRequest(bctx, req)
if err != nil {
return handleHTTPSendErr(bctx, err)
// check if cache already has a response for this query
resp := checkHTTPSendCache(bctx, req)
if resp == nil {
var err error
resp, err = executeHTTPRequest(bctx, req)
if err != nil {
return handleHTTPSendErr(bctx, err)
}

// add result to cache
insertIntoHTTPSendCache(bctx, req, resp)
}

return iter(ast.NewTerm(resp))
Expand Down Expand Up @@ -295,12 +309,6 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error)
body = bytes.NewBufferString("")
}

// check if cache already has a response for this query
cachedResponse := checkCache(method, url, bctx)
if cachedResponse != nil {
return cachedResponse, nil
}

// create the http request, use the builtin context's context to ensure
// the request is cancelled if evaluation is cancelled.
req, err := http.NewRequest(method, url, body)
Expand All @@ -309,8 +317,7 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error)
}
req = req.WithContext(bctx.Context)

// Add custom headers passed from CLI

// Add custom headers
if len(customHeaders) != 0 {
if ok, err := addHeaders(req, customHeaders); !ok {
return nil, err
Expand Down Expand Up @@ -358,33 +365,48 @@ func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error)
return nil, err
}

// add result to cache
key := getCtxKey(method, url)
bctx.Cache.Put(key, resultObj)

return resultObj, nil
}

func isContentTypeJSON(header http.Header) bool {
return strings.Contains(header.Get("Content-Type"), "application/json")
}

// getCtxKey returns the cache key.
// Key format: <METHOD>_<url>
func getCtxKey(method string, url string) string {
keyTerms := []string{strings.ToUpper(method), url}
return strings.Join(keyTerms, "_")
// In the BuiltinContext cache we only store a single entry that points to
// our ValueMap which is the "real" http.send() cache.
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
if !ok {
// Initialize if it isn't there
cache := ast.NewValueMap()
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
return cache
}

cache, ok := raw.(*ast.ValueMap)
if !ok {
return nil
}
return cache
}

// checkCache checks for the given key's value in the cache
func checkCache(method string, url string, bctx BuiltinContext) ast.Value {
key := getCtxKey(method, url)
// checkHTTPSendCache checks for the given key's value in the cache
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
return nil
}

return requestCache.Get(key)
}

val, ok := bctx.Cache.Get(key)
if ok {
return val.(ast.Value)
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
// Should never happen.. if it does just skip caching the value
return
}
return nil
requestCache.Put(key, value)
}

func createAllowedKeys() {
Expand Down
124 changes: 124 additions & 0 deletions topdown/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,130 @@ func TestHTTPRedirectEnable(t *testing.T) {
runTopDownTestCase(t, data, "http.send", rule, resultObj.String())
}

func TestHTTPSendCaching(t *testing.T) {
// test server
nextResponse := "{}"
var requests []*http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)
w.WriteHeader(http.StatusOK)
w.Write([]byte(nextResponse))
}))
defer ts.Close()

// expected result

var body []interface{}
bodyMap := map[string]string{"id": "1", "firstname": "John"}
body = append(body, bodyMap)

// run the test
tests := []struct {
patrick-east marked this conversation as resolved.
Show resolved Hide resolved
note string
ruleTemplate string
body string
response string
expectedReqCount int
}{
{
note: "http.send GET single",
ruleTemplate: `p = x { http.send({"method": "get", "url": "%URL%", "force_json_decode": true}, r); x = r.body }`,
response: `{"x": 1}`,
expectedReqCount: 1,
},
{
note: "http.send GET cache hit",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r3 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r1 == r2
r2 == r3
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 1,
},
{
note: "http.send GET cache miss different method",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true})
r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true})
r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
{
note: "http.send GET cache miss different url",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true})
r2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true})
r1_2 = http.send({"method": "get", "url": "%URL%/foo", "force_json_decode": true}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%/bar", "force_json_decode": true}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
{
note: "http.send GET cache miss different decode opt",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false})
r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": false}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
{
note: "http.send GET cache miss different headers",
ruleTemplate: `p = x {
r1 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}})
r2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}})
r1_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h1": "v1", "h2": "v2"}}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}}) # cached
r2_2 = http.send({"method": "get", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 3,
},
{
note: "http.send POST cache miss different body",
ruleTemplate: `p = x {
r1 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"})
r2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"})
r1_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v2"}, "body": "{\"foo\": 42}"}) # cached
r2_2 = http.send({"method": "post", "url": "%URL%", "force_json_decode": true, "headers": {"h2": "v3"}, "body": "{\"foo\": 23}"}) # cached
x = r1.body
}`,
response: `{"x": 1}`,
expectedReqCount: 2,
},
}

data := loadSmallTestData()

for _, tc := range tests {
nextResponse = tc.response
requests = nil
runTopDownTestCase(t, data, tc.note, []string{strings.ReplaceAll(tc.ruleTemplate, "%URL%", ts.URL)}, tc.response)

// Note: The runTopDownTestCase ends up evaluating twice (once with and once without partial
// eval first), so expect 2x the total request count the test case specified.
actualCount := len(requests) / 2
if actualCount != tc.expectedReqCount {
t.Fatalf("Expected to only get %d requests, got %d", tc.expectedReqCount, actualCount)
}
}
}

func getTestServer() (baseURL string, teardownFn func()) {
mux := http.NewServeMux()
ts := httptest.NewServer(mux)
Expand Down