Skip to content

Commit

Permalink
Add appservice user/device masquerading support to base HTTPAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Sep 6, 2023
1 parent 92e6091 commit 8ff8d07
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions mautrix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import AsyncGenerator, ClassVar, Literal, Mapping, Union
from typing import ClassVar, Literal, Mapping
from enum import Enum
from json.decoder import JSONDecodeError
from urllib.parse import quote as urllib_quote, urljoin as urllib_join
Expand All @@ -28,7 +28,7 @@

if __optional_imports__:
# Safe to import, but it's not actually needed, so don't force-import the whole types module.
from mautrix.types import JSON
from mautrix.types import JSON, DeviceID, UserID

API_CALLS = Counter(
name="bridge_matrix_api_calls",
Expand Down Expand Up @@ -193,6 +193,13 @@ class HTTPAPI:
default_retry_count: int
"""The default retry count to use if a custom value is not passed to :meth:`request`"""

as_user_id: UserID | None
"""An optional user ID to set as the user_id query parameter for appservice requests."""
as_device_id: DeviceID | None
"""
An optional device ID to set as the user_id query parameter for appservice requests (MSC3202).
"""

def __init__(
self,
base_url: URL | str,
Expand All @@ -203,6 +210,8 @@ def __init__(
txn_id: int = 0,
log: TraceLogger | None = None,
loop: asyncio.AbstractEventLoop | None = None,
as_user_id: UserID | None = None,
as_device_id: UserID | None = None,
) -> None:
"""
Args:
Expand All @@ -212,13 +221,19 @@ def __init__(
txn_id: The outgoing transaction ID to start with.
log: The :class:`logging.Logger` instance to log requests with.
default_retry_count: Default number of retries to do when encountering network errors.
as_user_id: An optional user ID to set as the user_id query parameter for
appservice requests.
as_device_id: An optional device ID to set as the user_id query parameter for
appservice requests (MSC3202).
"""
self.base_url = URL(base_url)
self.token = token
self.log = log or logging.getLogger("mau.http")
self.session = client_session or ClientSession(
loop=loop, headers={"User-Agent": self.default_ua}
)
self.as_user_id = as_user_id
self.as_device_id = as_device_id
if txn_id is not None:
self.txn_id = txn_id
if default_retry_count is not None:
Expand Down Expand Up @@ -360,6 +375,11 @@ async def request(
query_params = query_params or {}
if isinstance(query_params, dict):
query_params = {k: v for k, v in query_params.items() if v is not None}
if self.as_user_id:
query_params["user_id"] = self.as_user_id
if self.as_device_id:
query_params["org.matrix.msc3202.device_id"] = self.as_device_id
query_params["device_id"] = self.as_device_id

if method != Method.GET:
content = content or {}
Expand Down

0 comments on commit 8ff8d07

Please sign in to comment.