From cb0e174ddfbc7017d96967b22cae31aa6ff95557 Mon Sep 17 00:00:00 2001 From: Jekel Date: Fri, 31 Aug 2018 20:21:44 +0300 Subject: [PATCH] ModelLoader Columns objects support in column loader --- gino/loader.py | 17 +++++++++++++++-- tests/test_loader.py | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/gino/loader.py b/gino/loader.py index 05e8dbd7..4edc6964 100644 --- a/gino/loader.py +++ b/gino/loader.py @@ -56,7 +56,7 @@ def __init__(self, model, *column_names, **extras): self.model = model self._distinct = None if column_names: - self.columns = [getattr(model, name) for name in column_names] + self.columns = self._column_loader(model, column_names) else: self.columns = model self.extras = dict((key, self.get(value)) @@ -123,11 +123,24 @@ def get_from(self): def load(self, *column_names, **extras): if column_names: - self.columns = [getattr(self.model, name) for name in column_names] + self.columns = self._column_loader(self.model, column_names) + self.extras.update((key, self.get(value)) for key, value in extras.items()) return self + @classmethod + def _column_loader(cls, model, column_names): + def column_formatter(column_name): + if isinstance(column_name, str): + return getattr(model, column_name) + elif isinstance(column_name, Column): + return column_name + else: + raise TypeError('Unknown column name {} type {}'.format(column_name, type(column_name))) + + return [column_formatter(column_name) for column_name in column_names] + def on(self, on_clause): self.on_clause = on_clause return self diff --git a/tests/test_loader.py b/tests/test_loader.py index 0a5104d8..1950116c 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -48,6 +48,9 @@ async def test_model_load(user): assert u.id is None assert u.nickname == user.nickname + with pytest.raises(TypeError): + await User.query.gino.load(User.load(123)).first() + async def test_216_model_load_passive_partial(user): u = await db.select([User.nickname]).gino.model(User).first() @@ -144,6 +147,27 @@ def loader(row, context): assert u.team.parent.name == user.team.parent.name +async def test_adjanency_list_on_nested_load(user): + subquery = db.select(User).alias() + base_query = subquery.outerjoin(Team).select() + + query = base_query.execution_options(loader=(User.load('id'))) + u = await query.gino.first() + # Because here arrives team_id, not user_id, and replaces it + assert u.id is None + + query = base_query.execution_options(loader=(User.load( + *(map(subquery.corresponding_column, User)), team=Team))) + u = await query.gino.first() + assert u.id == user.id + assert u.realname == user.realname + assert u.nickname == user.nickname + + assert isinstance(u.team, Team) + assert u.team.id == user.team.id + assert u.team.name == user.team.name + + async def test_adjacency_list_query_builder(user): group = Team.alias() u = await User.load(team=Team.load(parent=group.on(