From 70bb9a631ced1db2602df0c061dc656a944bee20 Mon Sep 17 00:00:00 2001 From: nex Date: Thu, 9 Nov 2023 11:35:09 +0000 Subject: [PATCH] Allow importing encryption keys on start --- src/niobot/client.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/niobot/client.py b/src/niobot/client.py index a2ee636..5f1a980 100644 --- a/src/niobot/client.py +++ b/src/niobot/client.py @@ -1,9 +1,12 @@ import asyncio +import getpass import importlib import inspect import logging import os +import pathlib import re +import sys import time import typing import warnings @@ -44,6 +47,7 @@ class NioBot(nio.AsyncClient): :param owner_id: The user ID of the bot owner. If set, only this user can run owner-only commands, etc. :param max_message_cache: The maximum number of messages to cache. Defaults to 1000. :param ignore_self: Whether to ignore messages sent by the bot itself. Defaults to False. Useful for self-bots. + :param import_keys: A key export file and password tuple. These keys will be imported at startup. """ def __init__( @@ -66,6 +70,7 @@ def __init__( automatic_markdown_renderer: bool = True, max_message_cache: int = 1000, ignore_self: bool = True, + import_keys: typing.Tuple[os.PathLike, typing.Optional[str]] = None ): if user_id == owner_id and ignore_self is True: warnings.warn( @@ -154,6 +159,20 @@ def __init__( self.log.info("Auto-joining rooms enabled.") self.add_event_callback(self._auto_join_room_backlog_callback, nio.InviteMemberEvent) # type: ignore + if import_keys: + keys_path, keys_password = import_keys + if not keys_password: + if sys.stdin.isatty(): + keys_password = getpass.getpass(f"Password for key import ({keys_path}): ") + else: + raise ValueError( + "No password was provided for automatic key import and cannot interactively get password." + ) + + self.__key_import = pathlib.Path(keys_path), keys_password + else: + self.__key_import = None + async def sync(self, *args, **kwargs) -> U[nio.SyncResponse, nio.SyncError]: sync = await super().sync(*args, **kwargs) if isinstance(sync, nio.SyncResponse): @@ -248,7 +267,7 @@ async def update_read_receipts(self, room: U[str, nio.MatrixRoom], event: nio.Ev msg = result.message if isinstance(result, nio.ErrorResponse) else "?" self.log.warning("Failed to update read receipts for %s: %s", room, msg) else: - self.log.debug("Updated read receipts for %s to %s.", room, event) + self.log.debug("Updated read receipts for %s to %s.", room, event_id) async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessageText) -> None: """Processes a message and runs the command it is trying to invoke if any.""" @@ -894,6 +913,10 @@ async def start( """Starts the bot, running the sync loop.""" self.loop = asyncio.get_event_loop() self.dispatch("event_loop_ready") + if self.__key_import: + self.log.info("Starting automatic key import") + await self.import_keys(*map(str, self.__key_import)) + if password or sso_token: if password: self.log.critical(