diff --git a/gino/loader.py b/gino/loader.py index 05e8dbd7..944ec90f 100644 --- a/gino/loader.py +++ b/gino/loader.py @@ -56,7 +56,7 @@ def __init__(self, model, *column_names, **extras): self.model = model self._distinct = None if column_names: - self.columns = [getattr(model, name) for name in column_names] + self.columns = self._column_loader(model, column_names) else: self.columns = model self.extras = dict((key, self.get(value)) @@ -123,11 +123,28 @@ def get_from(self): def load(self, *column_names, **extras): if column_names: - self.columns = [getattr(self.model, name) for name in column_names] + self.columns = self._column_loader(self.model, column_names) + self.extras.update((key, self.get(value)) for key, value in extras.items()) return self + @classmethod + def _column_loader(cls, model, column_names): + def column_formatter(column_name): + if isinstance(column_name, str): + return getattr(model, column_name) + elif isinstance(column_name, Column): + if column_name not in model: + raise AttributeError('Column {} does not belong ' + 'to this model'.format(column_name)) + return column_name + else: + raise TypeError('Unknown column name {} type {}'. + format(column_name, type(column_name))) + + return [column_formatter(column_name) for column_name in column_names] + def on(self, on_clause): self.on_clause = on_clause return self diff --git a/tests/test_loader.py b/tests/test_loader.py index 0a5104d8..2bdf9409 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -43,10 +43,17 @@ async def test_scalar(user): async def test_model_load(user): - u = await User.query.gino.load(User.load('nickname')).first() + u = await User.query.gino.load(User.load('nickname', User.team_id)).first() assert isinstance(u, User) assert u.id is None assert u.nickname == user.nickname + assert u.team_id == user.team.id + + with pytest.raises(TypeError): + await User.query.gino.load(User.load(123)).first() + + with pytest.raises(AttributeError): + await User.query.gino.load(User.load(Team.id)).first() async def test_216_model_load_passive_partial(user):