diff --git a/plugins/core/optout.py b/plugins/core/optout.py index abf08d87..6400d7d6 100644 --- a/plugins/core/optout.py +++ b/plugins/core/optout.py @@ -5,6 +5,7 @@ from collections import defaultdict from functools import total_ordering from threading import RLock +from typing import List, MutableMapping, Optional from irclib.util.compare import match_mask from sqlalchemy import ( @@ -33,7 +34,7 @@ PrimaryKeyConstraint("network", "chan", "hook"), ) -optout_cache = DefaultKeyFoldDict(list) +optout_cache: MutableMapping[str, List["OptOut"]] = DefaultKeyFoldDict(list) cache_lock = RLock() @@ -47,16 +48,18 @@ def __init__(self, channel, hook_pattern, allow): def __lt__(self, other): if isinstance(other, OptOut): - diff = len(self.channel) - len(other.channel) - if diff: - return diff < 0 - - return len(self.hook) < len(other.hook) + return (self.channel.rstrip("*"), self.hook.rstrip("*")) < ( + other.channel.rstrip("*"), + other.hook.rstrip("*"), + ) return NotImplemented - def __str__(self): - return f"{self.channel} {self.hook} {self.allow}" + def __eq__(self, other): + if isinstance(other, OptOut): + return self.channel == other.channel and self.hook == other.hook + + return NotImplemented def __repr__(self): return "{}({}, {}, {})".format( @@ -82,12 +85,12 @@ async def check_channel_permissions(event, chan, *perms): return allowed -def get_conn_optouts(conn_name): +def get_conn_optouts(conn_name) -> List[OptOut]: with cache_lock: return optout_cache[conn_name.casefold()] -def get_channel_optouts(conn_name, chan=None): +def get_channel_optouts(conn_name, chan=None) -> List[OptOut]: with cache_lock: return [ opt @@ -96,6 +99,14 @@ def get_channel_optouts(conn_name, chan=None): ] +def get_first_matching_optout(conn_name, chan, hook_name) -> Optional[OptOut]: + for optout in get_conn_optouts(conn_name): + if optout.match(chan, hook_name): + return optout + + return None + + def format_optout_list(opts): headers = ("Channel Pattern", "Hook Pattern", "Allowed") table = [ @@ -186,19 +197,12 @@ def optout_sieve(bot, event, _hook): return event hook_name = _hook.plugin.title + "." + _hook.function_name - with cache_lock: - optouts = get_conn_optouts(event.conn.name) - for _optout in optouts: - if _optout.match(event.chan, hook_name): - if not _optout.allow: - if _hook.type == "command": - event.notice( - "Sorry, that command is disabled in this channel." - ) - - return None - - break + _optout = get_first_matching_optout(event.conn.name, event.chan, hook_name) + if _optout and not _optout.allow: + if _hook.type == "command": + event.notice("Sorry, that command is disabled in this channel.") + + return None return event diff --git a/pytest.ini b/pytest.ini index 21d07025..c88bfb7f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,4 @@ filterwarnings = error ignore:pkg_resources is deprecated as an API:DeprecationWarning ignore:datetime.*:DeprecationWarning:sqlalchemy.* +asyncio_mode = auto diff --git a/tests/core_tests/test_plugin_manager.py b/tests/core_tests/test_plugin_manager.py index 7f48de01..dc520699 100644 --- a/tests/core_tests/test_plugin_manager.py +++ b/tests/core_tests/test_plugin_manager.py @@ -14,6 +14,7 @@ from cloudbot.util import database from tests.util.mock_module import MockModule + @pytest.fixture() def mock_bot(mock_bot_factory, event_loop, tmp_path): tmp_base = tmp_path / "tmp" diff --git a/tests/plugin_tests/test_chan_log.py b/tests/plugin_tests/test_chan_log.py index afa352f9..8ce21dab 100644 --- a/tests/plugin_tests/test_chan_log.py +++ b/tests/plugin_tests/test_chan_log.py @@ -6,7 +6,7 @@ def test_format_exception_chain(): def _get_data(exc): yield repr(exc) - if hasattr(exc, 'add_note'): + if hasattr(exc, "add_note"): yield f" add_note = {exc.add_note!r}" yield f" args = {exc.args!r}" diff --git a/tests/plugin_tests/test_optout.py b/tests/plugin_tests/test_optout.py index 50065076..051a2b09 100644 --- a/tests/plugin_tests/test_optout.py +++ b/tests/plugin_tests/test_optout.py @@ -1,6 +1,9 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from plugins.core import optout +from tests.util.mock_db import MockDB def test_conn_case(): @@ -57,5 +60,452 @@ def test_match(): assert res is None -def test_get_channel_optouts(): - pass +def test_get_first_matching_optout(mock_db): + optout.optout_table.create(mock_db.engine) + mock_db.load_data( + optout.optout_table, + [ + { + "network": "my_net", + "chan": "#*", + "hook": "my.*", + "allow": False, + }, + { + "network": "my_net", + "chan": "#my*", + "hook": "*.*", + "allow": True, + }, + ], + ) + + optout.load_cache(mock_db.session()) + + assert ( + optout.get_first_matching_optout("my_net", "#my_chan", "my.hook").allow + is True + ) + + +def test_optout_match(): + assert optout.OptOut( + channel="#foo*", hook_pattern="foo.*", allow=True + ).match("#foobar", "foo.bar") + + +def test_optout_compare(): + with pytest.raises(TypeError): + assert ( + optout.OptOut(channel="#foo*", hook_pattern="foo.*", allow=True) > 5 # type: ignore[operator] + ) + + +def test_optout_eq_other(): + assert optout.OptOut(channel="#foo*", hook_pattern="foo.*", allow=True) != 5 + + +def test_optout_equals(): + args = {"channel": "#foo", "hook_pattern": "foo", "allow": True} + assert optout.OptOut(**args) == optout.OptOut(**args) + + +def test_optout_sort(): + optouts = [ + optout.OptOut(channel="#aaa", hook_pattern="test", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="test*", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="tes*", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="te*", allow=True), + optout.OptOut(channel="#aa", hook_pattern="test", allow=True), + ] + + assert sorted(optouts) == [ + optout.OptOut(channel="#aa", hook_pattern="test", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="te*", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="tes*", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="test", allow=True), + optout.OptOut(channel="#aaa", hook_pattern="test*", allow=True), + ] + + +def test_exact_override(): + net = "my_net" + channel = "#foobar" + hook = "my.hook" + optouts = { + net: [ + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar", + hook_pattern="my.*", + allow=True, + ), + optout.OptOut( + channel="#fooba*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar*", + hook_pattern="my.*", + allow=False, + ), + ] + } + + for opts in optouts.values(): + opts.sort(reverse=True) + + with patch.dict(optout.optout_cache, clear=True, values=optouts): + assert ( + optout.get_first_matching_optout(net, channel, hook).allow is True + ) + + +async def test_check_channel_permissions(): + event = MagicMock() + event.chan = "#foo" + event.check_permissions = AsyncMock(return_value=True) + res = await optout.check_channel_permissions( + event, "#bar", "botcontrol", "staff" + ) + assert res + + +def test_get_global_optouts(): + optouts = { + "net2": [ + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=True, + ), + ], + "net": [ + optout.OptOut( + "#baz", + hook_pattern="my.hook", + allow=True, + ), + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar", + hook_pattern="my.*", + allow=True, + ), + optout.OptOut( + channel="#fooba*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar*", + hook_pattern="my.*", + allow=False, + ), + ], + } + + for opts in optouts.values(): + opts.sort(reverse=True) + + with patch.dict(optout.optout_cache, clear=True, values=optouts): + assert optout.get_channel_optouts("net", None) == [ + optout.OptOut( + channel="#foobar", + hook_pattern="my.*", + allow=True, + ), + optout.OptOut( + channel="#foobar*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#fooba*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + "#baz", + hook_pattern="my.hook", + allow=True, + ), + ] + + +def test_match_optout(): + optouts = { + "net2": [ + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=True, + ), + ], + "net": [ + optout.OptOut( + channel="#other", + hook_pattern="my.hook", + allow=True, + ), + optout.OptOut( + "#baz", + hook_pattern="my.hook", + allow=True, + ), + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar", + hook_pattern="my.*", + allow=True, + ), + optout.OptOut( + channel="#fooba*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar*", + hook_pattern="my.*", + allow=False, + ), + ], + } + + for opts in optouts.values(): + opts.sort(reverse=True) + + with patch.dict(optout.optout_cache, clear=True, values=optouts): + assert ( + optout.get_first_matching_optout( + "my_other_net", "#foobar", "my.hook" + ) + is None + ) + assert ( + optout.get_first_matching_optout("net", "#chan", "my.hook") is None + ) + assert ( + optout.get_first_matching_optout( + "net", "#foobar", "some.other_hook" + ) + is None + ) + assert ( + optout.get_first_matching_optout("net", "#foobar", "my.hook").allow + is True + ) + + +def test_format(): + optouts = [ + optout.OptOut( + channel="#other", + hook_pattern="my.hook", + allow=True, + ), + optout.OptOut( + "#baz", + hook_pattern="my.hook", + allow=True, + ), + optout.OptOut( + channel="#foo*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar", + hook_pattern="my.*", + allow=True, + ), + optout.OptOut( + channel="#fooba*", + hook_pattern="my.*", + allow=False, + ), + optout.OptOut( + channel="#foobar*", + hook_pattern="my.*", + allow=False, + ), + ] + + assert ( + optout.format_optout_list(optouts) + == """\ +| Channel Pattern | Hook Pattern | Allowed | +| --------------- | ------------ | ------- | +| #other | my.hook | true | +| #baz | my.hook | true | +| #foo* | my.* | false | +| #foobar | my.* | true | +| #fooba* | my.* | false | +| #foobar* | my.* | false |""" + ) + + +class TestSetOptOut: + def test_add(self, mock_db: MockDB): + with mock_db.session() as session: + optout.optout_table.create(mock_db.engine) + optout.set_optout(session, "net", "#chan", "my.hook", True) + + assert mock_db.get_data(optout.optout_table) == [ + ("net", "#chan", "my.hook", True) + ] + + def test_update(self, mock_db: MockDB): + with mock_db.session() as session: + optout.optout_table.create(mock_db.engine) + mock_db.load_data( + optout.optout_table, + [ + { + "network": "net", + "chan": "#chan", + "hook": "my.hook", + "allow": False, + } + ], + ) + + assert mock_db.get_data(optout.optout_table) == [ + ("net", "#chan", "my.hook", False) + ] + + optout.set_optout(session, "net", "#chan", "my.hook", True) + + assert mock_db.get_data(optout.optout_table) == [ + ("net", "#chan", "my.hook", True) + ] + + +class TestDelOptOut: + def test_del_no_match(self, mock_db: MockDB): + with mock_db.session() as session: + optout.optout_table.create(mock_db.engine) + assert ( + optout.del_optout(session, "net", "#chan", "my.hook") is False + ) + + def test_del(self, mock_db: MockDB): + with mock_db.session() as session: + optout.optout_table.create(mock_db.engine) + mock_db.load_data( + optout.optout_table, + [ + { + "network": "net", + "chan": "#chan", + "hook": "my.hook", + "allow": False, + } + ], + ) + + assert mock_db.get_data(optout.optout_table) == [ + ("net", "#chan", "my.hook", False) + ] + + assert optout.del_optout(session, "net", "#chan", "my.hook") is True + + assert mock_db.get_data(optout.optout_table) == [] + + +class TestClearOptOut: + def test_clear_chan(self, mock_db: MockDB): + with mock_db.session() as session: + optout.optout_table.create(mock_db.engine) + mock_db.load_data( + optout.optout_table, + [ + { + "network": "net", + "chan": "#chan", + "hook": "my.hook", + "allow": False, + }, + { + "network": "othernet", + "chan": "#chan", + "hook": "my.hook", + "allow": False, + }, + { + "network": "net", + "chan": "#otherchan", + "hook": "my.hook", + "allow": False, + }, + { + "network": "net", + "chan": "#chan", + "hook": "other.hook", + "allow": False, + }, + ], + ) + + assert len(mock_db.get_data(optout.optout_table)) == 4 + + assert optout.clear_optout(session, "net", "#chan") == 2 + + assert len(mock_db.get_data(optout.optout_table)) == 2 + + def test_clear_conn(self, mock_db: MockDB): + with mock_db.session() as session: + optout.optout_table.create(mock_db.engine) + mock_db.load_data( + optout.optout_table, + [ + { + "network": "net", + "chan": "#chan", + "hook": "my.hook", + "allow": False, + }, + { + "network": "othernet", + "chan": "#chan", + "hook": "my.hook", + "allow": False, + }, + { + "network": "net", + "chan": "#otherchan", + "hook": "my.hook", + "allow": False, + }, + { + "network": "net", + "chan": "#chan", + "hook": "other.hook", + "allow": False, + }, + ], + ) + + assert len(mock_db.get_data(optout.optout_table)) == 4 + + assert optout.clear_optout(session, "net") == 3 + + assert len(mock_db.get_data(optout.optout_table)) == 1 diff --git a/tests/util/mock_db.py b/tests/util/mock_db.py index 26e50e23..b15abeda 100644 --- a/tests/util/mock_db.py +++ b/tests/util/mock_db.py @@ -1,4 +1,6 @@ -from sqlalchemy import create_engine +from typing import Any, Dict, List + +from sqlalchemy import Table, create_engine from sqlalchemy.orm import scoped_session, sessionmaker from cloudbot.util.database import Session @@ -15,7 +17,11 @@ def __init__(self, path="sqlite:///:memory:", force_session=False): def get_data(self, table): return self.session().execute(table.select()).fetchall() - def add_row(self, *args, **data): - table = args[0] + def add_row(self, table: Table, /, **data: Any) -> None: self.session().execute(table.insert().values(data)) self.session().commit() + + def load_data(self, table: Table, data: List[Dict[str, Any]]): + with self.session() as session, session.begin(): + for item in data: + session.execute(table.insert().values(item))