Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mitmproxy addon: refactor the callback addon #100

Merged
merged 4 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DS_Store
logs
js-sdk/node_modules
js-sdk/dist
/internal/api/js/chrome/dist
Expand All @@ -11,4 +12,4 @@ __pycache__
/tests/logs/
/tests/chromedp/
/tests/rust_storage/
_temp_rust_sdk
_temp_rust_sdk
284 changes: 284 additions & 0 deletions internal/deploy/callback_addon_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
package deploy

import (
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"testing"
"time"

"github.com/matrix-org/complement"
"github.com/matrix-org/complement/ct"
"github.com/matrix-org/complement/helpers"
"github.com/matrix-org/complement/must"
)

func TestMain(m *testing.M) {
complement.TestMain(m, "deploy")
}

// Test the functionality of the mitmproxy addon 'callback'.
func TestCallbackAddon(t *testing.T) {
workingDir, err := os.Getwd()
must.NotError(t, "failed to get working dir", err)
mitmProxyAddonsDir := filepath.Join(workingDir, "../../tests/mitmproxy_addons")
deployment := RunNewDeployment(t, mitmProxyAddonsDir, "")
defer deployment.Teardown()
client := deployment.Register(t, "hs1", helpers.RegistrationOpts{
LocalpartSuffix: "callback",
})
other := deployment.Register(t, "hs1", helpers.RegistrationOpts{
LocalpartSuffix: "callback",
})

testCases := []struct {
name string
filter string
inner func(t *testing.T, checker *checker)
}{
{
name: "works",
filter: "",
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
Method: "GET",
PathContains: "_matrix/client/v3/capabilities",
AccessToken: client.AccessToken,
ResponseCode: 200,
})
client.GetCapabilities(t)
checker.wait()
},
},
{
name: "can be filtered by path",
filter: "~u capabilities",
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
Method: "GET",
PathContains: "_matrix/client/v3/capabilities",
AccessToken: client.AccessToken,
ResponseCode: 200,
})
client.GetCapabilities(t)
checker.wait()
checker.expectNoCallbacks(true)
client.GetGlobalAccountData(t, "this_does_a_get")
checker.expectNoCallbacks(false)
},
},
{
name: "can be filtered by method",
filter: "~m GET",
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
Method: "GET",
PathContains: "_matrix/client/v3/capabilities",
AccessToken: client.AccessToken,
ResponseCode: 200,
})
client.GetCapabilities(t)
checker.wait()
checker.expectNoCallbacks(true)
client.MustSetGlobalAccountData(t, "this_does_a_put", map[string]any{})
checker.expectNoCallbacks(false)
},
},
{
name: "can be filtered by access token",
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
Method: "GET",
PathContains: "_matrix/client/v3/capabilities",
AccessToken: client.AccessToken,
ResponseCode: 200,
})
client.GetCapabilities(t)
checker.wait()
checker.expectNoCallbacks(true)
other.GetCapabilities(t)
checker.expectNoCallbacks(false)
},
},
{
name: "can be filtered by combinations of method path and access token",
filter: "~m GET ~u capabilities ~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
checker.expect(&callbackRequest{
Method: "GET",
PathContains: "_matrix/client/v3/capabilities",
AccessToken: client.AccessToken,
ResponseCode: 200,
})
client.GetCapabilities(t)
checker.wait()
checker.expectNoCallbacks(true)
other.GetCapabilities(t)
checker.expectNoCallbacks(false)
},
},
{
// ensure that if we tarpit a request it doesn't tarpit unrelated requests
name: "processes callbacks concurrently",
filter: "~hq " + client.AccessToken,
inner: func(t *testing.T, checker *checker) {
// signal when to make the unrelated request
signalSendUnrelatedRequest := make(chan bool)
signalTestFinished := make(chan bool)
checker.expect(&callbackRequest{
OnCallback: func(cd CallbackData) {
if strings.Contains(cd.URL, "capabilities") {
close(signalSendUnrelatedRequest) // send the signal to make the unrelated request
time.Sleep(time.Second) // tarpit this request
close(signalTestFinished) // test is done, cleanup
}
},
})
beforeSendingRequests := time.Now()
// send the tarpit request without waiting
go func() {
client.GetCapabilities(t)
}()
select {
case <-signalSendUnrelatedRequest:
// send the unrelated request
t.Logf("received signal @ %v", time.Since(beforeSendingRequests))
client.GetGlobalAccountData(t, "this_does_a_get")
t.Logf("received unrelated response @ %v", time.Since(beforeSendingRequests))
case <-time.After(time.Second):
ct.Errorf(t, "did not receive signal to send unrelated request")
return
}
since := time.Since(beforeSendingRequests)
if since > time.Second {
ct.Errorf(t, "unrelated request was tarpitted, took %v", since)
return
}

// now wait for the tarpit
select {
case <-signalTestFinished:
case <-time.After(2 * time.Second):
ct.Errorf(t, "timed out waiting for tarpit response")
}
},
},

// TODO: migrate functionality from status_code addon
// TODO: can modify response codes
// TODO: can modify response bodies
// TODO: can block requests
// TODO: can block responses
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := &checker{
t: t,
ch: make(chan callbackRequest, 3),
mu: &sync.Mutex{},
}
callbackURL, close := NewCallbackServer(
t, deployment.GetConfig().HostnameRunningComplement,
func(cd CallbackData) {
checker.onCallback(cd)
},
)
defer close()
mitmClient := deployment.MITM()
mitmOpts := map[string]any{
"callback": map[string]any{
"callback_url": callbackURL,
},
}
if tc.filter != "" {
cb := mitmOpts["callback"].(map[string]any)
cb["filter"] = tc.filter
mitmOpts["callback"] = cb
}
lockID := mitmClient.lockOptions(t, mitmOpts)
tc.inner(t, checker)
mitmClient.unlockOptions(t, lockID)
})
}
}

