diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 20cc8ef..f656a8a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -113,6 +113,8 @@ jobs: #env: #DOCKER_BUILDKIT: 1 - name: "Run Complement-Crypto unit tests" + env: + COMPLEMENT_BASE_IMAGE: homeserver run: | export LIBRARY_PATH="$(pwd)/rust-sdk/target/debug" export LD_LIBRARY_PATH="$(pwd)/rust-sdk/target/debug" diff --git a/.gitignore b/.gitignore index 60df069..728c285 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +logs js-sdk/node_modules js-sdk/dist /internal/api/js/chrome/dist @@ -11,4 +12,4 @@ __pycache__ /tests/logs/ /tests/chromedp/ /tests/rust_storage/ -_temp_rust_sdk \ No newline at end of file +_temp_rust_sdk diff --git a/internal/deploy/callback_addon_test.go b/internal/deploy/callback_addon_test.go new file mode 100644 index 0000000..2e49b0d --- /dev/null +++ b/internal/deploy/callback_addon_test.go @@ -0,0 +1,284 @@ +package deploy + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/matrix-org/complement" + "github.com/matrix-org/complement/ct" + "github.com/matrix-org/complement/helpers" + "github.com/matrix-org/complement/must" +) + +func TestMain(m *testing.M) { + complement.TestMain(m, "deploy") +} + +// Test the functionality of the mitmproxy addon 'callback'. +func TestCallbackAddon(t *testing.T) { + workingDir, err := os.Getwd() + must.NotError(t, "failed to get working dir", err) + mitmProxyAddonsDir := filepath.Join(workingDir, "../../tests/mitmproxy_addons") + deployment := RunNewDeployment(t, mitmProxyAddonsDir, "") + defer deployment.Teardown() + client := deployment.Register(t, "hs1", helpers.RegistrationOpts{ + LocalpartSuffix: "callback", + }) + other := deployment.Register(t, "hs1", helpers.RegistrationOpts{ + LocalpartSuffix: "callback", + }) + + testCases := []struct { + name string + filter string + inner func(t *testing.T, checker *checker) + }{ + { + name: "works", + filter: "", + inner: func(t *testing.T, checker *checker) { + checker.expect(&callbackRequest{ + Method: "GET", + PathContains: "_matrix/client/v3/capabilities", + AccessToken: client.AccessToken, + ResponseCode: 200, + }) + client.GetCapabilities(t) + checker.wait() + }, + }, + { + name: "can be filtered by path", + filter: "~u capabilities", + inner: func(t *testing.T, checker *checker) { + checker.expect(&callbackRequest{ + Method: "GET", + PathContains: "_matrix/client/v3/capabilities", + AccessToken: client.AccessToken, + ResponseCode: 200, + }) + client.GetCapabilities(t) + checker.wait() + checker.expectNoCallbacks(true) + client.GetGlobalAccountData(t, "this_does_a_get") + checker.expectNoCallbacks(false) + }, + }, + { + name: "can be filtered by method", + filter: "~m GET", + inner: func(t *testing.T, checker *checker) { + checker.expect(&callbackRequest{ + Method: "GET", + PathContains: "_matrix/client/v3/capabilities", + AccessToken: client.AccessToken, + ResponseCode: 200, + }) + client.GetCapabilities(t) + checker.wait() + checker.expectNoCallbacks(true) + client.MustSetGlobalAccountData(t, "this_does_a_put", map[string]any{}) + checker.expectNoCallbacks(false) + }, + }, + { + name: "can be filtered by access token", + filter: "~hq " + client.AccessToken, + inner: func(t *testing.T, checker *checker) { + checker.expect(&callbackRequest{ + Method: "GET", + PathContains: "_matrix/client/v3/capabilities", + AccessToken: client.AccessToken, + ResponseCode: 200, + }) + client.GetCapabilities(t) + checker.wait() + checker.expectNoCallbacks(true) + other.GetCapabilities(t) + checker.expectNoCallbacks(false) + }, + }, + { + name: "can be filtered by combinations of method path and access token", + filter: "~m GET ~u capabilities ~hq " + client.AccessToken, + inner: func(t *testing.T, checker *checker) { + checker.expect(&callbackRequest{ + Method: "GET", + PathContains: "_matrix/client/v3/capabilities", + AccessToken: client.AccessToken, + ResponseCode: 200, + }) + client.GetCapabilities(t) + checker.wait() + checker.expectNoCallbacks(true) + other.GetCapabilities(t) + checker.expectNoCallbacks(false) + }, + }, + { + // ensure that if we tarpit a request it doesn't tarpit unrelated requests + name: "processes callbacks concurrently", + filter: "~hq " + client.AccessToken, + inner: func(t *testing.T, checker *checker) { + // signal when to make the unrelated request + signalSendUnrelatedRequest := make(chan bool) + signalTestFinished := make(chan bool) + checker.expect(&callbackRequest{ + OnCallback: func(cd CallbackData) { + if strings.Contains(cd.URL, "capabilities") { + close(signalSendUnrelatedRequest) // send the signal to make the unrelated request + time.Sleep(time.Second) // tarpit this request + close(signalTestFinished) // test is done, cleanup + } + }, + }) + beforeSendingRequests := time.Now() + // send the tarpit request without waiting + go func() { + client.GetCapabilities(t) + }() + select { + case <-signalSendUnrelatedRequest: + // send the unrelated request + t.Logf("received signal @ %v", time.Since(beforeSendingRequests)) + client.GetGlobalAccountData(t, "this_does_a_get") + t.Logf("received unrelated response @ %v", time.Since(beforeSendingRequests)) + case <-time.After(time.Second): + ct.Errorf(t, "did not receive signal to send unrelated request") + return + } + since := time.Since(beforeSendingRequests) + if since > time.Second { + ct.Errorf(t, "unrelated request was tarpitted, took %v", since) + return + } + + // now wait for the tarpit + select { + case <-signalTestFinished: + case <-time.After(2 * time.Second): + ct.Errorf(t, "timed out waiting for tarpit response") + } + }, + }, + + // TODO: migrate functionality from status_code addon + // TODO: can modify response codes + // TODO: can modify response bodies + // TODO: can block requests + // TODO: can block responses + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := &checker{ + t: t, + ch: make(chan callbackRequest, 3), + mu: &sync.Mutex{}, + } + callbackURL, close := NewCallbackServer( + t, deployment.GetConfig().HostnameRunningComplement, + func(cd CallbackData) { + checker.onCallback(cd) + }, + ) + defer close() + mitmClient := deployment.MITM() + mitmOpts := map[string]any{ + "callback": map[string]any{ + "callback_url": callbackURL, + }, + } + if tc.filter != "" { + cb := mitmOpts["callback"].(map[string]any) + cb["filter"] = tc.filter + mitmOpts["callback"] = cb + } + lockID := mitmClient.lockOptions(t, mitmOpts) + tc.inner(t, checker) + mitmClient.unlockOptions(t, lockID) + }) + } +} + +type callbackRequest struct { + Method string + PathContains string + AccessToken string + ResponseCode int + OnCallback func(cd CallbackData) +} + +type checker struct { + t *testing.T + ch chan callbackRequest + mu *sync.Mutex + want *callbackRequest + noCallbacks bool +} + +func (c *checker) onCallback(cd CallbackData) { + c.mu.Lock() + if c.noCallbacks { + ct.Errorf(c.t, "wanted no callbacks but got %+v", cd) + } + if c.want == nil { + c.mu.Unlock() + return + } + if c.want.AccessToken != "" { + must.Equal(c.t, cd.AccessToken, c.want.AccessToken, "access token mismatch") + } + if c.want.Method != "" { + must.Equal(c.t, cd.Method, c.want.Method, "HTTP method mismatch") + } + if c.want.PathContains != "" { + must.Equal(c.t, strings.Contains(cd.URL, c.want.PathContains), true, + fmt.Sprintf("path mismatch, got %v want partial %v", cd.URL, c.want.PathContains), + ) + } + if c.want.ResponseCode != 0 { + must.Equal(c.t, cd.ResponseCode, c.want.ResponseCode, "response code mismatch") + } + + customCallback := c.want.OnCallback + // unlock early so we don't block other requests, as custom callbacks are generally + // used for testing tarpitting. + c.mu.Unlock() + if customCallback != nil { + customCallback(cd) + } + // signal that we processed the callback + c.ch <- *c.want +} + +func (c *checker) expect(want *callbackRequest) { + c.mu.Lock() + defer c.mu.Unlock() + c.want = want +} + +func (c *checker) expectNoCallbacks(noCallbacks bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.noCallbacks = noCallbacks +} + +func (c *checker) wait() { + select { + case got := <-c.ch: + if !reflect.DeepEqual(got, *c.want) { + ct.Fatalf(c.t, "checker: got success from a different request: did you forget to wait?"+ + " Received %+v but expected +%v", got, c.want) + } + return + case <-time.After(time.Second): + ct.Fatalf(c.t, "timed out waiting for %+v", c.want) + } +} diff --git a/internal/deploy/deploy.go b/internal/deploy/deploy.go index 8d3de43..f184b9d 100644 --- a/internal/deploy/deploy.go +++ b/internal/deploy/deploy.go @@ -361,6 +361,7 @@ func externalURL(t *testing.T, c testcontainers.Container, exposedPort string) s } func writeContainerLogs(readCloser io.ReadCloser, filename string) error { + os.Mkdir("./logs", os.ModePerm) // ignore error, we don't care if it already exists w, err := os.Create("./logs/" + filename) if err != nil { return fmt.Errorf("os.Create: %s", err) diff --git a/tests/mitmproxy_addons/__init__.py b/tests/mitmproxy_addons/__init__.py index d294d91..91c4367 100644 --- a/tests/mitmproxy_addons/__init__.py +++ b/tests/mitmproxy_addons/__init__.py @@ -1,4 +1,14 @@ from mitmproxy.addons import asgiapp +import subprocess +import sys + +# some addons need non-std packages. +# Rather than try to bundle in `pip install` commands in the CMD section of the Dockerfile, +# just install them when the addon loads. +def install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + +install("aiohttp") from callback import Callback from status_code import StatusCode diff --git a/tests/mitmproxy_addons/callback.py b/tests/mitmproxy_addons/callback.py index 03e1516..2551e05 100644 --- a/tests/mitmproxy_addons/callback.py +++ b/tests/mitmproxy_addons/callback.py @@ -1,4 +1,6 @@ from typing import Optional +import asyncio +import aiohttp import json from mitmproxy import ctx, flowfilter @@ -6,6 +8,7 @@ from controller import MITM_DOMAIN_NAME from urllib.request import urlopen, Request from urllib.error import HTTPError, URLError +from datetime import datetime # Callback will intercept a response and send a POST request to the provided callback_url, with # the following JSON object. Supports filters: https://docs.mitmproxy.org/stable/concepts-filters/ @@ -56,7 +59,7 @@ def configure(self, updates): else: self.filter = self.matchall - def response(self, flow): + async def response(self, flow): # always ignore the controller if flow.request.pretty_host == MITM_DOMAIN_NAME: return @@ -71,26 +74,20 @@ def response(self, flow): res_body = flow.response.json() except: res_body = None - data = json.dumps({ - "method": flow.request.method, - "access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "), - "url": flow.request.url, - "response_code": flow.response.status_code, - "request_body": req_body, - "response_body": res_body, - }) - request = Request( - self.config["callback_url"], - headers={"Content-Type": "application/json"}, - data=data.encode("utf-8"), - ) + print(f'{datetime.now().strftime("%H:%M:%S.%f")} hitting callback for {flow.request.url}') try: - with urlopen(request, timeout=10) as response: - print(f"callback returned HTTP {response.status}") - return response.read(), response - except HTTPError as error: - print(f"ERR: callback returned {error.status} {error.reason}") - except URLError as error: - print(f"ERR: callback returned {error.reason}") - except TimeoutError: - print(f"ERR: callback request timed out") + # use asyncio so we don't block other unrelated requests from being processed + async with aiohttp.request( + method="POST",url=self.config["callback_url"], timeout=aiohttp.ClientTimeout(total=10), + headers={"Content-Type": "application/json"}, + json={ + "method": flow.request.method, + "access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "), + "url": flow.request.url, + "response_code": flow.response.status_code, + "request_body": req_body, + "response_body": res_body, + }) as response: + print(f'{datetime.now().strftime("%H:%M:%S.%f")} callback for {flow.request.url} returned HTTP {response.status}') + except Exception as error: + print(f"ERR: callback returned {error}")