Skip to content

Commit

Permalink
feat: Merge pull request #67 from Era-Dorta/base64-attachments
Browse files Browse the repository at this point in the history
Enable getting attachments when receiving a message.
  • Loading branch information
Era-Dorta authored Dec 20, 2024
2 parents cdde4e3 + ebfabb3 commit f8dc961
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 30 deletions.
27 changes: 27 additions & 0 deletions signalbot/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64

import aiohttp
import websockets
from typing import Any
Expand Down Expand Up @@ -139,6 +141,27 @@ async def get_groups(self):
):
raise GroupsError

async def get_attachment(self, attachment_id: str) -> str:
uri = f"{self._attachment_rest_uri()}/{attachment_id}"
try:
async with aiohttp.ClientSession() as session:
resp = await session.get(uri)
resp.raise_for_status()
content = await resp.content.read()
except (
aiohttp.ClientError,
aiohttp.http_exceptions.HttpProcessingError,
):
raise GetAttachmentError

base64_bytes = base64.b64encode(content)
base64_string = str(base64_bytes, encoding="utf-8")

return base64_string

def _attachment_rest_uri(self):
return f"http://{self.signal_service}/v1/attachments"

def _receive_ws_uri(self):
return f"ws://{self.signal_service}/v1/receive/{self.phone_number}"

Expand Down Expand Up @@ -181,3 +204,7 @@ class ReactionError(Exception):

class GroupsError(Exception):
pass


class GetAttachmentError(Exception):
pass
2 changes: 1 addition & 1 deletion signalbot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def _produce(self, name: int) -> None:
logging.info(f"[Raw Message] {raw_message}")

try:
message = Message.parse(raw_message)
message = await Message.parse(self._signal, raw_message)
except UnknownMessageFormatError:
continue

Expand Down
23 changes: 19 additions & 4 deletions signalbot/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Optional


from signalbot.api import SignalAPI


class MessageType(Enum):
SYNC_MESSAGE = 1
DATA_MESSAGE = 2
Expand Down Expand Up @@ -61,7 +64,7 @@ def is_group(self) -> bool:
return bool(self.group)

@classmethod
def parse(cls, raw_message: str):
async def parse(cls, signal: SignalAPI, raw_message: str):
try:
raw_message = json.loads(raw_message)
except Exception:
Expand Down Expand Up @@ -90,6 +93,7 @@ def parse(cls, raw_message: str):
mentions = cls._parse_mentions(
raw_message["envelope"]["syncMessage"]["sentMessage"]
)
base64_attachments = None

# Option 2: dataMessage
elif "dataMessage" in raw_message["envelope"]:
Expand All @@ -98,13 +102,13 @@ def parse(cls, raw_message: str):
group = cls._parse_group_information(raw_message["envelope"]["dataMessage"])
reaction = cls._parse_reaction(raw_message["envelope"]["dataMessage"])
mentions = cls._parse_mentions(raw_message["envelope"]["dataMessage"])
base64_attachments = await cls._parse_attachments(
signal, raw_message["envelope"]["dataMessage"]
)

else:
raise UnknownMessageFormatError

# TODO: base64_attachments
base64_attachments = []

return cls(
source,
source_number,
Expand All @@ -119,6 +123,17 @@ def parse(cls, raw_message: str):
raw_message,
)

@classmethod
async def _parse_attachments(cls, signal: SignalAPI, data_message: dict) -> str:

if "attachments" not in data_message:
return []

return [
await signal.get_attachment(attachment["id"])
for attachment in data_message["attachments"]
]

@classmethod
def _parse_sync_message(cls, sync_message: dict) -> str:
try:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def test_send_uri(self):
actual_uri = self.signal_api._send_rest_uri()
self.assertEqual(actual_uri, expected_uri)

def test_attachment_rest_uri(self):
expected_uri = f"http://{self.signal_service}/v1/attachments"
actual_uri = self.signal_api._attachment_rest_uri()
self.assertEqual(actual_uri, expected_uri)


if __name__ == "__main__":
unittest.main()
88 changes: 63 additions & 25 deletions tests/test_message.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,105 @@
import base64
import unittest
from unittest.mock import AsyncMock, patch, Mock
import aiohttp
from signalbot import Message, MessageType
from signalbot.api import SignalAPI
from signalbot.utils import ChatTestCase, SendMessagesMock, ReceiveMessagesMock