type callbackRequest struct {
Method string
PathContains string
AccessToken string
ResponseCode int
OnCallback func(cd CallbackData)
}

type checker struct {
t *testing.T
ch chan callbackRequest
mu *sync.Mutex
want *callbackRequest
noCallbacks bool
}

func (c *checker) onCallback(cd CallbackData) {
c.mu.Lock()
if c.noCallbacks {
ct.Errorf(c.t, "wanted no callbacks but got %+v", cd)
}
if c.want == nil {
c.mu.Unlock()
return
}
if c.want.AccessToken != "" {
must.Equal(c.t, cd.AccessToken, c.want.AccessToken, "access token mismatch")
}
if c.want.Method != "" {
must.Equal(c.t, cd.Method, c.want.Method, "HTTP method mismatch")
}
if c.want.PathContains != "" {
must.Equal(c.t, strings.Contains(cd.URL, c.want.PathContains), true,
fmt.Sprintf("path mismatch, got %v want partial %v", cd.URL, c.want.PathContains),
)
}
if c.want.ResponseCode != 0 {
must.Equal(c.t, cd.ResponseCode, c.want.ResponseCode, "response code mismatch")
}

customCallback := c.want.OnCallback
// unlock early so we don't block other requests, as custom callbacks are generally
// used for testing tarpitting.
c.mu.Unlock()
if customCallback != nil {
customCallback(cd)
}
// signal that we processed the callback
c.ch <- *c.want
}

func (c *checker) expect(want *callbackRequest) {
c.mu.Lock()
defer c.mu.Unlock()
c.want = want
}

func (c *checker) expectNoCallbacks(noCallbacks bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.noCallbacks = noCallbacks
}

func (c *checker) wait() {
select {
case got := <-c.ch:
if !reflect.DeepEqual(got, *c.want) {
ct.Fatalf(c.t, "checker: got success from a different request: did you forget to wait?"+
" Received %+v but expected +%v", got, c.want)
}
return
case <-time.After(time.Second):
ct.Fatalf(c.t, "timed out waiting for %+v", c.want)
}
}
1 change: 1 addition & 0 deletions internal/deploy/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ func externalURL(t *testing.T, c testcontainers.Container, exposedPort string) s
}

func writeContainerLogs(readCloser io.ReadCloser, filename string) error {
os.Mkdir("./logs", os.ModePerm) // ignore error, we don't care if it already exists
w, err := os.Create("./logs/" + filename)
if err != nil {
return fmt.Errorf("os.Create: %s", err)
Expand Down
10 changes: 10 additions & 0 deletions tests/mitmproxy_addons/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from mitmproxy.addons import asgiapp
import subprocess
import sys

# some addons need non-std packages.
# Rather than try to bundle in `pip install` commands in the CMD section of the Dockerfile,
# just install them when the addon loads.
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install("aiohttp")

from callback import Callback
from status_code import StatusCode
Expand Down
44 changes: 21 additions & 23 deletions tests/mitmproxy_addons/callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Optional
import asyncio
import aiohttp
import json

from mitmproxy import ctx, flowfilter
from mitmproxy.http import Response
from controller import MITM_DOMAIN_NAME
from urllib.request import urlopen, Request
from urllib.error import HTTPError, URLError
from datetime import datetime

# Callback will intercept a response and send a POST request to the provided callback_url, with
# the following JSON object. Supports filters: https://docs.mitmproxy.org/stable/concepts-filters/
Expand Down Expand Up @@ -56,7 +59,7 @@ def configure(self, updates):
else:
self.filter = self.matchall

def response(self, flow):
async def response(self, flow):
# always ignore the controller
if flow.request.pretty_host == MITM_DOMAIN_NAME:
return
Expand All @@ -71,26 +74,21 @@ def response(self, flow):
res_body = flow.response.json()
except:
res_body = None
data = json.dumps({
"method": flow.request.method,
"access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "),
"url": flow.request.url,
"response_code": flow.response.status_code,
"request_body": req_body,
"response_body": res_body,
})
request = Request(
self.config["callback_url"],
headers={"Content-Type": "application/json"},
data=data.encode("utf-8"),
)
print(f'{datetime.now().strftime("%H:%M:%S.%f")} hitting callback for {flow.request.url}')
try:
with urlopen(request, timeout=10) as response:
print(f"callback returned HTTP {response.status}")
return response.read(), response
except HTTPError as error:
print(f"ERR: callback returned {error.status} {error.reason}")
except URLError as error:
print(f"ERR: callback returned {error.reason}")
except TimeoutError:
print(f"ERR: callback request timed out")
# use asyncio so we don't block other unrelated requests from being processed
async with aiohttp.request(
method="POST",url=self.config["callback_url"], timeout=aiohttp.ClientTimeout(total=10),
headers={"Content-Type": "application/json"},
json={
"method": flow.request.method,
"access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "),
"url": flow.request.url,
"response_code": flow.response.status_code,
"request_body": req_body,
"response_body": res_body,
}) as response:
print(f'{datetime.now().strftime("%H:%M:%S.%f")} callback for {flow.request.url} returned HTTP {response.status}')
return
kegsay marked this conversation as resolved.
Show resolved Hide resolved
except Exception as error:
print(f"ERR: callback returned {error}")