Skip to content

Commit

Permalink
ModelLoader Columns objects support in column loader
Browse files Browse the repository at this point in the history
  • Loading branch information
jekel committed Sep 3, 2018
1 parent 365fd60 commit cb0e174
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
17 changes: 15 additions & 2 deletions gino/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, model, *column_names, **extras):
self.model = model
self._distinct = None
if column_names:
self.columns = [getattr(model, name) for name in column_names]
self.columns = self._column_loader(model, column_names)
else:
self.columns = model
self.extras = dict((key, self.get(value))
Expand Down Expand Up @@ -123,11 +123,24 @@ def get_from(self):

def load(self, *column_names, **extras):
if column_names:
self.columns = [getattr(self.model, name) for name in column_names]
self.columns = self._column_loader(self.model, column_names)

self.extras.update((key, self.get(value))
for key, value in extras.items())
return self

@classmethod
def _column_loader(cls, model, column_names):
def column_formatter(column_name):
if isinstance(column_name, str):
return getattr(model, column_name)
elif isinstance(column_name, Column):
return column_name
else:
raise TypeError('Unknown column name {} type {}'.format(column_name, type(column_name)))

return [column_formatter(column_name) for column_name in column_names]

def on(self, on_clause):
self.on_clause = on_clause
return self
Expand Down
24 changes: 24 additions & 0 deletions tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ async def test_model_load(user):
assert u.id is None
assert u.nickname == user.nickname

with pytest.raises(TypeError):
await User.query.gino.load(User.load(123)).first()


async def test_216_model_load_passive_partial(user):
u = await db.select([User.nickname]).gino.model(User).first()
Expand Down Expand Up @@ -144,6 +147,27 @@ def loader(row, context):
assert u.team.parent.name == user.team.parent.name


async def test_adjanency_list_on_nested_load(user):
subquery = db.select(User).alias()
base_query = subquery.outerjoin(Team).select()

query = base_query.execution_options(loader=(User.load('id')))
u = await query.gino.first()
# Because here arrives team_id, not user_id, and replaces it
assert u.id is None

query = base_query.execution_options(loader=(User.load(
*(map(subquery.corresponding_column, User)), team=Team)))
u = await query.gino.first()
assert u.id == user.id
assert u.realname == user.realname
assert u.nickname == user.nickname

assert isinstance(u.team, Team)
assert u.team.id == user.team.id
assert u.team.name == user.team.name


async def test_adjacency_list_query_builder(user):
group = Team.alias()
u = await User.load(team=Team.load(parent=group.on(
Expand Down

0 comments on commit cb0e174

Please sign in to comment.