diff --git a/tests/apps/app/commands/test_new.py b/tests/apps/app/commands/test_new.py index efdacdb5f..fe2addd88 100644 --- a/tests/apps/app/commands/test_new.py +++ b/tests/apps/app/commands/test_new.py @@ -39,5 +39,7 @@ def test_new_with_clashing_name(self): exception = context.exception self.assertTrue( - exception.code.startswith("A module called sys already exists") + str(exception.code).startswith( + "A module called sys already exists" + ) ) diff --git a/tests/apps/fixtures/commands/test_dump_load.py b/tests/apps/fixtures/commands/test_dump_load.py index 407bd4504..59c4d04a2 100644 --- a/tests/apps/fixtures/commands/test_dump_load.py +++ b/tests/apps/fixtures/commands/test_dump_load.py @@ -276,5 +276,3 @@ def test_on_conflict(self): run_sync(load(path=json_file_path, on_conflict="DO NOTHING")) run_sync(load(path=json_file_path, on_conflict="DO UPDATE")) - run_sync(load(path=json_file_path, on_conflict="do nothing")) - run_sync(load(path=json_file_path, on_conflict="do update")) diff --git a/tests/apps/fixtures/commands/test_shared.py b/tests/apps/fixtures/commands/test_shared.py index b2246aa71..34e2af4ee 100644 --- a/tests/apps/fixtures/commands/test_shared.py +++ b/tests/apps/fixtures/commands/test_shared.py @@ -56,5 +56,5 @@ def test_shared(self): } model = pydantic_model(**data) - self.assertEqual(model.mega.SmallTable[0].id, 1) - self.assertEqual(model.mega.MegaTable[0].id, 1) + self.assertEqual(model.mega.SmallTable[0].id, 1) # type: ignore + self.assertEqual(model.mega.MegaTable[0].id, 1) # type: ignore diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index 8d58bb8ce..a1988a029 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -1,5 +1,6 @@ import asyncio import random +import typing as t from io import StringIO from unittest import TestCase from unittest.mock import MagicMock, patch @@ -267,7 +268,7 @@ def test_add_table(self, get_app_config: MagicMock): self.assertEqual(self.table_exists("musician"), False) @engines_only("postgres", "cockroach") - def test_add_column(self): + def test_add_column(self) -> None: """ Test adding a column to a MigrationManager. """ @@ -304,21 +305,21 @@ def test_add_column(self): response = self.run_sync("SELECT * FROM manager;") self.assertEqual(response, [{"id": 1, "name": "Dave"}]) - id = 0 + row_id: t.Optional[int] = None if engine_is("cockroach"): - id = self.run_sync( + row_id = self.run_sync( "INSERT INTO manager VALUES (default, 'Dave', 'dave@me.com') RETURNING id;" # noqa: E501 - ) + )[0]["id"] response = self.run_sync("SELECT * FROM manager;") self.assertEqual( response, - [{"id": id[0]["id"], "name": "Dave", "email": "dave@me.com"}], + [{"id": row_id, "name": "Dave", "email": "dave@me.com"}], ) # Reverse asyncio.run(manager.run(backwards=True)) response = self.run_sync("SELECT * FROM manager;") - self.assertEqual(response, [{"id": id[0]["id"], "name": "Dave"}]) + self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) # Preview manager.preview = True @@ -333,7 +334,7 @@ def test_add_column(self): if engine_is("postgres"): self.assertEqual(response, [{"id": 1, "name": "Dave"}]) if engine_is("cockroach"): - self.assertEqual(response, [{"id": id[0]["id"], "name": "Dave"}]) + self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) @engines_only("postgres", "cockroach") def test_add_column_with_index(self): diff --git a/tests/apps/user/test_tables.py b/tests/apps/user/test_tables.py index d48e2cc75..59d274905 100644 --- a/tests/apps/user/test_tables.py +++ b/tests/apps/user/test_tables.py @@ -221,7 +221,10 @@ def test_long_password_error(self): def test_no_username_error(self): with self.assertRaises(ValueError) as manager: - BaseUser.create_user_sync(username=None, password="abc123") + BaseUser.create_user_sync( + username=None, # type: ignore + password="abc123", + ) self.assertEqual( manager.exception.__str__(), "A username must be provided." @@ -229,7 +232,10 @@ def test_no_username_error(self): def test_no_password_error(self): with self.assertRaises(ValueError) as manager: - BaseUser.create_user_sync(username="bob", password=None) + BaseUser.create_user_sync( + username="bob", + password=None, # type: ignore + ) self.assertEqual( manager.exception.__str__(), "A password must be provided." @@ -272,12 +278,14 @@ def test_hash_update(self): BaseUser.login_sync(username=username, password=password) ) - hashed_password = ( + user_data = ( BaseUser.select(BaseUser.password) .where(BaseUser.id == user.id) .first() - .run_sync()["password"] + .run_sync() ) + assert user_data is not None + hashed_password = user_data["password"] algorithm, iterations_, salt, hashed = BaseUser.split_stored_password( hashed_password diff --git a/tests/base.py b/tests/base.py index 6f9bfbb55..b05f85622 100644 --- a/tests/base.py +++ b/tests/base.py @@ -19,19 +19,19 @@ ENGINE = engine_finder() -def engine_version_lt(version: float): - return ENGINE and run_sync(ENGINE.get_version()) < version +def engine_version_lt(version: float) -> bool: + return ENGINE is not None and run_sync(ENGINE.get_version()) < version -def is_running_postgres(): +def is_running_postgres() -> bool: return type(ENGINE) is PostgresEngine -def is_running_sqlite(): +def is_running_sqlite() -> bool: return type(ENGINE) is SQLiteEngine -def is_running_cockroach(): +def is_running_cockroach() -> bool: return type(ENGINE) is CockroachEngine @@ -228,6 +228,8 @@ def get_postgres_varchar_length( ########################################################################### def create_tables(self): + assert ENGINE is not None + if ENGINE.engine_type in ("postgres", "cockroach"): self.run_sync( """ @@ -308,6 +310,8 @@ def create_tables(self): raise Exception("Unrecognised engine") def insert_row(self): + assert ENGINE is not None + if ENGINE.engine_type == "cockroach": id = self.run_sync( """ @@ -352,6 +356,8 @@ def insert_row(self): ) def insert_rows(self): + assert ENGINE is not None + if ENGINE.engine_type == "cockroach": id = self.run_sync( """ @@ -428,6 +434,8 @@ def insert_many_rows(self, row_count=10000): self.run_sync(f"INSERT INTO manager (name) VALUES {values_string};") def drop_tables(self): + assert ENGINE is not None + if ENGINE.engine_type in ("postgres", "cockroach"): self.run_sync("DROP TABLE IF EXISTS band CASCADE;") self.run_sync("DROP TABLE IF EXISTS manager CASCADE;") diff --git a/tests/columns/foreign_key/test_all_columns.py b/tests/columns/foreign_key/test_all_columns.py index e2718ce5b..0d6828ddf 100644 --- a/tests/columns/foreign_key/test_all_columns.py +++ b/tests/columns/foreign_key/test_all_columns.py @@ -24,17 +24,17 @@ def test_all_columns_deep(self): """ Make sure ``all_columns`` works when the joins are several layers deep. """ - all_columns = Concert.band_1.manager.all_columns() - self.assertEqual(all_columns, [Band.manager.id, Band.manager.name]) + all_columns = Concert.band_1._.manager.all_columns() + self.assertEqual(all_columns, [Band.manager._.id, Band.manager._.name]) # Make sure the call chains are also correct. self.assertEqual( all_columns[0]._meta.call_chain, - Concert.band_1.manager.id._meta.call_chain, + Concert.band_1._.manager._.id._meta.call_chain, ) self.assertEqual( all_columns[1]._meta.call_chain, - Concert.band_1.manager.name._meta.call_chain, + Concert.band_1._.manager._.name._meta.call_chain, ) def test_all_columns_exclude(self): diff --git a/tests/columns/foreign_key/test_all_related.py b/tests/columns/foreign_key/test_all_related.py index 737ad1924..94ebf7dc2 100644 --- a/tests/columns/foreign_key/test_all_related.py +++ b/tests/columns/foreign_key/test_all_related.py @@ -38,13 +38,13 @@ def test_all_related_deep(self): """ Make sure ``all_related`` works when the joins are several layers deep. """ - all_related = Ticket.concert.band_1.all_related() - self.assertEqual(all_related, [Ticket.concert.band_1.manager]) + all_related = Ticket.concert._.band_1.all_related() + self.assertEqual(all_related, [Ticket.concert._.band_1._.manager]) # Make sure the call chains are also correct. self.assertEqual( all_related[0]._meta.call_chain, - Ticket.concert.band_1.manager._meta.call_chain, + Ticket.concert._.band_1._.manager._meta.call_chain, ) def test_all_related_exclude(self): @@ -57,6 +57,6 @@ def test_all_related_exclude(self): ) self.assertEqual( - Ticket.concert.all_related(exclude=[Ticket.concert.venue]), + Ticket.concert.all_related(exclude=[Ticket.concert._.venue]), [Ticket.concert.band_1, Ticket.concert.band_2], ) diff --git a/tests/columns/foreign_key/test_attribute_access.py b/tests/columns/foreign_key/test_attribute_access.py index ccfc77818..597b33bd6 100644 --- a/tests/columns/foreign_key/test_attribute_access.py +++ b/tests/columns/foreign_key/test_attribute_access.py @@ -60,6 +60,6 @@ def test_recursion_time(self): Make sure that a really large call chain doesn't take too long. """ start = time.time() - Manager.manager.manager.manager.manager.manager.manager.name + Manager.manager._.manager._.manager._.manager._.manager._.manager._.name # noqa: E501 end = time.time() self.assertLess(end - start, 1.0) diff --git a/tests/columns/foreign_key/test_foreign_key_self.py b/tests/columns/foreign_key/test_foreign_key_self.py index 830c147ce..18c35e337 100644 --- a/tests/columns/foreign_key/test_foreign_key_self.py +++ b/tests/columns/foreign_key/test_foreign_key_self.py @@ -1,10 +1,11 @@ from unittest import TestCase -from piccolo.columns import ForeignKey, Varchar +from piccolo.columns import ForeignKey, Serial, Varchar from piccolo.table import Table class Manager(Table, tablename="manager"): + id: Serial name = Varchar() manager: ForeignKey["Manager"] = ForeignKey("self", null=True) diff --git a/tests/columns/foreign_key/test_schema.py b/tests/columns/foreign_key/test_schema.py index 121e32ebd..7e6b45c18 100644 --- a/tests/columns/foreign_key/test_schema.py +++ b/tests/columns/foreign_key/test_schema.py @@ -84,7 +84,7 @@ def test_with_schema(self): query = Concert.select( Concert.start_date, Concert.band.name.as_alias("band_name"), - Concert.band.manager.name.as_alias("manager_name"), + Concert.band._.manager._.name.as_alias("manager_name"), ) self.assertIn('"schema_1"."concert"', query.__str__()) self.assertIn('"schema_1"."band"', query.__str__()) diff --git a/tests/columns/m2m/base.py b/tests/columns/m2m/base.py index a3f282d23..6386ffcaf 100644 --- a/tests/columns/m2m/base.py +++ b/tests/columns/m2m/base.py @@ -1,27 +1,53 @@ import typing as t +from piccolo.columns.column_types import ( + ForeignKey, + LazyTableReference, + Serial, + Text, + Varchar, +) +from piccolo.columns.m2m import M2M from piccolo.engine.finder import engine_finder +from piccolo.schema import SchemaManager from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync from tests.base import engine_is, engines_skip engine = engine_finder() +class Band(Table): + id: Serial + name = Varchar() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class Genre(Table): + id: Serial + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class GenreToBand(Table): + id: Serial + band = ForeignKey(Band) + genre = ForeignKey(Genre) + reason = Text(help_text="For testing additional columns on join tables.") + + class M2MBase: """ This allows us to test M2M when the tables are in different schemas (public vs non-public). """ - band: t.Type[Table] - genre: t.Type[Table] - genre_to_band: t.Type[Table] - all_tables: t.List[t.Type[Table]] + def _setUp(self, schema: t.Optional[str] = None): + self.schema = schema - def setUp(self): - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band + for table_class in (Band, Genre, GenreToBand): + table_class._meta.schema = schema + + self.all_tables = [Band, Genre, GenreToBand] create_db_tables_sync(*self.all_tables, if_not_exists=True) @@ -77,14 +103,22 @@ def setUp(self): def tearDown(self): drop_db_tables_sync(*self.all_tables) + if self.schema: + SchemaManager().drop_schema( + schema_name="schema_1", cascade=True + ).run_sync() + + def assertEqual(self, first, second, msg=None): + assert first == second + + def assertTrue(self, first, msg=None): + assert first is True + @engines_skip("cockroach") def test_select_name(self): """ 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg """ # noqa: E501 - Band = self.band - Genre = self.genre - response = Band.select( Band.name, Band.genres(Genre.name, as_list=True) ).run_sync() @@ -118,9 +152,6 @@ def test_no_related(self): """ 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg """ # noqa: E501 - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band GenreToBand.delete(force=True).run_sync() @@ -156,8 +187,6 @@ def test_select_multiple(self): """ 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg """ # noqa: E501 - Band = self.band - Genre = self.genre response = Band.select( Band.name, Band.genres(Genre.id, Genre.name) @@ -218,8 +247,6 @@ def test_select_id(self): """ 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg """ # noqa: E501 - Band = self.band - Genre = self.genre response = Band.select( Band.name, Band.genres(Genre.id, as_list=True) @@ -257,8 +284,6 @@ def test_select_all_columns(self): 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/71908 "could not decorrelate subquery" error under asyncpg """ # noqa: E501 - Band = self.band - Genre = self.genre response = Band.select( Band.name, Band.genres(Genre.all_columns(exclude=(Genre.id,))) @@ -288,11 +313,9 @@ def test_add_m2m(self): """ Make sure we can add items to the joining table. """ - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None band.add_m2m(Genre(name="Punk Rock"), m2m=Band.genres).run_sync() self.assertTrue( @@ -314,13 +337,11 @@ def test_extra_columns_str(self): Make sure the ``extra_column_values`` parameter for ``add_m2m`` works correctly when the dictionary keys are strings. """ - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band reason = "Their second album was very punk rock." band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None band.add_m2m( Genre(name="Punk Rock"), m2m=Band.genres, @@ -329,7 +350,7 @@ def test_extra_columns_str(self): }, ).run_sync() - genre_to_band = ( + Genreto_band = ( GenreToBand.objects() .get( (GenreToBand.band.name == "Pythonistas") @@ -337,21 +358,20 @@ def test_extra_columns_str(self): ) .run_sync() ) + assert Genreto_band is not None - self.assertEqual(genre_to_band.reason, reason) + self.assertEqual(Genreto_band.reason, reason) def test_extra_columns_class(self): """ Make sure the ``extra_column_values`` parameter for ``add_m2m`` works correctly when the dictionary keys are ``Column`` classes. """ - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band reason = "Their second album was very punk rock." band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None band.add_m2m( Genre(name="Punk Rock"), m2m=Band.genres, @@ -360,7 +380,7 @@ def test_extra_columns_class(self): }, ).run_sync() - genre_to_band = ( + Genreto_band = ( GenreToBand.objects() .get( (GenreToBand.band.name == "Pythonistas") @@ -368,20 +388,20 @@ def test_extra_columns_class(self): ) .run_sync() ) + assert Genreto_band is not None - self.assertEqual(genre_to_band.reason, reason) + self.assertEqual(Genreto_band.reason, reason) def test_add_m2m_existing(self): """ Make sure we can add an existing element to the joining table. """ - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None genre = Genre.objects().get(Genre.name == "Classical").run_sync() + assert genre is not None band.add_m2m(genre, m2m=Band.genres).run_sync() @@ -404,9 +424,9 @@ def test_get_m2m(self): """ Make sure we can get related items via the joining table. """ - Band = self.band band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None genres = band.get_m2m(Band.genres).run_sync() @@ -418,13 +438,12 @@ def test_remove_m2m(self): """ Make sure we can remove related items via the joining table. """ - Band = self.band - Genre = self.genre - GenreToBand = self.genre_to_band band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None genre = Genre.objects().get(Genre.name == "Rock").run_sync() + assert genre is not None band.remove_m2m(genre, m2m=Band.genres).run_sync() diff --git a/tests/columns/m2m/test_m2m.py b/tests/columns/m2m/test_m2m.py index d897eace6..85731a715 100644 --- a/tests/columns/m2m/test_m2m.py +++ b/tests/columns/m2m/test_m2m.py @@ -44,27 +44,9 @@ engine = engine_finder() -class Band(Table): - name = Varchar() - genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) - - -class Genre(Table): - name = Varchar() - bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) - - -class GenreToBand(Table): - band = ForeignKey(Band) - genre = ForeignKey(Genre) - reason = Text(help_text="For testing additional columns on join tables.") - - class TestM2M(M2MBase, TestCase): - band = Band - genre = Genre - genre_to_band = GenreToBand - all_tables = [Band, Genre, GenreToBand] + def setUp(self): + return self._setUp(schema=None) ############################################################################### @@ -161,6 +143,7 @@ def test_add_m2m(self): Make sure we can add items to the joining table. """ customer = Customer.objects().get(Customer.name == "Bob").run_sync() + assert customer is not None customer.add_m2m( Concert(name="Jazzfest"), m2m=Customer.concerts ).run_sync() @@ -192,6 +175,7 @@ def test_add_m2m_within_transaction(self): async def add_m2m_in_transaction(): async with engine.transaction(): customer = await Customer.objects().get(Customer.name == "Bob") + assert customer is not None await customer.add_m2m( Concert(name="Jazzfest"), m2m=Customer.concerts ) @@ -217,6 +201,7 @@ def test_get_m2m(self): Make sure we can get related items via the joining table. """ customer = Customer.objects().get(Customer.name == "Bob").run_sync() + assert customer is not None concerts = customer.get_m2m(Customer.concerts).run_sync() diff --git a/tests/columns/m2m/test_m2m_schema.py b/tests/columns/m2m/test_m2m_schema.py index f9b958838..01ed90681 100644 --- a/tests/columns/m2m/test_m2m_schema.py +++ b/tests/columns/m2m/test_m2m_schema.py @@ -1,35 +1,10 @@ from unittest import TestCase -from piccolo.columns.column_types import ( - ForeignKey, - LazyTableReference, - Text, - Varchar, -) -from piccolo.columns.m2m import M2M -from piccolo.schema import SchemaManager -from piccolo.table import Table from tests.base import engines_skip from .base import M2MBase -class Band(Table, schema="schema_1"): - name = Varchar() - genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) - - -class Genre(Table, schema="schema_1"): - name = Varchar() - bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) - - -class GenreToBand(Table, schema="schema_1"): - band = ForeignKey(Band) - genre = ForeignKey(Genre) - reason = Text(help_text="For testing additional columns on join tables.") - - @engines_skip("sqlite") class TestM2MWithSchema(M2MBase, TestCase): """ @@ -37,12 +12,5 @@ class TestM2MWithSchema(M2MBase, TestCase): works. """ - band = Band - genre = Genre - genre_to_band = GenreToBand - all_tables = [Band, Genre, GenreToBand] - - def tearDown(self): - SchemaManager().drop_schema( - schema_name="schema_1", cascade=True - ).run_sync() + def setUp(self): + return self._setUp(schema="schema_1") diff --git a/tests/columns/test_array.py b/tests/columns/test_array.py index 8db64c17d..bc3c863db 100644 --- a/tests/columns/test_array.py +++ b/tests/columns/test_array.py @@ -42,6 +42,7 @@ def test_storage(self): MyTable(value=[1, 2, 3]).save().run_sync() row = MyTable.objects().first().run_sync() + assert row is not None self.assertEqual(row.value, [1, 2, 3]) @engines_only("postgres") diff --git a/tests/columns/test_base.py b/tests/columns/test_base.py index db111f6f1..2ef4c2a8c 100644 --- a/tests/columns/test_base.py +++ b/tests/columns/test_base.py @@ -121,7 +121,7 @@ def test_non_column(self): Make sure non-column values don't match. """ for value in (1, "abc", None): - self.assertFalse(Manager.name._equals(value)) + self.assertFalse(Manager.name._equals(value)) # type: ignore def test_equals(self): """ diff --git a/tests/columns/test_boolean.py b/tests/columns/test_boolean.py index eea3df8d0..4c67ef6db 100644 --- a/tests/columns/test_boolean.py +++ b/tests/columns/test_boolean.py @@ -1,3 +1,4 @@ +import typing as t from unittest import TestCase from piccolo.columns.column_types import Boolean @@ -15,23 +16,30 @@ def setUp(self): def tearDown(self): MyTable.alter().drop_table().run_sync() - def test_return_type(self): + def test_return_type(self) -> None: for value in (True, False, None, ...): - kwargs = {} if value is ... else {"boolean": value} + kwargs: t.Dict[str, t.Any] = ( + {} if value is ... else {"boolean": value} + ) expected = MyTable.boolean.default if value is ... else value row = MyTable(**kwargs) row.save().run_sync() self.assertEqual(row.boolean, expected) - self.assertEqual( + row_from_db = ( MyTable.select(MyTable.boolean) .where( MyTable._meta.primary_key == getattr(row, MyTable._meta.primary_key._meta.name) ) .first() - .run_sync()["boolean"], + .run_sync() + ) + assert row_from_db is not None + + self.assertEqual( + row_from_db["boolean"], expected, ) diff --git a/tests/columns/test_bytea.py b/tests/columns/test_bytea.py index 666a5a8d6..6976a8840 100644 --- a/tests/columns/test_bytea.py +++ b/tests/columns/test_bytea.py @@ -59,4 +59,4 @@ def test_json_default(self): def test_invalid_default(self): with self.assertRaises(ValueError): for value in ("a", 1, ("x", "y", "z")): - Bytea(default=value) + Bytea(default=value) # type: ignore diff --git a/tests/columns/test_choices.py b/tests/columns/test_choices.py index ccea6b97f..d127a87d0 100644 --- a/tests/columns/test_choices.py +++ b/tests/columns/test_choices.py @@ -34,6 +34,7 @@ def test_default(self): """ Shirt().save().run_sync() shirt = Shirt.objects().first().run_sync() + assert shirt is not None self.assertEqual(shirt.size, "l") def test_update(self): diff --git a/tests/columns/test_date.py b/tests/columns/test_date.py index 3169d8bf6..5c1016211 100644 --- a/tests/columns/test_date.py +++ b/tests/columns/test_date.py @@ -27,6 +27,7 @@ def test_timestamp(self): row.save().run_sync() result = MyTable.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) @@ -43,4 +44,5 @@ def test_timestamp(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) diff --git a/tests/columns/test_db_column_name.py b/tests/columns/test_db_column_name.py index 24d9bd9eb..33beffed9 100644 --- a/tests/columns/test_db_column_name.py +++ b/tests/columns/test_db_column_name.py @@ -1,9 +1,10 @@ -from piccolo.columns.column_types import Integer, Varchar +from piccolo.columns.column_types import Integer, Serial, Varchar from piccolo.table import Table from tests.base import DBTestCase, engine_is, engines_only, engines_skip class Band(Table): + id: Serial name = Varchar(db_column_name="regrettable_column_name") popularity = Integer() @@ -48,6 +49,7 @@ def test_save(self): band.save().run_sync() band_from_db = Band.objects().first().run_sync() + assert band_from_db is not None self.assertEqual(band_from_db.name, "Pythonistas") def test_create(self): @@ -62,6 +64,7 @@ def test_create(self): self.assertEqual(band.name, "Pythonistas") band_from_db = Band.objects().first().run_sync() + assert band_from_db is not None self.assertEqual(band_from_db.name, "Pythonistas") def test_select(self): diff --git a/tests/columns/test_defaults.py b/tests/columns/test_defaults.py index db95b165c..77df731bf 100644 --- a/tests/columns/test_defaults.py +++ b/tests/columns/test_defaults.py @@ -35,32 +35,32 @@ def test_int(self): _type(default=0) _type(default=None, null=True) with self.assertRaises(ValueError): - _type(default="hello world") + _type(default="hello world") # type: ignore def test_text(self): for _type in (Text, Varchar): _type(default="") _type(default=None, null=True) with self.assertRaises(ValueError): - _type(default=123) + _type(default=123) # type: ignore def test_real(self): Real(default=0.0) Real(default=None, null=True) with self.assertRaises(ValueError): - Real(default="hello world") + Real(default="hello world") # type: ignore def test_double_precision(self): DoublePrecision(default=0.0) DoublePrecision(default=None, null=True) with self.assertRaises(ValueError): - DoublePrecision(default="hello world") + DoublePrecision(default="hello world") # type: ignore def test_numeric(self): Numeric(default=decimal.Decimal(1.0)) Numeric(default=None, null=True) with self.assertRaises(ValueError): - Numeric(default="hello world") + Numeric(default="hello world") # type: ignore def test_uuid(self): UUID(default=None, null=True) @@ -74,21 +74,21 @@ def test_time(self): Time(default=TimeNow()) Time(default=datetime.datetime.now().time()) with self.assertRaises(ValueError): - Time(default="hello world") + Time(default="hello world") # type: ignore def test_date(self): Date(default=None, null=True) Date(default=DateNow()) Date(default=datetime.datetime.now().date()) with self.assertRaises(ValueError): - Date(default="hello world") + Date(default="hello world") # type: ignore def test_timestamp(self): Timestamp(default=None, null=True) Timestamp(default=TimestampNow()) Timestamp(default=datetime.datetime.now()) with self.assertRaises(ValueError): - Timestamp(default="hello world") + Timestamp(default="hello world") # type: ignore def test_foreignkey(self): class MyTable(Table): diff --git a/tests/columns/test_double_precision.py b/tests/columns/test_double_precision.py index 20d63331b..99e411f11 100644 --- a/tests/columns/test_double_precision.py +++ b/tests/columns/test_double_precision.py @@ -20,5 +20,6 @@ def test_creation(self): row.save().run_sync() _row = MyTable.objects().first().run_sync() + assert _row is not None self.assertEqual(type(_row.column_a), float) self.assertAlmostEqual(_row.column_a, 1.23) diff --git a/tests/columns/test_interval.py b/tests/columns/test_interval.py index 11f30c670..d9e7c6f10 100644 --- a/tests/columns/test_interval.py +++ b/tests/columns/test_interval.py @@ -48,6 +48,7 @@ def test_interval(self): .first() .run_sync() ) + assert result is not None self.assertEqual(result.interval, interval) def test_interval_where_clause(self): @@ -102,4 +103,5 @@ def test_interval(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() + assert result is not None self.assertEqual(result.interval.days, 1) diff --git a/tests/columns/test_json.py b/tests/columns/test_json.py index 932c244d3..69808163c 100644 --- a/tests/columns/test_json.py +++ b/tests/columns/test_json.py @@ -34,11 +34,11 @@ def test_json_string(self): row = MyTable(json='{"a": 1}') row.save().run_sync() + row_from_db = MyTable.select(MyTable.json).first().run_sync() + assert row_from_db is not None + self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row_from_db["json"].replace(" ", ""), '{"a":1}', ) @@ -49,11 +49,11 @@ def test_json_object(self): row = MyTable(json={"a": 1}) row.save().run_sync() + row_from_db = MyTable.select(MyTable.json).first().run_sync() + assert row_from_db is not None + self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row_from_db["json"].replace(" ", ""), '{"a":1}', ) @@ -78,7 +78,7 @@ def test_json_default(self): def test_invalid_default(self): with self.assertRaises(ValueError): for value in ("a", 1, ("x", "y", "z")): - JSON(default=value) + JSON(default=value) # type: ignore class TestJSONInsert(TestCase): @@ -89,11 +89,10 @@ def tearDown(self): MyTable.alter().drop_table().run_sync() def check_response(self): + row = MyTable.select(MyTable.json).first().run_sync() + assert row is not None self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row["json"].replace(" ", ""), '{"message":"original"}', ) @@ -125,11 +124,10 @@ def add_row(self): row.save().run_sync() def check_response(self): + row = MyTable.select(MyTable.json).first().run_sync() + assert row is not None self.assertEqual( - MyTable.select(MyTable.json) - .first() - .run_sync()["json"] - .replace(" ", ""), + row["json"].replace(" ", ""), '{"message":"updated"}', ) diff --git a/tests/columns/test_jsonb.py b/tests/columns/test_jsonb.py index 4a2ed1395..7c2be3a5a 100644 --- a/tests/columns/test_jsonb.py +++ b/tests/columns/test_jsonb.py @@ -159,6 +159,7 @@ def test_arrow(self): .first() .run_sync() ) + assert row is not None self.assertEqual(row["facilities"], "true") row = ( @@ -169,6 +170,7 @@ def test_arrow(self): .first() .run_sync() ) + assert row is not None self.assertEqual(row["facilities"], True) def test_arrow_as_alias(self): @@ -188,6 +190,7 @@ def test_arrow_as_alias(self): .first() .run_sync() ) + assert row is not None self.assertEqual(row["mixing_desk"], "true") def test_arrow_where(self): diff --git a/tests/columns/test_numeric.py b/tests/columns/test_numeric.py index 129db5946..872d739a9 100644 --- a/tests/columns/test_numeric.py +++ b/tests/columns/test_numeric.py @@ -22,6 +22,7 @@ def test_creation(self): row.save().run_sync() _row = MyTable.objects().first().run_sync() + assert _row is not None self.assertEqual(type(_row.column_a), Decimal) self.assertEqual(type(_row.column_b), Decimal) diff --git a/tests/columns/test_primary_key.py b/tests/columns/test_primary_key.py index 98d1f5d4c..1850944cc 100644 --- a/tests/columns/test_primary_key.py +++ b/tests/columns/test_primary_key.py @@ -133,6 +133,7 @@ def test_primary_key_queries(self): ) manager_dict = Manager.select().first().run_sync() + assert manager_dict is not None self.assertEqual( [i for i in manager_dict.keys()], @@ -151,6 +152,7 @@ def test_primary_key_queries(self): band.save().run_sync() band_dict = Band.select().first().run_sync() + assert band_dict is not None self.assertEqual( [i for i in band_dict.keys()], ["pk", "name", "manager"] @@ -163,6 +165,7 @@ def test_primary_key_queries(self): # type (i.e. `uuid.UUID`). manager = Manager.objects().first().run_sync() + assert manager is not None band_2 = Band(manager=manager.pk, name="Pythonistas 2") band_2.save().run_sync() @@ -178,9 +181,10 @@ def test_primary_key_queries(self): ####################################################################### # Make sure `get_related` works - self.assertEqual( - band_2.get_related(Band.manager).run_sync().pk, manager.pk - ) + manager_from_db = band_2.get_related(Band.manager).run_sync() + assert manager_from_db is not None + + self.assertEqual(manager_from_db.pk, manager.pk) ####################################################################### # Make sure `remove` works diff --git a/tests/columns/test_real.py b/tests/columns/test_real.py index 30dc4338d..3257111de 100644 --- a/tests/columns/test_real.py +++ b/tests/columns/test_real.py @@ -20,5 +20,6 @@ def test_creation(self): row.save().run_sync() _row = MyTable.objects().first().run_sync() + assert _row is not None self.assertEqual(type(_row.column_a), float) self.assertAlmostEqual(_row.column_a, 1.23) diff --git a/tests/columns/test_time.py b/tests/columns/test_time.py index b0be9768e..a6d931448 100644 --- a/tests/columns/test_time.py +++ b/tests/columns/test_time.py @@ -30,6 +30,7 @@ def test_timestamp(self): row.save().run_sync() result = MyTable.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) @@ -49,6 +50,7 @@ def test_timestamp(self): _datetime = partial(datetime.datetime, year=2020, month=1, day=1) result = MyTableDefault.objects().first().run_sync() + assert result is not None self.assertLess( _datetime( hour=result.created_on.hour, diff --git a/tests/columns/test_timestamp.py b/tests/columns/test_timestamp.py index 2c79728e9..ad1fa01f0 100644 --- a/tests/columns/test_timestamp.py +++ b/tests/columns/test_timestamp.py @@ -35,6 +35,7 @@ def test_timestamp(self): row.save().run_sync() result = MyTable.objects().first().run_sync() + assert result is not None self.assertEqual(result.created_on, created_on) def test_timezone_aware(self): @@ -61,6 +62,7 @@ def test_timestamp(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() + assert result is not None self.assertLess( result.created_on - created_on, datetime.timedelta(seconds=1) ) diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 09755e340..8e239900b 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -71,6 +71,7 @@ def test_timestamptz_timezone_aware(self): .first() .run_sync() ) + assert result is not None self.assertEqual(result.created_on, created_on) # The database converts it to UTC @@ -93,6 +94,7 @@ def test_timestamptz_default(self): row.save().run_sync() result = MyTableDefault.objects().first().run_sync() + assert result is not None delta = result.created_on - created_on self.assertLess(delta, datetime.timedelta(seconds=1)) self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) diff --git a/tests/conftest.py b/tests/conftest.py index dd71cba21..457c4dae7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,7 @@ async def drop_tables(): "mega_table", "small_table", ] - assert ENGINE + assert ENGINE is not None if ENGINE.engine_type == "sqlite": # SQLite doesn't allow us to drop more than one table at a time. diff --git a/tests/engine/test_extra_nodes.py b/tests/engine/test_extra_nodes.py index 2e078c9ee..0d59868c2 100644 --- a/tests/engine/test_extra_nodes.py +++ b/tests/engine/test_extra_nodes.py @@ -1,3 +1,4 @@ +import typing as t from unittest import TestCase from unittest.mock import MagicMock @@ -16,6 +17,9 @@ def test_extra_nodes(self): """ # Get the test database credentials: test_engine = engine_finder() + assert test_engine is not None + + test_engine = t.cast(PostgresEngine, test_engine) EXTRA_NODE = MagicMock(spec=PostgresEngine(config=test_engine.config)) EXTRA_NODE.run_querystring = AsyncMock(return_value=[]) diff --git a/tests/query/mixins/test_columns_delegate.py b/tests/query/mixins/test_columns_delegate.py index dd9110431..e16a13dda 100644 --- a/tests/query/mixins/test_columns_delegate.py +++ b/tests/query/mixins/test_columns_delegate.py @@ -37,26 +37,26 @@ def test_as_of(self): self.insert_rows() time.sleep(1) # Ensure time travel queries have some history to use! - result = ( + query = ( Band.select() .where(Band.name == "Pythonistas") .as_of("-500ms") .limit(1) ) - self.assertTrue("AS OF SYSTEM TIME '-500ms'" in str(result)) - result = result.run_sync() + self.assertTrue("AS OF SYSTEM TIME '-500ms'" in str(query)) + result = query.run_sync() self.assertTrue(result[0]["name"] == "Pythonistas") - result = Band.select().as_of() - self.assertTrue("AS OF SYSTEM TIME '-1s'" in str(result)) - result = result.run_sync() + query = Band.select().as_of() + self.assertTrue("AS OF SYSTEM TIME '-1s'" in str(query)) + result = query.run_sync() self.assertTrue(result[0]["name"] == "Pythonistas") # Alternative syntax. - result = Band.objects().get(Band.name == "Pythonistas").as_of("-1s") - self.assertTrue("AS OF SYSTEM TIME '-1s'" in str(result)) - result = result.run_sync() + query = Band.objects().get(Band.name == "Pythonistas").as_of("-1s") + self.assertTrue("AS OF SYSTEM TIME '-1s'" in str(query)) + result = query.run_sync() - self.assertTrue(result.name == "Pythonistas") + self.assertTrue(result.name == "Pythonistas") # type: ignore diff --git a/tests/query/test_slots.py b/tests/query/test_slots.py index 971019910..ad6322502 100644 --- a/tests/query/test_slots.py +++ b/tests/query/test_slots.py @@ -41,4 +41,4 @@ def test_attributes(self): AttributeError, msg=f"{class_name} didn't raised an error" ): print(f"Setting {class_name} attribute") - query_class(table=Manager).abc = 123 + query_class(table=Manager).abc = 123 # type: ignore diff --git a/tests/table/instance/test_get_related_readable.py b/tests/table/instance/test_get_related_readable.py index 2088c867f..982c4a5bc 100644 --- a/tests/table/instance/test_get_related_readable.py +++ b/tests/table/instance/test_get_related_readable.py @@ -34,7 +34,7 @@ def get_readable(cls): columns=[ cls.name, cls.thing_two.name, - cls.thing_two.thing_one.name, + cls.thing_two._.thing_one._.name, ], ) diff --git a/tests/table/instance/test_to_dict.py b/tests/table/instance/test_to_dict.py index b40bad790..b5d75f52b 100644 --- a/tests/table/instance/test_to_dict.py +++ b/tests/table/instance/test_to_dict.py @@ -10,6 +10,7 @@ def test_to_dict(self): self.insert_row() instance = Manager.objects().first().run_sync() + assert instance is not None dictionary = instance.to_dict() if engine_is("cockroach"): self.assertDictEqual( @@ -26,6 +27,7 @@ def test_nested(self): self.insert_row() instance = Band.objects(Band.manager).first().run_sync() + assert instance is not None dictionary = instance.to_dict() if engine_is("cockroach"): self.assertDictEqual( @@ -58,6 +60,7 @@ def test_filter_rows(self): self.insert_row() instance = Manager.objects().first().run_sync() + assert instance is not None dictionary = instance.to_dict(Manager.name) self.assertDictEqual(dictionary, {"name": "Guido"}) @@ -69,6 +72,7 @@ def test_nested_filter(self): self.insert_row() instance = Band.objects(Band.manager).first().run_sync() + assert instance is not None dictionary = instance.to_dict(Band.name, Band.manager.id) if engine_is("cockroach"): self.assertDictEqual( @@ -94,6 +98,7 @@ def test_aliases(self): self.insert_row() instance = Manager.objects().first().run_sync() + assert instance is not None dictionary = instance.to_dict( Manager.id, Manager.name.as_alias("title") ) diff --git a/tests/table/test_alter.py b/tests/table/test_alter.py index e8df1b004..836a552a8 100644 --- a/tests/table/test_alter.py +++ b/tests/table/test_alter.py @@ -90,7 +90,7 @@ class TestDropColumn(DBTestCase): SQLite has very limited support for ALTER statements. """ - def _test_drop(self, column: str): + def _test_drop(self, column: t.Union[str, Column]): self.insert_row() Band.alter().drop_column(column).run_sync() @@ -229,10 +229,9 @@ def test_integer_to_bigint(self): "BIGINT", ) - popularity = ( - Band.select(Band.popularity).first().run_sync()["popularity"] - ) - self.assertEqual(popularity, 1000) + row = Band.select(Band.popularity).first().run_sync() + assert row is not None + self.assertEqual(row["popularity"], 1000) def test_integer_to_varchar(self): """ @@ -252,10 +251,9 @@ def test_integer_to_varchar(self): "CHARACTER VARYING", ) - popularity = ( - Band.select(Band.popularity).first().run_sync()["popularity"] - ) - self.assertEqual(popularity, "1000") + row = Band.select(Band.popularity).first().run_sync() + assert row is not None + self.assertEqual(row["popularity"], "1000") def test_using_expression(self): """ @@ -271,8 +269,9 @@ def test_using_expression(self): ) alter_query.run_sync() - popularity = Band.select(Band.name).first().run_sync()["name"] - self.assertEqual(popularity, 1) + row = Band.select(Band.name).first().run_sync() + assert row is not None + self.assertEqual(row["name"], 1) @engines_only("postgres", "cockroach") @@ -321,12 +320,12 @@ def test_set_default(self): ).run_sync() manager = Manager.objects().first().run_sync() + assert manager is not None self.assertEqual(manager.name, "Pending") @engines_only("postgres", "cockroach") class TestSetSchema(TestCase): - schema_manager = SchemaManager() schema_name = "schema_1" diff --git a/tests/table/test_batch.py b/tests/table/test_batch.py index 8fca2944c..2de9e0c7d 100644 --- a/tests/table/test_batch.py +++ b/tests/table/test_batch.py @@ -104,6 +104,7 @@ def test_batch_extra_node(self): # Get the test database credentials: test_engine = engine_finder() + assert isinstance(test_engine, PostgresEngine) EXTRA_NODE = AsyncMock(spec=PostgresEngine(config=test_engine.config)) diff --git a/tests/table/test_inheritance.py b/tests/table/test_inheritance.py index 8030bb4b7..a7ab2c90e 100644 --- a/tests/table/test_inheritance.py +++ b/tests/table/test_inheritance.py @@ -61,6 +61,7 @@ def test_inheritance(self): ).save().run_sync() response = Manager.select().first().run_sync() + assert response is not None self.assertEqual(response["started_on"], started_on) self.assertEqual(response["name"], name) self.assertEqual(response["favourite"], favourite) @@ -98,6 +99,7 @@ def test_inheritance(self): _Table(name=name, started_on=started_on).save().run_sync() response = _Table.select().first().run_sync() + assert response is not None self.assertEqual(response["started_on"], started_on) self.assertEqual(response["name"], name) diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 1c5fab732..b2c58e378 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -3,7 +3,7 @@ import pytest -from piccolo.columns import Integer, Varchar +from piccolo.columns import Integer, Serial, Varchar from piccolo.query.methods.insert import OnConflictAction from piccolo.table import Table from piccolo.utils.lazy_loader import LazyLoader @@ -98,6 +98,7 @@ def test_insert_returning_alias(self): ) class TestOnConflict(TestCase): class Band(Table): + id: Serial name = Varchar(unique=True) popularity = Integer() diff --git a/tests/table/test_join.py b/tests/table/test_join.py index b5ebc867d..1d3aba2f3 100644 --- a/tests/table/test_join.py +++ b/tests/table/test_join.py @@ -93,7 +93,7 @@ def test_join(self): ) # Now make sure that even deeper joins work: - select_query = Concert.select(Concert.band_1.manager.name) + select_query = Concert.select(Concert.band_1._.manager._.name) response = select_query.run_sync() self.assertEqual(response, [{"band_1.manager.name": "Guido"}]) @@ -126,10 +126,11 @@ def test_select_all_columns(self): explicitly specifying them. """ result = ( - Band.select(Band.name, Band.manager.all_columns()) + Band.select(Band.name, *Band.manager.all_columns()) .first() .run_sync() ) + assert result is not None if engine_is("cockroach"): self.assertDictEqual( @@ -156,13 +157,14 @@ def test_select_all_columns_deep(self): """ result = ( Concert.select( - Concert.venue.all_columns(), - Concert.band_1.manager.all_columns(), - Concert.band_2.manager.all_columns(), + *Concert.venue.all_columns(), + *Concert.band_1._.manager.all_columns(), + *Concert.band_2._.manager.all_columns(), ) .first() .run_sync() ) + assert result is not None if engine_is("cockroach"): self.assertDictEqual( @@ -203,7 +205,8 @@ def test_proxy_columns(self): # We call it multiple times to make sure it doesn't change with time. for _ in range(2): self.assertEqual( - len(Concert.band_1.manager._foreign_key_meta.proxy_columns), 2 + len(Concert.band_1._.manager._foreign_key_meta.proxy_columns), + 2, ) self.assertEqual( len(Concert.band_1._foreign_key_meta.proxy_columns), 4 @@ -216,12 +219,13 @@ def test_select_all_columns_root(self): """ result = ( Band.select( - Band.all_columns(), - Band.manager.all_columns(), + *Band.all_columns(), + *Band.manager.all_columns(), ) .first() .run_sync() ) + assert result is not None if engine_is("cockroach"): self.assertDictEqual( @@ -254,11 +258,12 @@ def test_select_all_columns_root_nested(self): with using it for referenced tables. """ result = ( - Band.select(Band.all_columns(), Band.manager.all_columns()) + Band.select(*Band.all_columns(), *Band.manager.all_columns()) .output(nested=True) .first() .run_sync() ) + assert result is not None if engine_is("cockroach"): self.assertDictEqual( @@ -290,23 +295,25 @@ def test_select_all_columns_exclude(self): """ result = ( Band.select( - Band.all_columns(exclude=[Band.id]), - Band.manager.all_columns(exclude=[Band.manager.id]), + *Band.all_columns(exclude=[Band.id]), + *Band.manager.all_columns(exclude=[Band.manager.id]), ) .output(nested=True) .first() .run_sync() ) + assert result is not None result_str_args = ( Band.select( - Band.all_columns(exclude=["id"]), - Band.manager.all_columns(exclude=["id"]), + *Band.all_columns(exclude=["id"]), + *Band.manager.all_columns(exclude=["id"]), ) .output(nested=True) .first() .run_sync() ) + assert result_str_args is not None for data in (result, result_str_args): self.assertDictEqual( @@ -325,6 +332,7 @@ def test_objects_nested(self): Make sure the prefetch argument works correctly for objects. """ band = Band.objects(Band.manager).first().run_sync() + assert band is not None self.assertIsInstance(band.manager, Manager) def test_objects__all_related__root(self): @@ -333,6 +341,7 @@ def test_objects__all_related__root(self): root table of the query. """ concert = Concert.objects(Concert.all_related()).first().run_sync() + assert concert is not None self.assertIsInstance(concert.band_1, Band) self.assertIsInstance(concert.band_2, Band) self.assertIsInstance(concert.venue, Venue) @@ -344,15 +353,16 @@ def test_objects_nested_deep(self): ticket = ( Ticket.objects( Ticket.concert, - Ticket.concert.band_1, - Ticket.concert.band_2, - Ticket.concert.venue, - Ticket.concert.band_1.manager, - Ticket.concert.band_2.manager, + Ticket.concert._.band_1, + Ticket.concert._.band_2, + Ticket.concert._.venue, + Ticket.concert._.band_1._.manager, + Ticket.concert._.band_2._.manager, ) .first() .run_sync() ) + assert ticket is not None self.assertIsInstance(ticket.concert, Concert) self.assertIsInstance(ticket.concert.band_1, Band) @@ -370,12 +380,13 @@ def test_objects__all_related__deep(self): Ticket.objects( Ticket.all_related(), Ticket.concert.all_related(), - Ticket.concert.band_1.all_related(), - Ticket.concert.band_2.all_related(), + Ticket.concert._.band_1.all_related(), + Ticket.concert._.band_2.all_related(), ) .first() .run_sync() ) + assert ticket is not None self.assertIsInstance(ticket.concert, Concert) self.assertIsInstance(ticket.concert.band_1, Band) @@ -393,12 +404,13 @@ def test_objects_prefetch_clause(self): .prefetch( Ticket.all_related(), Ticket.concert.all_related(), - Ticket.concert.band_1.all_related(), - Ticket.concert.band_2.all_related(), + Ticket.concert._.band_1.all_related(), + Ticket.concert._.band_2.all_related(), ) .first() .run_sync() ) + assert ticket is not None self.assertIsInstance(ticket.concert, Concert) self.assertIsInstance(ticket.concert.band_1, Band) @@ -415,11 +427,12 @@ def test_objects_prefetch_intermediate(self): ticket = ( Ticket.objects() .prefetch( - Ticket.concert.band_1.manager, + Ticket.concert._.band_1._.manager, ) .first() .run_sync() ) + assert ticket is not None self.assertIsInstance(ticket.price, decimal.Decimal) self.assertIsInstance(ticket.concert, Concert) @@ -444,12 +457,13 @@ def test_objects_prefetch_multiple_intermediate(self): ticket = ( Ticket.objects() .prefetch( - Ticket.concert.band_1.manager, - Ticket.concert.band_2.manager, + Ticket.concert._.band_1._.manager, + Ticket.concert._.band_2._.manager, ) .first() .run_sync() ) + assert ticket is not None self.assertIsInstance(ticket.price, decimal.Decimal) self.assertIsInstance(ticket.concert, Concert) diff --git a/tests/table/test_join_on.py b/tests/table/test_join_on.py index 7983b218a..5be16158c 100644 --- a/tests/table/test_join_on.py +++ b/tests/table/test_join_on.py @@ -1,26 +1,28 @@ from unittest import TestCase -from piccolo.columns import Varchar +from piccolo.columns import Serial, Varchar from piccolo.table import Table class Manager(Table): + id: Serial name = Varchar(unique=True) email = Varchar(unique=True) class Band(Table): + id: Serial name = Varchar(unique=True) manager_name = Varchar() class Concert(Table): + id: Serial title = Varchar() band_name = Varchar() class TestJoinOn(TestCase): - tables = [Manager, Band, Concert] def setUp(self): diff --git a/tests/table/test_objects.py b/tests/table/test_objects.py index 853de6a13..e2db53ba9 100644 --- a/tests/table/test_objects.py +++ b/tests/table/test_objects.py @@ -66,6 +66,7 @@ def test_get(self): self.insert_row() band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None self.assertEqual(band.name, "Pythonistas") @@ -79,7 +80,8 @@ def test_get_prefetch(self): .prefetch(Band.manager) .run_sync() ) - self.assertIsInstance(band.manager, Manager) + assert band is not None + self.assertIsInstance(band.manager, Manager) # type: ignore # Just passing it straight into objects band = ( @@ -87,6 +89,7 @@ def test_get_prefetch(self): .get((Band.name == "Pythonistas")) .run_sync() ) + assert band is not None self.assertIsInstance(band.manager, Manager) @@ -97,12 +100,14 @@ def test_simple_where_clause(self): """ # When the row doesn't exist in the db: Band.objects().get_or_create( - Band.name == "Pink Floyd", defaults={"popularity": 100} + Band.name == "Pink Floyd", + defaults={"popularity": 100}, # type: ignore ).run_sync() instance = ( Band.objects().where(Band.name == "Pink Floyd").first().run_sync() ) + assert instance is not None self.assertIsInstance(instance, Band) self.assertEqual(instance.name, "Pink Floyd") @@ -116,6 +121,7 @@ def test_simple_where_clause(self): instance = ( Band.objects().where(Band.name == "Pink Floyd").first().run_sync() ) + assert instance is not None self.assertIsInstance(instance, Band) self.assertEqual(instance.name, "Pink Floyd") @@ -219,8 +225,8 @@ def test_prefetch_existing_object(self): .prefetch(Band.manager) .run_sync() ) - self.assertIsInstance(band.manager, Manager) - self.assertEqual(band.manager.name, "Guido") + self.assertIsInstance(band.manager, Manager) # type: ignore + self.assertEqual(band.manager.name, "Guido") # type: ignore # Just passing it straight into objects band = ( @@ -248,8 +254,8 @@ def test_prefetch_new_object(self): .prefetch(Band.manager) .run_sync() ) - self.assertIsInstance(band.manager, Manager) - self.assertEqual(band.name, "New Band") + self.assertIsInstance(band.manager, Manager) # type: ignore + self.assertEqual(band.name, "New Band") # type: ignore # Just passing it straight into objects band = ( diff --git a/tests/table/test_output.py b/tests/table/test_output.py index 97256dfcb..8298396fd 100644 --- a/tests/table/test_output.py +++ b/tests/table/test_output.py @@ -101,6 +101,8 @@ def test_output_nested_with_first(self): .output(nested=True) .run_sync() ) + assert response is not None self.assertDictEqual( - response, {"name": "Pythonistas", "manager": {"name": "Guido"}} + response, # type: ignore + {"name": "Pythonistas", "manager": {"name": "Guido"}}, ) diff --git a/tests/table/test_repr.py b/tests/table/test_repr.py index 59ec57ae0..37a98d37b 100644 --- a/tests/table/test_repr.py +++ b/tests/table/test_repr.py @@ -11,4 +11,5 @@ def test_repr_postgres(self): self.insert_row() manager = Manager.objects().first().run_sync() + assert manager is not None self.assertEqual(manager.__repr__(), f"") diff --git a/tests/table/test_select.py b/tests/table/test_select.py index 892972aec..a2bb86981 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -278,7 +278,7 @@ def test_where_bool(self): ``where(Band.has_drummer is None)``, which evaluates to a boolean. """ with self.assertRaises(ValueError): - Band.select().where(False) + Band.select().where(False) # type: ignore def test_where_is_not_null(self): self.insert_rows() @@ -680,6 +680,7 @@ def test_avg(self): self.insert_rows() response = Band.select(Avg(Band.popularity)).first().run_sync() + assert response is not None self.assertEqual(float(response["avg"]), 1003.3333333333334) @@ -691,6 +692,7 @@ def test_avg_alias(self): .first() .run_sync() ) + assert response is not None self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) @@ -702,6 +704,7 @@ def test_avg_as_alias_method(self): .first() .run_sync() ) + assert response is not None self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) @@ -714,6 +717,7 @@ def test_avg_with_where_clause(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["avg"], 1500) @@ -730,6 +734,7 @@ def test_avg_alias_with_where_clause(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_avg"], 1500) @@ -746,6 +751,7 @@ def test_avg_as_alias_method_with_where_clause(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_avg"], 1500) @@ -753,6 +759,7 @@ def test_max(self): self.insert_rows() response = Band.select(Max(Band.popularity)).first().run_sync() + assert response is not None self.assertEqual(response["max"], 2000) @@ -764,6 +771,7 @@ def test_max_alias(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_max"], 2000) @@ -775,6 +783,7 @@ def test_max_as_alias_method(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_max"], 2000) @@ -782,6 +791,7 @@ def test_min(self): self.insert_rows() response = Band.select(Min(Band.popularity)).first().run_sync() + assert response is not None self.assertEqual(response["min"], 10) @@ -793,6 +803,7 @@ def test_min_alias(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_min"], 10) @@ -804,6 +815,7 @@ def test_min_as_alias_method(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_min"], 10) @@ -811,6 +823,7 @@ def test_sum(self): self.insert_rows() response = Band.select(Sum(Band.popularity)).first().run_sync() + assert response is not None self.assertEqual(response["sum"], 3010) @@ -822,6 +835,7 @@ def test_sum_alias(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_sum"], 3010) @@ -833,6 +847,7 @@ def test_sum_as_alias_method(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_sum"], 3010) @@ -845,6 +860,7 @@ def test_sum_with_where_clause(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["sum"], 3000) @@ -861,6 +877,7 @@ def test_sum_alias_with_where_clause(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_sum"], 3000) @@ -877,6 +894,7 @@ def test_sum_as_alias_method_with_where_clause(self): .first() .run_sync() ) + assert response is not None self.assertEqual(response["popularity_sum"], 3000) @@ -888,6 +906,7 @@ def test_chain_different_functions(self): .first() .run_sync() ) + assert response is not None self.assertEqual(float(response["avg"]), 1003.3333333333334) self.assertEqual(response["sum"], 3010) @@ -903,6 +922,7 @@ def test_chain_different_functions_alias(self): .first() .run_sync() ) + assert response is not None self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) self.assertEqual(response["popularity_sum"], 3010) @@ -929,7 +949,8 @@ def test_columns(self): .first() .run_sync() ) - self.assertEqual(response, {"name": "Pythonistas"}) + assert response is not None + self.assertDictEqual(response, {"name": "Pythonistas"}) # Multiple calls to 'columns' should be additive. response = ( @@ -940,6 +961,7 @@ def test_columns(self): .first() .run_sync() ) + assert response is not None if engine_is("cockroach"): self.assertEqual( @@ -953,7 +975,9 @@ def test_call_chain(self): Make sure the call chain lengths are the correct size. """ self.assertEqual(len(Concert.band_1.name._meta.call_chain), 1) - self.assertEqual(len(Concert.band_1.manager.name._meta.call_chain), 2) + self.assertEqual( + len(Concert.band_1._.manager._.name._meta.call_chain), 2 + ) def test_as_alias(self): """ @@ -1028,6 +1052,7 @@ def test_secret(self): user.save().run_sync() user_dict = BaseUser.select(exclude_secrets=True).first().run_sync() + assert user_dict is not None self.assertNotIn("password", user_dict.keys()) @@ -1047,6 +1072,7 @@ def test_secret_parameter(self): venue.save().run_sync() venue_dict = Venue.select(exclude_secrets=True).first().run_sync() + assert venue_dict is not None if engine_is("cockroach"): self.assertTrue( venue_dict, {"id": venue_dict["id"], "name": "The Garage"} @@ -1379,7 +1405,7 @@ def test_distinct_on_error(self): raise a ValueError. """ with self.assertRaises(ValueError) as manager: - Album.select().distinct(on=Album.band) + Album.select().distinct(on=Album.band) # type: ignore self.assertEqual( manager.exception.__str__(), diff --git a/tests/table/test_update.py b/tests/table/test_update.py index 366b54fce..554774f60 100644 --- a/tests/table/test_update.py +++ b/tests/table/test_update.py @@ -553,7 +553,9 @@ def test_edge_cases(self): with self.assertRaises(ValueError): # An error should be raised because we can't save at this level # of resolution - 1 millisecond is the minimum. - MyTable.timestamp + datetime.timedelta(microseconds=1) + MyTable.timestamp + datetime.timedelta( # type: ignore + microseconds=1 + ) ############################################################################### @@ -604,12 +606,18 @@ def test_update(self): # Insert a row for us to update AutoUpdateTable.insert(AutoUpdateTable(name="test")).run_sync() - self.assertDictEqual( + data = ( AutoUpdateTable.select( AutoUpdateTable.name, AutoUpdateTable.modified_on ) .first() - .run_sync(), + .run_sync() + ) + + assert data is not None + + self.assertDictEqual( + data, {"name": "test", "modified_on": None}, ) @@ -626,6 +634,7 @@ def test_update(self): .first() .run_sync() ) + assert updated_row is not None self.assertIsInstance(updated_row["modified_on"], datetime.datetime) self.assertEqual(updated_row["name"], "test 2") diff --git a/tests/test_schema.py b/tests/test_schema.py index 29a4652db..d8ec3d481 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -21,10 +21,11 @@ def test_list_tables(self): """ Make sure we can list all the tables in a schema. """ + schema_name = Band._meta.schema + + assert schema_name is not None table_list = ( - SchemaManager() - .list_tables(schema_name=Band._meta.schema) - .run_sync() + SchemaManager().list_tables(schema_name=schema_name).run_sync() ) self.assertListEqual(table_list, [Band._meta.tablename]) @@ -49,7 +50,6 @@ def test_create_and_drop(self): @engines_skip("sqlite") class TestMoveTable(TestCase): - new_schema = "schema_2" def setUp(self): @@ -89,7 +89,6 @@ def test_move_table(self): @engines_skip("sqlite") class TestRenameSchema(TestCase): - manager = SchemaManager() schema_name = "test_schema" new_schema_name = "test_schema_2" diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py index f56fcf956..93a079e37 100644 --- a/tests/testing/test_model_builder.py +++ b/tests/testing/test_model_builder.py @@ -93,6 +93,7 @@ def test_choices(self): queried_shirt = ( Shirt.objects().where(Shirt.id == shirt.id).first().run_sync() ) + assert queried_shirt is not None self.assertIn( queried_shirt.size, @@ -157,6 +158,7 @@ def test_valid_column(self): .first() .run_sync() ) + assert queried_manager is not None self.assertEqual(queried_manager.name, "Guido") @@ -169,6 +171,7 @@ def test_valid_column_string(self): .first() .run_sync() ) + assert queried_manager is not None self.assertEqual(queried_manager.name, "Guido")