Skip to content

Commit

Permalink
Get request callback interception working with the callback server
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay committed Jul 5, 2024
1 parent 26941ae commit 005bb91
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
56 changes: 37 additions & 19 deletions internal/deploy/callback_addon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ func TestCallbackAddon(t *testing.T) {
})

testCases := []struct {
name string
filter string
inner func(t *testing.T, checker *checker)
name string
filter string
needsRequestCallback bool
inner func(t *testing.T, checker *checker)
}{
{
name: "works",
Expand Down Expand Up @@ -234,16 +235,15 @@ func TestCallbackAddon(t *testing.T) {
},
},
{
name: "can block requests and modify response codes and bodies",
filter: "~m PUT",
name: "can block requests and modify response codes and bodies",
filter: "~m PUT",
needsRequestCallback: true,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
OnRequestCallback: func(cd CallbackData) *CallbackResponse {
return &CallbackResponse{
RespondStatusCode: 200,
RespondBody: json.RawMessage(`{
"yep": "ok",
}`),
RespondBody: json.RawMessage(`{"yep": "ok"}`),
}
},
})
Expand Down Expand Up @@ -273,22 +273,30 @@ func TestCallbackAddon(t *testing.T) {
t, deployment.GetConfig().HostnameRunningComplement,
)
callbackURL := cbServer.SetOnResponseCallback(t, func(cd CallbackData) *CallbackResponse {
return checker.onCallback(cd)
return checker.onResponseCallback(cd)
})
var reqCallbackURL string
if tc.needsRequestCallback {
reqCallbackURL = cbServer.SetOnRequestCallback(t, func(cd CallbackData) *CallbackResponse {
return checker.onRequestCallback(cd)
})
}
must.NotError(t, "failed to create callback server", err)
defer cbServer.Close()
mitmClient := deployment.MITM()
mitmOpts := map[string]any{
"callback": map[string]any{
"callback_response_url": callbackURL,
},
callbackOpts := map[string]any{
"callback_response_url": callbackURL,
}
if tc.filter != "" {
cb := mitmOpts["callback"].(map[string]any)
cb["filter"] = tc.filter
mitmOpts["callback"] = cb
callbackOpts["filter"] = tc.filter
}
if reqCallbackURL != "" {
callbackOpts["callback_request_url"] = reqCallbackURL
}
lockID := mitmClient.lockOptions(t, mitmOpts)

mitmClient := deployment.MITM()
lockID := mitmClient.lockOptions(t, map[string]any{
"callback": callbackOpts,
})
tc.inner(t, checker)
mitmClient.unlockOptions(t, lockID)
})
Expand All @@ -312,7 +320,7 @@ type checker struct {
noCallbacks bool
}

func (c *checker) onCallback(cd CallbackData) *CallbackResponse {
func (c *checker) onResponseCallback(cd CallbackData) *CallbackResponse {
c.mu.Lock()
if c.noCallbacks {
ct.Errorf(c.t, "wanted no callbacks but got %+v", cd)
Expand Down Expand Up @@ -349,6 +357,16 @@ func (c *checker) onCallback(cd CallbackData) *CallbackResponse {
return callbackResponse
}

func (c *checker) onRequestCallback(cd CallbackData) *CallbackResponse {
c.mu.Lock()
cb := c.want.OnRequestCallback
c.mu.Unlock()
if cb != nil {
return cb(cd)
}
return nil
}

func (c *checker) expect(want *callbackRequest) {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down
11 changes: 9 additions & 2 deletions tests/mitmproxy_addons/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,18 @@ async def send_callback(self, flow, url: str, body: dict):
headers={"Content-Type": "application/json"},
json=body) as response:
print(f'{datetime.now().strftime("%H:%M:%S.%f")} callback for {flow.request.url} returned HTTP {response.status}')
if response.content_type != 'application/json':
err_response_body = await response.text()
print(f'ERR: callback server returned non-json: {err_response_body}')
raise Exception("callback server content-type: " + response.content_type)
test_response_body = await response.json()
# if the response includes some keys then we are modifying the response on a per-key basis.
if len(test_response_body) > 0:
respond_status_code = test_response_body.get("respond_status_code", flow.response.status_code)
respond_body = test_response_body.get("respond_body", body["response_body"])
# use what fields were provided preferentially.
# For requests: both fields must be provided so the default case won't execute.
# For responses: fields are optional but the default case is always specified.
respond_status_code = test_response_body.get("respond_status_code", body.get("response_code"))
respond_body = test_response_body.get("respond_body", body.get("response_body"))
flow.response = Response.make(
respond_status_code, json.dumps(respond_body),
headers={
Expand Down

0 comments on commit 005bb91

Please sign in to comment.