Skip to content

Commit

Permalink
Allow importing encryption keys on start
Browse files Browse the repository at this point in the history
  • Loading branch information
nexy7574 committed Nov 9, 2023
1 parent e043bdf commit 70bb9a6
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/niobot/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import getpass
import importlib
import inspect
import logging
import os
import pathlib
import re
import sys
import time
import typing
import warnings
Expand Down Expand Up @@ -44,6 +47,7 @@ class NioBot(nio.AsyncClient):
:param owner_id: The user ID of the bot owner. If set, only this user can run owner-only commands, etc.
:param max_message_cache: The maximum number of messages to cache. Defaults to 1000.
:param ignore_self: Whether to ignore messages sent by the bot itself. Defaults to False. Useful for self-bots.
:param import_keys: A key export file and password tuple. These keys will be imported at startup.
"""

def __init__(
Expand All @@ -66,6 +70,7 @@ def __init__(
automatic_markdown_renderer: bool = True,
max_message_cache: int = 1000,
ignore_self: bool = True,
import_keys: typing.Tuple[os.PathLike, typing.Optional[str]] = None
):
if user_id == owner_id and ignore_self is True:
warnings.warn(
Expand Down Expand Up @@ -154,6 +159,20 @@ def __init__(
self.log.info("Auto-joining rooms enabled.")
self.add_event_callback(self._auto_join_room_backlog_callback, nio.InviteMemberEvent) # type: ignore

if import_keys:
keys_path, keys_password = import_keys
if not keys_password:
if sys.stdin.isatty():
keys_password = getpass.getpass(f"Password for key import ({keys_path}): ")
else:
raise ValueError(
"No password was provided for automatic key import and cannot interactively get password."
)

self.__key_import = pathlib.Path(keys_path), keys_password
else:
self.__key_import = None

async def sync(self, *args, **kwargs) -> U[nio.SyncResponse, nio.SyncError]:
sync = await super().sync(*args, **kwargs)
if isinstance(sync, nio.SyncResponse):
Expand Down Expand Up @@ -248,7 +267,7 @@ async def update_read_receipts(self, room: U[str, nio.MatrixRoom], event: nio.Ev
msg = result.message if isinstance(result, nio.ErrorResponse) else "?"
self.log.warning("Failed to update read receipts for %s: %s", room, msg)
else:
self.log.debug("Updated read receipts for %s to %s.", room, event)
self.log.debug("Updated read receipts for %s to %s.", room, event_id)

async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessageText) -> None:
"""Processes a message and runs the command it is trying to invoke if any."""
Expand Down Expand Up @@ -894,6 +913,10 @@ async def start(
"""Starts the bot, running the sync loop."""
self.loop = asyncio.get_event_loop()
self.dispatch("event_loop_ready")
if self.__key_import:
self.log.info("Starting automatic key import")
await self.import_keys(*map(str, self.__key_import))

if password or sso_token:
if password:
self.log.critical(
Expand Down

0 comments on commit 70bb9a6

Please sign in to comment.