class TestMessage(unittest.TestCase):
class TestMessage(unittest.IsolatedAsyncioTestCase):
raw_sync_message = '{"envelope":{"source":"+490123456789","sourceNumber":"+490123456789","sourceUuid":"<uuid>","sourceName":"<name>","sourceDevice":1,"timestamp":1632576001632,"syncMessage":{"sentMessage":{"timestamp":1632576001632,"message":"Uhrzeit","expiresInSeconds":0,"viewOnce":false,"mentions":[],"attachments":[],"contacts":[],"groupInfo":{"groupId":"<groupid>","type":"DELIVER"},"destination":null,"destinationNumber":null,"destinationUuid":null}}}}' # noqa
raw_data_message = '{"envelope":{"source":"+490123456789","sourceNumber":"+490123456789","sourceUuid":"<uuid>","sourceName":"<name>","sourceDevice":1,"timestamp":1632576001632,"dataMessage":{"timestamp":1632576001632,"message":"Uhrzeit","expiresInSeconds":0,"viewOnce":false,"mentions":[],"attachments":[],"contacts":[],"groupInfo":{"groupId":"<groupid>","type":"DELIVER"}}}}' # noqa
raw_reaction_message = '{"envelope":{"source":"<source>","sourceNumber":"<source>","sourceUuid":"<uuid>","sourceName":"<name>","sourceDevice":1,"timestamp":1632576001632,"syncMessage":{"sentMessage":{"timestamp":1632576001632,"message":null,"expiresInSeconds":0,"viewOnce":false,"reaction":{"emoji":"👍","targetAuthor":"<target>","targetAuthorNumber":"<target>","targetAuthorUuid":"<uuid>","targetSentTimestamp":1632576001632,"isRemove":false},"mentions":[],"attachments":[],"contacts":[],"groupInfo":{"groupId":"<groupid>","type":"DELIVER"},"destination":null,"destinationNumber":null,"destinationUuid":null}}}}' # noqa
raw_user_chat_message = '{"envelope":{"source":"+490123456789","sourceNumber":"+490123456789","sourceUuid":"<uuid>","sourceName":"<name>","sourceDevice":1,"timestamp":1632576001632,"dataMessage":{"timestamp":1632576001632,"message":"Uhrzeit","expiresInSeconds":0,"viewOnce":false}},"account":"+49987654321","subscription":0}' # noqa
raw_attachment_message = '{"envelope":{"source":"+490123456789","sourceNumber":"+490123456789","sourceUuid":"<uuid>","sourceName":"<name>","sourceDevice":1,"timestamp":1632576001632,"dataMessage":{"timestamp":1632576001632,"message":"Uhrzeit","expiresInSeconds":0,"viewOnce":false, "attachments": [{"contentType": "image/png", "filename": "image.png", "id": "4296180834490578536","size": 12005}]}},"account":"+49987654321","subscription":0}' # noqa

expected_source = "+490123456789"
expected_timestamp = 1632576001632
expected_text = "Uhrzeit"
expected_group = "<groupid>"

signal_service = "127.0.0.1:8080"
phone_number = "+49123456789"

group_id = "group_id1"
group_secret = "group.group_secret1"
groups = {group_id: group_secret}

def setUp(self):
self.signal_api = SignalAPI(
TestMessage.signal_service, TestMessage.phone_number
)

# Own Message
def test_parse_source_own_message(self):
message = Message.parse(TestMessage.raw_sync_message)
async def test_parse_source_own_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_sync_message)
self.assertEqual(message.timestamp, TestMessage.expected_timestamp)

def test_parse_timestamp_own_message(self):
message = Message.parse(TestMessage.raw_sync_message)
async def test_parse_timestamp_own_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_sync_message)
self.assertEqual(message.source, TestMessage.expected_source)

def test_parse_type_own_message(self):
message = Message.parse(TestMessage.raw_sync_message)
async def test_parse_type_own_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_sync_message)
self.assertEqual(message.type, MessageType.SYNC_MESSAGE)

def test_parse_text_own_message(self):
message = Message.parse(TestMessage.raw_sync_message)
async def test_parse_text_own_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_sync_message)
self.assertEqual(message.text, TestMessage.expected_text)

def test_parse_group_own_message(self):
message = Message.parse(TestMessage.raw_sync_message)
async def test_parse_group_own_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_sync_message)
self.assertEqual(message.group, TestMessage.expected_group)

# Foreign Messages
def test_parse_source_foreign_message(self):
message = Message.parse(TestMessage.raw_data_message)
async def test_parse_source_foreign_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_data_message)
self.assertEqual(message.timestamp, TestMessage.expected_timestamp)

def test_parse_timestamp_foreign_message(self):
message = Message.parse(TestMessage.raw_data_message)
async def test_parse_timestamp_foreign_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_data_message)
self.assertEqual(message.source, TestMessage.expected_source)

def test_parse_type_foreign_message(self):
message = Message.parse(TestMessage.raw_data_message)
async def test_parse_type_foreign_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_data_message)
self.assertEqual(message.type, MessageType.DATA_MESSAGE)

def test_parse_text_foreign_message(self):
message = Message.parse(TestMessage.raw_data_message)
async def test_parse_text_foreign_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_data_message)
self.assertEqual(message.text, TestMessage.expected_text)

def test_parse_group_foreign_message(self):
message = Message.parse(TestMessage.raw_data_message)
async def test_parse_group_foreign_message(self):
message = await Message.parse(self.signal_api, TestMessage.raw_data_message)
self.assertEqual(message.group, TestMessage.expected_group)

def test_read_reaction(self):
message = Message.parse(TestMessage.raw_reaction_message)
async def test_read_reaction(self):
message = await Message.parse(self.signal_api, TestMessage.raw_reaction_message)
self.assertEqual(message.reaction, "👍")

@patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
async def test_attachments(self, mock_get):
attachment_bytes_str = b"test"

mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.raise_for_status = Mock()
mock_response.content.read = AsyncMock(return_value=attachment_bytes_str)

mock_get.return_value = mock_response

expected_base64_bytes = base64.b64encode(attachment_bytes_str)
expected_base64_str = str(expected_base64_bytes, encoding="utf-8")

message = await Message.parse(
self.signal_api, TestMessage.raw_attachment_message
)
self.assertEqual(message.base64_attachments, [expected_base64_str])

# User Chats
def test_parse_user_chat_message(self):
message = Message.parse(TestMessage.raw_user_chat_message)
async def test_parse_user_chat_message(self):
message = await Message.parse(
self.signal_api, TestMessage.raw_user_chat_message
)
self.assertEqual(message.source, TestMessage.expected_source)
self.assertEqual(message.text, TestMessage.expected_text)
self.assertEqual(message.timestamp, TestMessage.expected_timestamp)
Expand Down

0 comments on commit f8dc961

Please sign in to comment.