Skip to content

Commit

Permalink
Merge pull request #1673 from half-duplex/db-default
Browse files Browse the repository at this point in the history
db: allow requesting a default if key not found
  • Loading branch information
dgw authored Nov 22, 2019
2 parents 8333360 + 7ad2709 commit d00e08a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
18 changes: 12 additions & 6 deletions sopel/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def delete_nick_value(self, nick, key):
finally:
session.close()

def get_nick_value(self, nick, key):
def get_nick_value(self, nick, key, default=None):
"""Retrieves the value for a given key associated with a nick."""
nick = Identifier(nick)
session = self.ssession()
Expand All @@ -307,6 +307,8 @@ def get_nick_value(self, nick, key):
.one_or_none()
if result is not None:
result = result.value
elif default is not None:
result = default
return _deserialize(result)
except SQLAlchemyError:
session.rollback()
Expand Down Expand Up @@ -434,7 +436,7 @@ def delete_channel_value(self, channel, key):
finally:
session.close()

def get_channel_value(self, channel, key):
def get_channel_value(self, channel, key, default=None):
"""Retrieves the value for a given key associated with a channel."""
channel = Identifier(channel).lower()
session = self.ssession()
Expand All @@ -445,6 +447,8 @@ def get_channel_value(self, channel, key):
.one_or_none()
if result is not None:
result = result.value
elif default is not None:
result = default
return _deserialize(result)
except SQLAlchemyError:
session.rollback()
Expand Down Expand Up @@ -498,7 +502,7 @@ def delete_plugin_value(self, plugin, key):
finally:
session.close()

def get_plugin_value(self, plugin, key):
def get_plugin_value(self, plugin, key, default=None):
"""Retrieves the value for a given key associated with a plugin."""
plugin = plugin.lower()
session = self.ssession()
Expand All @@ -509,6 +513,8 @@ def get_plugin_value(self, plugin, key):
.one_or_none()
if result is not None:
result = result.value
elif default is not None:
result = default
return _deserialize(result)
except SQLAlchemyError:
session.rollback()
Expand All @@ -518,13 +524,13 @@ def get_plugin_value(self, plugin, key):

# NICK AND CHANNEL FUNCTIONS

def get_nick_or_channel_value(self, name, key):
def get_nick_or_channel_value(self, name, key, default=None):
"""Gets the value `key` associated to the nick or channel `name`."""
name = Identifier(name)
if name.is_nick():
return self.get_nick_value(name, key)
return self.get_nick_value(name, key, default)
else:
return self.get_channel_value(name, key)
return self.get_channel_value(name, key, default)

def get_preferred_value(self, names, key):
"""Gets the value for the first name which has it set.
Expand Down
4 changes: 1 addition & 3 deletions sopel/modules/adminchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,4 @@ def set_mask(bot, trigger):
@commands('showmask')
def show_mask(bot, trigger):
"""Show the topic mask for the current channel."""
mask = bot.db.get_channel_value(trigger.sender, 'topic_mask')
mask = mask or default_mask(trigger)
bot.say(mask)
bot.say(bot.db.get_channel_value(trigger.sender, 'topic_mask', default_mask(trigger)))
20 changes: 20 additions & 0 deletions test/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def test_get_nick_value(db):
assert found_value == value


def test_get_nick_value_default(db):
assert db.get_nick_value("TestUser", "DoesntExist") is None
assert db.get_nick_value("TestUser", "DoesntExist", "MyDefault") == "MyDefault"


def test_delete_nick_value(db):
nick = 'Embolalia'
db.set_nick_value(nick, 'wasd', 'uldr')
Expand Down Expand Up @@ -255,13 +260,23 @@ def test_get_channel_value(db):
assert result == 'zxcv'


def test_get_channel_value_default(db):
assert db.get_channel_value("TestChan", "DoesntExist") is None
assert db.get_channel_value("TestChan", "DoesntExist", "MyDefault") == "MyDefault"


def test_get_nick_or_channel_value(db):
db.set_nick_value('asdf', 'qwer', 'poiu')
db.set_channel_value('#asdf', 'qwer', '/.,m')
assert db.get_nick_or_channel_value('asdf', 'qwer') == 'poiu'
assert db.get_nick_or_channel_value('#asdf', 'qwer') == '/.,m'


def test_get_nick_or_channel_value_default(db):
assert db.get_nick_or_channel_value("Test", "DoesntExist") is None
assert db.get_nick_or_channel_value("Test", "DoesntExist", "MyDefault") == "MyDefault"


def test_get_preferred_value(db):
db.set_nick_value('asdf', 'qwer', 'poiu')
db.set_channel_value('#asdf', 'qwer', '/.,m')
Expand All @@ -288,6 +303,11 @@ def test_get_plugin_value(db):
assert result == 'zxcv'


def test_get_plugin_value_default(db):
assert db.get_plugin_value("TestPlugin", "DoesntExist") is None
assert db.get_plugin_value("TestPlugin", "DoesntExist", "MyDefault") == "MyDefault"


def test_delete_plugin_value(db):
db.set_plugin_value('plugin', 'wasd', 'uldr')
assert db.get_plugin_value('plugin', 'wasd') == 'uldr'
Expand Down

0 comments on commit d00e08a

Please sign in to comment.