From 8ff8d07128325dc2894b9be47d8fdce548bef3fb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Sep 2023 23:18:38 +0300 Subject: [PATCH] Add appservice user/device masquerading support to base HTTPAPI --- mautrix/api.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mautrix/api.py b/mautrix/api.py index f6c9f475..d034149b 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import AsyncGenerator, ClassVar, Literal, Mapping, Union +from typing import ClassVar, Literal, Mapping from enum import Enum from json.decoder import JSONDecodeError from urllib.parse import quote as urllib_quote, urljoin as urllib_join @@ -28,7 +28,7 @@ if __optional_imports__: # Safe to import, but it's not actually needed, so don't force-import the whole types module. - from mautrix.types import JSON + from mautrix.types import JSON, DeviceID, UserID API_CALLS = Counter( name="bridge_matrix_api_calls", @@ -193,6 +193,13 @@ class HTTPAPI: default_retry_count: int """The default retry count to use if a custom value is not passed to :meth:`request`""" + as_user_id: UserID | None + """An optional user ID to set as the user_id query parameter for appservice requests.""" + as_device_id: DeviceID | None + """ + An optional device ID to set as the user_id query parameter for appservice requests (MSC3202). + """ + def __init__( self, base_url: URL | str, @@ -203,6 +210,8 @@ def __init__( txn_id: int = 0, log: TraceLogger | None = None, loop: asyncio.AbstractEventLoop | None = None, + as_user_id: UserID | None = None, + as_device_id: UserID | None = None, ) -> None: """ Args: @@ -212,6 +221,10 @@ def __init__( txn_id: The outgoing transaction ID to start with. log: The :class:`logging.Logger` instance to log requests with. default_retry_count: Default number of retries to do when encountering network errors. + as_user_id: An optional user ID to set as the user_id query parameter for + appservice requests. + as_device_id: An optional device ID to set as the user_id query parameter for + appservice requests (MSC3202). """ self.base_url = URL(base_url) self.token = token @@ -219,6 +232,8 @@ def __init__( self.session = client_session or ClientSession( loop=loop, headers={"User-Agent": self.default_ua} ) + self.as_user_id = as_user_id + self.as_device_id = as_device_id if txn_id is not None: self.txn_id = txn_id if default_retry_count is not None: @@ -360,6 +375,11 @@ async def request( query_params = query_params or {} if isinstance(query_params, dict): query_params = {k: v for k, v in query_params.items() if v is not None} + if self.as_user_id: + query_params["user_id"] = self.as_user_id + if self.as_device_id: + query_params["org.matrix.msc3202.device_id"] = self.as_device_id + query_params["device_id"] = self.as_device_id if method != Method.GET: content = content or {}