diff --git a/sopel/db.py b/sopel/db.py index 4a56b8df99..654a66f6f7 100644 --- a/sopel/db.py +++ b/sopel/db.py @@ -295,7 +295,7 @@ def delete_nick_value(self, nick, key): finally: session.close() - def get_nick_value(self, nick, key): + def get_nick_value(self, nick, key, default=None): """Retrieves the value for a given key associated with a nick.""" nick = Identifier(nick) session = self.ssession() @@ -307,6 +307,8 @@ def get_nick_value(self, nick, key): .one_or_none() if result is not None: result = result.value + elif default is not None: + result = default return _deserialize(result) except SQLAlchemyError: session.rollback() @@ -434,7 +436,7 @@ def delete_channel_value(self, channel, key): finally: session.close() - def get_channel_value(self, channel, key): + def get_channel_value(self, channel, key, default=None): """Retrieves the value for a given key associated with a channel.""" channel = Identifier(channel).lower() session = self.ssession() @@ -445,6 +447,8 @@ def get_channel_value(self, channel, key): .one_or_none() if result is not None: result = result.value + elif default is not None: + result = default return _deserialize(result) except SQLAlchemyError: session.rollback() @@ -498,7 +502,7 @@ def delete_plugin_value(self, plugin, key): finally: session.close() - def get_plugin_value(self, plugin, key): + def get_plugin_value(self, plugin, key, default=None): """Retrieves the value for a given key associated with a plugin.""" plugin = plugin.lower() session = self.ssession() @@ -509,6 +513,8 @@ def get_plugin_value(self, plugin, key): .one_or_none() if result is not None: result = result.value + elif default is not None: + result = default return _deserialize(result) except SQLAlchemyError: session.rollback() @@ -518,13 +524,13 @@ def get_plugin_value(self, plugin, key): # NICK AND CHANNEL FUNCTIONS - def get_nick_or_channel_value(self, name, key): + def get_nick_or_channel_value(self, name, key, default=None): """Gets the value `key` associated to the nick or channel `name`.""" name = Identifier(name) if name.is_nick(): - return self.get_nick_value(name, key) + return self.get_nick_value(name, key, default) else: - return self.get_channel_value(name, key) + return self.get_channel_value(name, key, default) def get_preferred_value(self, names, key): """Gets the value for the first name which has it set. diff --git a/sopel/modules/adminchannel.py b/sopel/modules/adminchannel.py index d0586013b7..ff0b9d6077 100644 --- a/sopel/modules/adminchannel.py +++ b/sopel/modules/adminchannel.py @@ -347,6 +347,4 @@ def set_mask(bot, trigger): @commands('showmask') def show_mask(bot, trigger): """Show the topic mask for the current channel.""" - mask = bot.db.get_channel_value(trigger.sender, 'topic_mask') - mask = mask or default_mask(trigger) - bot.say(mask) + bot.say(bot.db.get_channel_value(trigger.sender, 'topic_mask', default_mask(trigger))) diff --git a/test/test_db.py b/test/test_db.py index 3fbb3af979..78bc8bfd1c 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -151,6 +151,11 @@ def test_get_nick_value(db): assert found_value == value +def test_get_nick_value_default(db): + assert db.get_nick_value("TestUser", "DoesntExist") is None + assert db.get_nick_value("TestUser", "DoesntExist", "MyDefault") == "MyDefault" + + def test_delete_nick_value(db): nick = 'Embolalia' db.set_nick_value(nick, 'wasd', 'uldr') @@ -255,6 +260,11 @@ def test_get_channel_value(db): assert result == 'zxcv' +def test_get_channel_value_default(db): + assert db.get_channel_value("TestChan", "DoesntExist") is None + assert db.get_channel_value("TestChan", "DoesntExist", "MyDefault") == "MyDefault" + + def test_get_nick_or_channel_value(db): db.set_nick_value('asdf', 'qwer', 'poiu') db.set_channel_value('#asdf', 'qwer', '/.,m') @@ -262,6 +272,11 @@ def test_get_nick_or_channel_value(db): assert db.get_nick_or_channel_value('#asdf', 'qwer') == '/.,m' +def test_get_nick_or_channel_value_default(db): + assert db.get_nick_or_channel_value("Test", "DoesntExist") is None + assert db.get_nick_or_channel_value("Test", "DoesntExist", "MyDefault") == "MyDefault" + + def test_get_preferred_value(db): db.set_nick_value('asdf', 'qwer', 'poiu') db.set_channel_value('#asdf', 'qwer', '/.,m') @@ -288,6 +303,11 @@ def test_get_plugin_value(db): assert result == 'zxcv' +def test_get_plugin_value_default(db): + assert db.get_plugin_value("TestPlugin", "DoesntExist") is None + assert db.get_plugin_value("TestPlugin", "DoesntExist", "MyDefault") == "MyDefault" + + def test_delete_plugin_value(db): db.set_plugin_value('plugin', 'wasd', 'uldr') assert db.get_plugin_value('plugin', 'wasd') == 'uldr'