Skip to content

Commit

Permalink
db: allow requesting a default if key not found
Browse files Browse the repository at this point in the history
  • Loading branch information
half-duplex committed Aug 3, 2019
1 parent 5b21488 commit f018746
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
14 changes: 9 additions & 5 deletions sopel/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def delete_nick_value(self, nick, key):
finally:
session.close()

def get_nick_value(self, nick, key):
def get_nick_value(self, nick, key, default=None):
"""Retrieves the value for a given key associated with a nick."""
nick = Identifier(nick)
session = self.ssession()
Expand All @@ -288,6 +288,8 @@ def get_nick_value(self, nick, key):
.one_or_none()
if result is not None:
result = result.value
elif default is not None:
result = default
return _deserialize(result)
except SQLAlchemyError:
session.rollback()
Expand Down Expand Up @@ -415,7 +417,7 @@ def delete_channel_value(self, channel, key):
finally:
session.close()

def get_channel_value(self, channel, key):
def get_channel_value(self, channel, key, default=None):
"""Retrieves the value for a given key associated with a channel."""
channel = Identifier(channel).lower()
session = self.ssession()
Expand All @@ -426,6 +428,8 @@ def get_channel_value(self, channel, key):
.one_or_none()
if result is not None:
result = result.value
elif default is not None:
result = default
return _deserialize(result)
except SQLAlchemyError:
session.rollback()
Expand All @@ -435,13 +439,13 @@ def get_channel_value(self, channel, key):

# NICK AND CHANNEL FUNCTIONS

def get_nick_or_channel_value(self, name, key):
def get_nick_or_channel_value(self, name, key, default=None):
"""Gets the value `key` associated to the nick or channel `name`."""
name = Identifier(name)
if name.is_nick():
return self.get_nick_value(name, key)
return self.get_nick_value(name, key, default)
else:
return self.get_channel_value(name, key)
return self.get_channel_value(name, key, default)

def get_preferred_value(self, names, key):
"""Gets the value for the first name which has it set.
Expand Down
15 changes: 15 additions & 0 deletions test/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def test_get_nick_value(db):
assert found_value == value


def test_get_nick_value_default(db):
assert db.get_nick_value("TestUser", "DoesntExist") == None
assert db.get_nick_value("TestUser", "DoesntExist", "MyDefault") == "MyDefault"


def test_delete_nick_value(db):
nick = 'Embolalia'
db.set_nick_value(nick, 'wasd', 'uldr')
Expand Down Expand Up @@ -255,13 +260,23 @@ def test_get_channel_value(db):
assert result == 'zxcv'


def test_get_channel_value_default(db):
assert db.get_channel_value("TestChan", "DoesntExist") == None
assert db.get_channel_value("TestChan", "DoesntExist", "MyDefault") == "MyDefault"


def test_get_nick_or_channel_value(db):
db.set_nick_value('asdf', 'qwer', 'poiu')
db.set_channel_value('#asdf', 'qwer', '/.,m')
assert db.get_nick_or_channel_value('asdf', 'qwer') == 'poiu'
assert db.get_nick_or_channel_value('#asdf', 'qwer') == '/.,m'


def test_get_nick_or_channel_value_default(db):
assert db.get_nick_or_channel_value("Test", "DoesntExist") == None
assert db.get_nick_or_channel_value("Test", "DoesntExist", "MyDefault") == "MyDefault"


def test_get_preferred_value(db):
db.set_nick_value('asdf', 'qwer', 'poiu')
db.set_channel_value('#asdf', 'qwer', '/.,m')
Expand Down

0 comments on commit f018746

Please sign in to comment.