Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add coverage for db access in optout.py #677

Merged
merged 1 commit into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions plugins/core/optout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict
from functools import total_ordering
from threading import RLock
from typing import List, MutableMapping, Optional

from irclib.util.compare import match_mask
from sqlalchemy import (
Expand Down Expand Up @@ -33,7 +34,7 @@
PrimaryKeyConstraint("network", "chan", "hook"),
)

optout_cache = DefaultKeyFoldDict(list)
optout_cache: MutableMapping[str, List["OptOut"]] = DefaultKeyFoldDict(list)

cache_lock = RLock()

Expand All @@ -47,16 +48,18 @@ def __init__(self, channel, hook_pattern, allow):

def __lt__(self, other):
if isinstance(other, OptOut):
diff = len(self.channel) - len(other.channel)
if diff:
return diff < 0

return len(self.hook) < len(other.hook)
return (self.channel.rstrip("*"), self.hook.rstrip("*")) < (
other.channel.rstrip("*"),
other.hook.rstrip("*"),
)

return NotImplemented

def __str__(self):
return f"{self.channel} {self.hook} {self.allow}"
def __eq__(self, other):
if isinstance(other, OptOut):
return self.channel == other.channel and self.hook == other.hook

return NotImplemented

def __repr__(self):
return "{}({}, {}, {})".format(
Expand All @@ -82,12 +85,12 @@ async def check_channel_permissions(event, chan, *perms):
return allowed


def get_conn_optouts(conn_name):
def get_conn_optouts(conn_name) -> List[OptOut]:
with cache_lock:
return optout_cache[conn_name.casefold()]


def get_channel_optouts(conn_name, chan=None):
def get_channel_optouts(conn_name, chan=None) -> List[OptOut]:
with cache_lock:
return [
opt
Expand All @@ -96,6 +99,14 @@ def get_channel_optouts(conn_name, chan=None):
]


def get_first_matching_optout(conn_name, chan, hook_name) -> Optional[OptOut]:
for optout in get_conn_optouts(conn_name):
if optout.match(chan, hook_name):
return optout

return None


def format_optout_list(opts):
headers = ("Channel Pattern", "Hook Pattern", "Allowed")
table = [
Expand Down Expand Up @@ -186,19 +197,12 @@ def optout_sieve(bot, event, _hook):
return event

hook_name = _hook.plugin.title + "." + _hook.function_name
with cache_lock:
optouts = get_conn_optouts(event.conn.name)
for _optout in optouts:
if _optout.match(event.chan, hook_name):
if not _optout.allow:
if _hook.type == "command":
event.notice(
"Sorry, that command is disabled in this channel."
)

return None

break
_optout = get_first_matching_optout(event.conn.name, event.chan, hook_name)
if _optout and not _optout.allow:
if _hook.type == "command":
event.notice("Sorry, that command is disabled in this channel.")

return None

return event

Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ filterwarnings =
error
ignore:pkg_resources is deprecated as an API:DeprecationWarning
ignore:datetime.*:DeprecationWarning:sqlalchemy.*
asyncio_mode = auto
1 change: 1 addition & 0 deletions tests/core_tests/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from cloudbot.util import database
from tests.util.mock_module import MockModule


@pytest.fixture()
def mock_bot(mock_bot_factory, event_loop, tmp_path):
tmp_base = tmp_path / "tmp"
Expand Down
2 changes: 1 addition & 1 deletion tests/plugin_tests/test_chan_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def test_format_exception_chain():
def _get_data(exc):
yield repr(exc)
if hasattr(exc, 'add_note'):
if hasattr(exc, "add_note"):
yield f" add_note = {exc.add_note!r}"

yield f" args = {exc.args!r}"
Expand Down
Loading