From 86bc48de7d75dc36142f9b59e8464efb996db20d Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 20 Apr 2023 11:48:29 +0300 Subject: [PATCH] Type-check entire code base with mypy Lots of `# type: ignore` in the internals, but the types should be decent enough for consumption. --- .circleci/config.yml | 2 +- psqlextra/backend/base.py | 20 +++++++- psqlextra/backend/base_impl.py | 16 ++++-- psqlextra/backend/introspection.py | 19 +++++-- .../migrations/patched_autodetector.py | 25 ++++++---- psqlextra/backend/migrations/state/model.py | 24 +++++---- .../backend/migrations/state/partitioning.py | 4 +- psqlextra/backend/migrations/state/view.py | 8 +-- psqlextra/backend/operations.py | 2 +- psqlextra/backend/schema.py | 45 ++++++++++------- psqlextra/compiler.py | 25 +++++----- psqlextra/error.py | 21 +++++--- psqlextra/expressions.py | 2 +- .../management/commands/pgmakemigrations.py | 4 +- psqlextra/management/commands/pgpartition.py | 2 +- psqlextra/manager/manager.py | 2 +- psqlextra/models/base.py | 5 +- psqlextra/models/partitioned.py | 6 ++- psqlextra/models/view.py | 20 +++++--- psqlextra/partitioning/config.py | 4 +- psqlextra/partitioning/manager.py | 6 ++- psqlextra/partitioning/partition.py | 6 +-- psqlextra/partitioning/plan.py | 12 +++-- psqlextra/partitioning/range_partition.py | 6 +-- psqlextra/partitioning/shorthands.py | 4 +- psqlextra/query.py | 49 ++++++++++++++----- psqlextra/schema.py | 17 +++++-- psqlextra/settings.py | 16 +++--- psqlextra/sql.py | 22 ++++++--- psqlextra/type_assertions.py | 2 +- psqlextra/util.py | 7 ++- pyproject.toml | 15 ++++++ setup.py | 23 +++++++++ 33 files changed, 315 insertions(+), 126 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bb545bad..f5ee6a31 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -119,7 +119,7 @@ jobs: steps: - checkout - install-dependencies: - extra: analysis + extra: analysis, test - run: name: Verify command: python setup.py verify diff --git a/psqlextra/backend/base.py b/psqlextra/backend/base.py index 40086da8..5c788a05 100644 --- a/psqlextra/backend/base.py +++ b/psqlextra/backend/base.py @@ -1,5 +1,7 @@ import logging +from typing import TYPE_CHECKING + from django.conf import settings from django.db import ProgrammingError @@ -8,17 +10,31 @@ from .operations import PostgresOperations from .schema import PostgresSchemaEditor +from django.db.backends.postgresql.base import ( # isort:skip + DatabaseWrapper as PostgresDatabaseWrapper, +) + + logger = logging.getLogger(__name__) -class DatabaseWrapper(base_impl.backend()): +if TYPE_CHECKING: + + class Wrapper(PostgresDatabaseWrapper): + pass + +else: + Wrapper = base_impl.backend() + + +class DatabaseWrapper(Wrapper): """Wraps the standard PostgreSQL database back-end. Overrides the schema editor with our custom schema editor and makes sure the `hstore` extension is enabled. """ - SchemaEditorClass = PostgresSchemaEditor + SchemaEditorClass = PostgresSchemaEditor # type: ignore[assignment] introspection_class = PostgresIntrospection ops_class = PostgresOperations diff --git a/psqlextra/backend/base_impl.py b/psqlextra/backend/base_impl.py index 4e2af04c..88bf9278 100644 --- a/psqlextra/backend/base_impl.py +++ b/psqlextra/backend/base_impl.py @@ -3,6 +3,14 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS, connections +from django.db.backends.postgresql.base import DatabaseWrapper +from django.db.backends.postgresql.introspection import ( # type: ignore[import] + DatabaseIntrospection, +) +from django.db.backends.postgresql.operations import DatabaseOperations +from django.db.backends.postgresql.schema import ( # type: ignore[import] + DatabaseSchemaEditor, +) from django.db.backends.postgresql.base import ( # isort:skip DatabaseWrapper as Psycopg2DatabaseWrapper, @@ -68,13 +76,13 @@ def base_backend_instance(): return base_instance -def backend(): +def backend() -> DatabaseWrapper: """Gets the base class for the database back-end.""" return base_backend_instance().__class__ -def schema_editor(): +def schema_editor() -> DatabaseSchemaEditor: """Gets the base class for the schema editor. We have to use the configured base back-end's schema editor for @@ -84,7 +92,7 @@ def schema_editor(): return base_backend_instance().SchemaEditorClass -def introspection(): +def introspection() -> DatabaseIntrospection: """Gets the base class for the introspection class. We have to use the configured base back-end's introspection class @@ -94,7 +102,7 @@ def introspection(): return base_backend_instance().introspection.__class__ -def operations(): +def operations() -> DatabaseOperations: """Gets the base class for the operations class. We have to use the configured base back-end's operations class for diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 0f7daf1a..bd775779 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,5 +1,9 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from django.db.backends.postgresql.introspection import ( # type: ignore[import] + DatabaseIntrospection, +) from psqlextra.types import PostgresPartitioningMethod @@ -45,7 +49,16 @@ def partition_by_name( ) -class PostgresIntrospection(base_impl.introspection()): +if TYPE_CHECKING: + + class Introspection(DatabaseIntrospection): + pass + +else: + Introspection = base_impl.introspection() + + +class PostgresIntrospection(Introspection): """Adds introspection features specific to PostgreSQL.""" # TODO: This class is a mess, both here and in the @@ -66,7 +79,7 @@ class PostgresIntrospection(base_impl.introspection()): def get_partitioned_tables( self, cursor - ) -> PostgresIntrospectedPartitonedTable: + ) -> List[PostgresIntrospectedPartitonedTable]: """Gets a list of partitioned tables.""" cursor.execute( diff --git a/psqlextra/backend/migrations/patched_autodetector.py b/psqlextra/backend/migrations/patched_autodetector.py index cd647fb0..e5ba8938 100644 --- a/psqlextra/backend/migrations/patched_autodetector.py +++ b/psqlextra/backend/migrations/patched_autodetector.py @@ -12,7 +12,7 @@ RenameField, ) from django.db.migrations.autodetector import MigrationAutodetector -from django.db.migrations.operations.base import Operation +from django.db.migrations.operations.fields import FieldOperation from psqlextra.models import ( PostgresMaterializedViewModel, @@ -83,7 +83,7 @@ def rename_field(self, operation: RenameField): return self._transform_view_field_operations(operation) - def _transform_view_field_operations(self, operation: Operation): + def _transform_view_field_operations(self, operation: FieldOperation): """Transforms operations on fields on a (materialized) view into state only operations. @@ -199,9 +199,15 @@ def add_create_partitioned_model(self, operation: CreateModel): ) ) + partitioned_kwargs = { + **kwargs, + "partitioning_options": partitioning_options, + } + self.add( operations.PostgresCreatePartitionedModel( - *args, **kwargs, partitioning_options=partitioning_options + *args, + **partitioned_kwargs, ) ) @@ -231,11 +237,9 @@ def add_create_view_model(self, operation: CreateModel): _, args, kwargs = operation.deconstruct() - self.add( - operations.PostgresCreateViewModel( - *args, **kwargs, view_options=view_options - ) - ) + view_kwargs = {**kwargs, "view_options": view_options} + + self.add(operations.PostgresCreateViewModel(*args, **view_kwargs)) def add_delete_view_model(self, operation: DeleteModel): """Adds a :see:PostgresDeleteViewModel operation to the list of @@ -261,9 +265,12 @@ def add_create_materialized_view_model(self, operation: CreateModel): _, args, kwargs = operation.deconstruct() + view_kwargs = {**kwargs, "view_options": view_options} + self.add( operations.PostgresCreateMaterializedViewModel( - *args, **kwargs, view_options=view_options + *args, + **view_kwargs, ) ) diff --git a/psqlextra/backend/migrations/state/model.py b/psqlextra/backend/migrations/state/model.py index 465b6152..797147f4 100644 --- a/psqlextra/backend/migrations/state/model.py +++ b/psqlextra/backend/migrations/state/model.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Type +from typing import Tuple, Type, cast from django.db.migrations.state import ModelState from django.db.models import Model @@ -17,8 +17,8 @@ class PostgresModelState(ModelState): """ @classmethod - def from_model( - cls, model: PostgresModel, *args, **kwargs + def from_model( # type: ignore[override] + cls, model: Type[PostgresModel], *args, **kwargs ) -> "PostgresModelState": """Creates a new :see:PostgresModelState object from the specified model. @@ -29,28 +29,32 @@ def from_model( We also need to patch up the base class for the model. """ - model_state = super().from_model(model, *args, **kwargs) - model_state = cls._pre_new(model, model_state) + model_state = super().from_model( + cast(Type[Model], model), *args, **kwargs + ) + model_state = cls._pre_new( + model, cast("PostgresModelState", model_state) + ) # django does not add abstract bases as a base in migrations # because it assumes the base does not add anything important # in a migration.. but it does, so we replace the Model # base with the actual base - bases = tuple() + bases: Tuple[Type[Model], ...] = tuple() for base in model_state.bases: if issubclass(base, Model): bases += (cls._get_base_model_class(),) else: bases += (base,) - model_state.bases = bases + model_state.bases = cast(Tuple[Type[Model]], bases) return model_state def clone(self) -> "PostgresModelState": """Gets an exact copy of this :see:PostgresModelState.""" model_state = super().clone() - return self._pre_clone(model_state) + return self._pre_clone(cast(PostgresModelState, model_state)) def render(self, apps): """Renders this state into an actual model.""" @@ -95,7 +99,9 @@ def render(self, apps): @classmethod def _pre_new( - cls, model: PostgresModel, model_state: "PostgresModelState" + cls, + model: Type[PostgresModel], + model_state: "PostgresModelState", ) -> "PostgresModelState": """Called when a new model state is created from the specified model.""" diff --git a/psqlextra/backend/migrations/state/partitioning.py b/psqlextra/backend/migrations/state/partitioning.py index aef7a5e3..e8b9a5eb 100644 --- a/psqlextra/backend/migrations/state/partitioning.py +++ b/psqlextra/backend/migrations/state/partitioning.py @@ -94,7 +94,7 @@ def delete_partition(self, name: str): del self.partitions[name] @classmethod - def _pre_new( + def _pre_new( # type: ignore[override] cls, model: PostgresPartitionedModel, model_state: "PostgresPartitionedModelState", @@ -108,7 +108,7 @@ def _pre_new( ) return model_state - def _pre_clone( + def _pre_clone( # type: ignore[override] self, model_state: "PostgresPartitionedModelState" ) -> "PostgresPartitionedModelState": """Called when this model state is cloned.""" diff --git a/psqlextra/backend/migrations/state/view.py b/psqlextra/backend/migrations/state/view.py index d59b3120..0f5b52eb 100644 --- a/psqlextra/backend/migrations/state/view.py +++ b/psqlextra/backend/migrations/state/view.py @@ -22,8 +22,10 @@ def __init__(self, *args, view_options={}, **kwargs): self.view_options = dict(view_options) @classmethod - def _pre_new( - cls, model: PostgresViewModel, model_state: "PostgresViewModelState" + def _pre_new( # type: ignore[override] + cls, + model: Type[PostgresViewModel], + model_state: "PostgresViewModelState", ) -> "PostgresViewModelState": """Called when a new model state is created from the specified model.""" @@ -31,7 +33,7 @@ def _pre_new( model_state.view_options = dict(model._view_meta.original_attrs) return model_state - def _pre_clone( + def _pre_clone( # type: ignore[override] self, model_state: "PostgresViewModelState" ) -> "PostgresViewModelState": """Called when this model state is cloned.""" diff --git a/psqlextra/backend/operations.py b/psqlextra/backend/operations.py index 52793fac..3bcf1897 100644 --- a/psqlextra/backend/operations.py +++ b/psqlextra/backend/operations.py @@ -9,7 +9,7 @@ from . import base_impl -class PostgresOperations(base_impl.operations()): +class PostgresOperations(base_impl.operations()): # type: ignore[misc] """Simple operations specific to PostgreSQL.""" compiler_module = "psqlextra.compiler" diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 85978f05..28e9211a 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type +from typing import TYPE_CHECKING, Any, List, Optional, Type, cast from unittest import mock import django @@ -10,6 +10,9 @@ ) from django.db import transaction from django.db.backends.ddl_references import Statement +from django.db.backends.postgresql.schema import ( # type: ignore[import] + DatabaseSchemaEditor, +) from django.db.models import Field, Model from psqlextra.settings import ( @@ -26,7 +29,13 @@ HStoreUniqueSchemaEditorSideEffect, ) -SchemaEditor = base_impl.schema_editor() +if TYPE_CHECKING: + + class SchemaEditor(DatabaseSchemaEditor): + pass + +else: + SchemaEditor = base_impl.schema_editor() class PostgresSchemaEditor(SchemaEditor): @@ -72,9 +81,9 @@ class PostgresSchemaEditor(SchemaEditor): sql_delete_partition = "DROP TABLE %s" sql_table_comment = "COMMENT ON TABLE %s IS %s" - side_effects = [ - HStoreUniqueSchemaEditorSideEffect(), - HStoreRequiredSchemaEditorSideEffect(), + side_effects: List[DatabaseSchemaEditor] = [ + cast(DatabaseSchemaEditor, HStoreUniqueSchemaEditorSideEffect()), + cast(DatabaseSchemaEditor, HStoreRequiredSchemaEditorSideEffect()), ] def __init__(self, connection, collect_sql=False, atomic=True): @@ -231,7 +240,7 @@ def clone_model_constraints_and_indexes_to_schema( [schema_name], using=self.connection.alias ): for constraint in model._meta.constraints: - self.add_constraint(model, constraint) + self.add_constraint(model, constraint) # type: ignore[attr-defined] for index in model._meta.indexes: self.add_index(model, index) @@ -246,14 +255,14 @@ def clone_model_constraints_and_indexes_to_schema( model, tuple(), model._meta.index_together ) - for field in model._meta.local_concrete_fields: + for field in model._meta.local_concrete_fields: # type: ignore[attr-defined] # Django creates primary keys later added to the model with # a custom name. We want the name as it was created originally. if field.primary_key: with postgres_reset_local_search_path( using=self.connection.alias ): - [primary_key_name] = self._constraint_names( + [primary_key_name] = self._constraint_names( # type: ignore[attr-defined] model, primary_key=True ) @@ -278,7 +287,7 @@ def clone_model_constraints_and_indexes_to_schema( with postgres_reset_local_search_path( using=self.connection.alias ): - [fk_name] = self._constraint_names( + [fk_name] = self._constraint_names( # type: ignore[attr-defined] model, [field.column], foreign_key=True ) @@ -304,7 +313,7 @@ def clone_model_constraints_and_indexes_to_schema( with postgres_reset_local_search_path( using=self.connection.alias ): - [field_check_name] = self._constraint_names( + [field_check_name] = self._constraint_names( # type: ignore[attr-defined] model, [field.column], check=True, @@ -315,7 +324,7 @@ def clone_model_constraints_and_indexes_to_schema( ) self.execute( - self._create_check_sql( + self._create_check_sql( # type: ignore[attr-defined] model, field_check_name, field_check ) ) @@ -361,7 +370,7 @@ def clone_model_foreign_keys_to_schema( resides. """ - constraint_names = self._constraint_names(model, foreign_key=True) + constraint_names = self._constraint_names(model, foreign_key=True) # type: ignore[attr-defined] with postgres_prepend_local_search_path( [schema_name], using=self.connection.alias @@ -569,7 +578,7 @@ def replace_materialized_view_model(self, model: Type[Model]) -> None: if not constraint_options["definition"]: raise SuspiciousOperation( "Table %s has a constraint '%s' that no definition could be generated for", - (model._meta.db_tabel, constraint_name), + (model._meta.db_table, constraint_name), ) self.execute(constraint_options["definition"]) @@ -597,7 +606,7 @@ def create_partitioned_model(self, model: Type[Model]) -> None: # create a composite key that includes the partitioning key sql = sql.replace(" PRIMARY KEY", "") - if model._meta.pk.name not in meta.key: + if model._meta.pk and model._meta.pk.name not in meta.key: sql = sql[:-1] + ", PRIMARY KEY (%s, %s))" % ( self.quote_name(model._meta.pk.name), partitioning_key_sql, @@ -927,7 +936,9 @@ def vacuum_model( """ columns = [ - field.column for field in fields if field.concrete and field.column + field.column + for field in fields + if getattr(field, "concrete", False) and field.column ] self.vacuum_table(model._meta.db_table, columns, **kwargs) @@ -1080,8 +1091,8 @@ def _clone_model_field(self, field: Field, **overrides) -> Field: cloned_field.model = field.model cloned_field.set_attributes_from_name(field.name) - if cloned_field.remote_field: + if cloned_field.remote_field and field.remote_field: cloned_field.remote_field.model = field.remote_field.model - cloned_field.set_attributes_from_rel() + cloned_field.set_attributes_from_rel() # type: ignore[attr-defined] return cloned_field diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index be96e50d..12fff3fa 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -71,25 +71,25 @@ def append_caller_to_sql(sql): return sql -class SQLCompiler(django_compiler.SQLCompiler): +class SQLCompiler(django_compiler.SQLCompiler): # type: ignore [attr-defined] def as_sql(self, *args, **kwargs): sql, params = super().as_sql(*args, **kwargs) return append_caller_to_sql(sql), params -class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): +class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): # type: ignore [name-defined] def as_sql(self, *args, **kwargs): sql, params = super().as_sql(*args, **kwargs) return append_caller_to_sql(sql), params -class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): +class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): # type: ignore [name-defined] def as_sql(self, *args, **kwargs): sql, params = super().as_sql(*args, **kwargs) return append_caller_to_sql(sql), params -class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): +class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): # type: ignore [name-defined] """Compiler for SQL UPDATE statements that allows us to use expressions inside HStore values. @@ -146,7 +146,7 @@ def _does_dict_contain_expression(data: dict) -> bool: return False -class SQLInsertCompiler(django_compiler.SQLInsertCompiler): +class SQLInsertCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" def as_sql(self, *args, **kwargs): @@ -159,7 +159,7 @@ def as_sql(self, *args, **kwargs): return queries -class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): +class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" def __init__(self, *args, **kwargs): @@ -237,15 +237,15 @@ def _rewrite_insert_on_conflict( update_columns = ", ".join( [ "{0} = EXCLUDED.{0}".format(self.qn(field.column)) - for field in self.query.update_fields + for field in self.query.update_fields # type: ignore[attr-defined] ] ) # build the conflict target, the columns to watch # for conflicts conflict_target = self._build_conflict_target() - index_predicate = self.query.index_predicate - update_condition = self.query.conflict_update_condition + index_predicate = self.query.index_predicate # type: ignore[attr-defined] + update_condition = self.query.conflict_update_condition # type: ignore[attr-defined] rewritten_sql = f"{sql} ON CONFLICT {conflict_target}" @@ -355,12 +355,15 @@ def _get_model_field(self, name: str): field_name = self._normalize_field_name(name) + if not self.query.model: + return None + # 'pk' has special meaning and always refers to the primary # key of a model, we have to respect this de-facto standard behaviour if field_name == "pk" and self.query.model._meta.pk: return self.query.model._meta.pk - for field in self.query.model._meta.local_concrete_fields: + for field in self.query.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.name == field_name or field.column == field_name: return field @@ -402,7 +405,7 @@ def _format_field_value(self, field_name) -> str: if isinstance(field, RelatedField) and isinstance(value, Model): value = value.pk - return django_compiler.SQLInsertCompiler.prepare_value( + return django_compiler.SQLInsertCompiler.prepare_value( # type: ignore[attr-defined] self, field, # Note: this deliberately doesn't use `pre_save_val` as we don't diff --git a/psqlextra/error.py b/psqlextra/error.py index 66438a5b..b3a5cf83 100644 --- a/psqlextra/error.py +++ b/psqlextra/error.py @@ -1,21 +1,30 @@ -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Type, Union from django import db +if TYPE_CHECKING: + from psycopg2 import Error as _Psycopg2Error + + Psycopg2Error: Optional[Type[_Psycopg2Error]] + + from psycopg import Error as _Psycopg3Error + + Psycopg3Error: Optional[Type[_Psycopg3Error]] + try: - from psycopg2 import Error as Psycopg2Error + from psycopg2 import Error as Psycopg2Error # type: ignore[no-redef] except ImportError: - Psycopg2Error = None + Psycopg2Error = None # type: ignore[misc] try: - from psycopg import Error as Psycopg3Error + from psycopg import Error as Psycopg3Error # type: ignore[no-redef] except ImportError: - Psycopg3Error = None + Psycopg3Error = None # type: ignore[misc] def extract_postgres_error( error: db.Error, -) -> Optional[Union["Psycopg2Error", "Psycopg3Error"]]: +) -> Optional[Union["_Psycopg2Error", "_Psycopg3Error"]]: """Extracts the underlying :see:psycopg2.Error from the specified Django database error. diff --git a/psqlextra/expressions.py b/psqlextra/expressions.py index 75351e68..d9c6bb54 100644 --- a/psqlextra/expressions.py +++ b/psqlextra/expressions.py @@ -140,7 +140,7 @@ def __init__(self, name: str, key: str): def resolve_expression(self, *args, **kwargs): """Resolves the expression into a :see:HStoreColumn expression.""" - original_expression: expressions.Col = super().resolve_expression( + original_expression: expressions.Col = super().resolve_expression( # type: ignore[annotation-unchecked] *args, **kwargs ) expression = HStoreColumn( diff --git a/psqlextra/management/commands/pgmakemigrations.py b/psqlextra/management/commands/pgmakemigrations.py index cdb7131b..7b678855 100644 --- a/psqlextra/management/commands/pgmakemigrations.py +++ b/psqlextra/management/commands/pgmakemigrations.py @@ -1,4 +1,6 @@ -from django.core.management.commands import makemigrations +from django.core.management.commands import ( # type: ignore[attr-defined] + makemigrations, +) from psqlextra.backend.migrations import postgres_patched_migrations diff --git a/psqlextra/management/commands/pgpartition.py b/psqlextra/management/commands/pgpartition.py index 592b57d7..8a6fa636 100644 --- a/psqlextra/management/commands/pgpartition.py +++ b/psqlextra/management/commands/pgpartition.py @@ -57,7 +57,7 @@ def add_arguments(self, parser): default=False, ) - def handle( + def handle( # type: ignore[override] self, dry: bool, yes: bool, diff --git a/psqlextra/manager/manager.py b/psqlextra/manager/manager.py index 4b96e34f..0931b38a 100644 --- a/psqlextra/manager/manager.py +++ b/psqlextra/manager/manager.py @@ -8,7 +8,7 @@ from psqlextra.query import PostgresQuerySet -class PostgresManager(Manager.from_queryset(PostgresQuerySet)): +class PostgresManager(Manager.from_queryset(PostgresQuerySet)): # type: ignore[misc] """Adds support for PostgreSQL specifics.""" use_in_migrations = True diff --git a/psqlextra/models/base.py b/psqlextra/models/base.py index 21caad36..d240237a 100644 --- a/psqlextra/models/base.py +++ b/psqlextra/models/base.py @@ -1,4 +1,7 @@ +from typing import Any + from django.db import models +from django.db.models import Manager from psqlextra.manager import PostgresManager @@ -10,4 +13,4 @@ class Meta: abstract = True base_manager_name = "objects" - objects = PostgresManager() + objects: "Manager[Any]" = PostgresManager() diff --git a/psqlextra/models/partitioned.py b/psqlextra/models/partitioned.py index c03f3e93..f0115367 100644 --- a/psqlextra/models/partitioned.py +++ b/psqlextra/models/partitioned.py @@ -1,3 +1,5 @@ +from typing import Iterable + from django.db.models.base import ModelBase from psqlextra.types import PostgresPartitioningMethod @@ -15,7 +17,7 @@ class PostgresPartitionedModelMeta(ModelBase): """ default_method = PostgresPartitioningMethod.RANGE - default_key = [] + default_key: Iterable[str] = [] def __new__(cls, name, bases, attrs, **kwargs): new_class = super().__new__(cls, name, bases, attrs, **kwargs) @@ -38,6 +40,8 @@ class PostgresPartitionedModel( """Base class for taking advantage of PostgreSQL's 11.x native support for table partitioning.""" + _partitioning_meta: PostgresPartitionedModelOptions + class Meta: abstract = True base_manager_name = "objects" diff --git a/psqlextra/models/view.py b/psqlextra/models/view.py index a9497057..b19f88c8 100644 --- a/psqlextra/models/view.py +++ b/psqlextra/models/view.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast from django.core.exceptions import ImproperlyConfigured from django.db import connections @@ -12,6 +12,9 @@ from .base import PostgresModel from .options import PostgresViewOptions +if TYPE_CHECKING: + from psqlextra.backend.schema import PostgresSchemaEditor + ViewQueryValue = Union[QuerySet, SQLWithParams, SQL] ViewQuery = Optional[Union[ViewQueryValue, Callable[[], ViewQueryValue]]] @@ -77,23 +80,26 @@ def _view_query_as_sql_with_params( " to be a valid `django.db.models.query.QuerySet`" " SQL string, or tuple of SQL string and params." ) - % (model.__name__) + % (model.__class__.__name__) ) # querysets can easily be converted into sql, params if is_query_set(view_query): - return view_query.query.sql_with_params() + return cast("QuerySet[Any]", view_query).query.sql_with_params() # query was already specified in the target format if is_sql_with_params(view_query): - return view_query + return cast(SQLWithParams, view_query) - return view_query, tuple() + view_query_sql = cast(str, view_query) + return view_query_sql, tuple() class PostgresViewModel(PostgresModel, metaclass=PostgresViewModelMeta): """Base class for creating a model that is a view.""" + _view_meta: PostgresViewOptions + class Meta: abstract = True base_manager_name = "objects" @@ -127,4 +133,6 @@ def refresh( conn_name = using or "default" with connections[conn_name].schema_editor() as schema_editor: - schema_editor.refresh_materialized_view_model(cls, concurrently) + cast( + "PostgresSchemaEditor", schema_editor + ).refresh_materialized_view_model(cls, concurrently) diff --git a/psqlextra/partitioning/config.py b/psqlextra/partitioning/config.py index df21c057..976bf1ae 100644 --- a/psqlextra/partitioning/config.py +++ b/psqlextra/partitioning/config.py @@ -1,3 +1,5 @@ +from typing import Type + from psqlextra.models import PostgresPartitionedModel from .strategy import PostgresPartitioningStrategy @@ -9,7 +11,7 @@ class PostgresPartitioningConfig: def __init__( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], strategy: PostgresPartitioningStrategy, ) -> None: self.model = model diff --git a/psqlextra/partitioning/manager.py b/psqlextra/partitioning/manager.py index 4dcbb599..074cc1c6 100644 --- a/psqlextra/partitioning/manager.py +++ b/psqlextra/partitioning/manager.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type from django.db import connections @@ -111,7 +111,9 @@ def _plan_for_config( return model_plan @staticmethod - def _get_partitioned_table(connection, model: PostgresPartitionedModel): + def _get_partitioned_table( + connection, model: Type[PostgresPartitionedModel] + ): with connection.cursor() as cursor: table = connection.introspection.get_partitioned_table( cursor, model._meta.db_table diff --git a/psqlextra/partitioning/partition.py b/psqlextra/partitioning/partition.py index ca64bbdc..4c13fda0 100644 --- a/psqlextra/partitioning/partition.py +++ b/psqlextra/partitioning/partition.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Type from psqlextra.backend.schema import PostgresSchemaEditor from psqlextra.models import PostgresPartitionedModel @@ -15,7 +15,7 @@ def name(self) -> str: @abstractmethod def create( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, comment: Optional[str] = None, ) -> None: @@ -24,7 +24,7 @@ def create( @abstractmethod def delete( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, ) -> None: """Deletes this partition from the database.""" diff --git a/psqlextra/partitioning/plan.py b/psqlextra/partitioning/plan.py index 31746360..3fcac44d 100644 --- a/psqlextra/partitioning/plan.py +++ b/psqlextra/partitioning/plan.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, cast from django.db import connections, transaction @@ -7,6 +7,9 @@ from .constants import AUTO_PARTITIONED_COMMENT from .partition import PostgresPartition +if TYPE_CHECKING: + from psqlextra.backend.schema import PostgresSchemaEditor + @dataclass class PostgresModelPartitioningPlan: @@ -38,12 +41,15 @@ def apply(self, using: Optional[str]) -> None: for partition in self.creations: partition.create( self.config.model, - schema_editor, + cast("PostgresSchemaEditor", schema_editor), comment=AUTO_PARTITIONED_COMMENT, ) for partition in self.deletions: - partition.delete(self.config.model, schema_editor) + partition.delete( + self.config.model, + cast("PostgresSchemaEditor", schema_editor), + ) def print(self) -> None: """Prints this model plan to the terminal in a readable format.""" diff --git a/psqlextra/partitioning/range_partition.py b/psqlextra/partitioning/range_partition.py index b49fe784..a2f3e82f 100644 --- a/psqlextra/partitioning/range_partition.py +++ b/psqlextra/partitioning/range_partition.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Type from psqlextra.backend.schema import PostgresSchemaEditor from psqlextra.models import PostgresPartitionedModel @@ -23,7 +23,7 @@ def deconstruct(self) -> dict: def create( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, comment: Optional[str] = None, ) -> None: @@ -37,7 +37,7 @@ def create( def delete( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, ) -> None: schema_editor.delete_partition(model, self.name()) diff --git a/psqlextra/partitioning/shorthands.py b/psqlextra/partitioning/shorthands.py index dab65e4f..30175273 100644 --- a/psqlextra/partitioning/shorthands.py +++ b/psqlextra/partitioning/shorthands.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type from dateutil.relativedelta import relativedelta @@ -10,7 +10,7 @@ def partition_by_current_time( - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], count: int, years: Optional[int] = None, months: Optional[int] = None, diff --git a/psqlextra/query.py b/psqlextra/query.py index 2f117e3d..5c5e6f47 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -1,10 +1,20 @@ from collections import OrderedDict from itertools import chain -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, +) from django.core.exceptions import SuspiciousOperation from django.db import connections, models, router -from django.db.models import Expression, Q +from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED from .sql import PostgresInsertQuery, PostgresQuery @@ -13,7 +23,17 @@ ConflictTarget = List[Union[str, Tuple[str]]] -class PostgresQuerySet(models.QuerySet): +TModel = TypeVar("TModel", bound=models.Model, covariant=True) + +if TYPE_CHECKING: + from typing_extensions import Self + + QuerySetBase = QuerySet[TModel] +else: + QuerySetBase = QuerySet + + +class PostgresQuerySet(QuerySetBase, Generic[TModel]): """Adds support for PostgreSQL specifics.""" def __init__(self, model=None, query=None, using=None, hints=None): @@ -28,7 +48,7 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_update_condition = None self.index_predicate = None - def annotate(self, **annotations): + def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override] """Custom version of the standard annotate function that allows using field names as annotated fields. @@ -112,7 +132,7 @@ def on_conflict( def bulk_insert( self, - rows: List[dict], + rows: Iterable[dict], return_model: bool = False, using: Optional[str] = None, ): @@ -202,7 +222,10 @@ def insert(self, using: Optional[str] = None, **fields): compiler = self._build_insert_compiler([fields], using=using) rows = compiler.execute_sql(return_id=True) - _, pk_db_column = self.model._meta.pk.get_attname_column() + if not self.model or not self.model.pk: + return None + + _, pk_db_column = self.model._meta.pk.get_attname_column() # type: ignore[union-attr] if not rows or len(rows) == 0: return None @@ -245,7 +268,7 @@ def insert_and_get(self, using: Optional[str] = None, **fields): # preserve the fact that the attribute name # might be different than the database column name model_columns = {} - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] model_columns[field.column] = field.attname # strip out any columns/fields returned by the db that @@ -298,7 +321,9 @@ def upsert( index_predicate=index_predicate, update_condition=update_condition, ) - return self.insert(**fields, using=using) + + kwargs = {**fields, "using": using} + return self.insert(**kwargs) def upsert_and_get( self, @@ -340,7 +365,9 @@ def upsert_and_get( index_predicate=index_predicate, update_condition=update_condition, ) - return self.insert_and_get(**fields, using=using) + + kwargs = {**fields, "using": using} + return self.insert_and_get(**kwargs) def bulk_upsert( self, @@ -403,7 +430,7 @@ def _create_model_instance( if apply_converters: connection = connections[using] - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.attname not in converted_field_values: continue @@ -447,7 +474,7 @@ def _build_insert_compiler( # ask the db router which connection to use using = ( - using or self._db or router.db_for_write(self.model, **self._hints) + using or self._db or router.db_for_write(self.model, **self._hints) # type: ignore[attr-defined] ) # create model objects, we also have to detect cases diff --git a/psqlextra/schema.py b/psqlextra/schema.py index 4ee81cd8..9edb83bd 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -1,11 +1,16 @@ import os from contextlib import contextmanager +from typing import TYPE_CHECKING, Generator, cast from django.core.exceptions import SuspiciousOperation, ValidationError from django.db import DEFAULT_DB_ALIAS, connections, transaction from django.utils import timezone +if TYPE_CHECKING: + from psqlextra.backend.introspection import PostgresIntrospection + from psqlextra.backend.schema import PostgresSchemaEditor + class PostgresSchema: """Represents a Postgres schema. @@ -47,7 +52,7 @@ def create( ) with connections[using].schema_editor() as schema_editor: - schema_editor.create_schema(name) + cast("PostgresSchemaEditor", schema_editor).create_schema(name) return cls(name) @@ -133,7 +138,9 @@ def exists(cls, name: str, *, using: str = DEFAULT_DB_ALIAS) -> bool: connection = connections[using] with connection.cursor() as cursor: - return name in connection.introspection.get_schema_list(cursor) + return name in cast( + "PostgresIntrospection", connection.introspection + ).get_schema_list(cursor) def delete( self, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS @@ -157,7 +164,9 @@ def delete( ) with connections[using].schema_editor() as schema_editor: - schema_editor.delete_schema(self.name, cascade=cascade) + cast("PostgresSchemaEditor", schema_editor).delete_schema( + self.name, cascade=cascade + ) @classmethod def _create_generated_name(cls, prefix: str, suffix: str) -> str: @@ -183,7 +192,7 @@ def postgres_temporary_schema( cascade: bool = False, delete_on_throw: bool = False, using: str = DEFAULT_DB_ALIAS, -) -> PostgresSchema: +) -> Generator[PostgresSchema, None, None]: """Creates a temporary schema that only lives in the context of this context manager. diff --git a/psqlextra/settings.py b/psqlextra/settings.py index 6dd32f37..6f75c779 100644 --- a/psqlextra/settings.py +++ b/psqlextra/settings.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Dict, List, Optional, Union +from typing import Generator, List, Optional, Union from django.core.exceptions import SuspiciousOperation from django.db import DEFAULT_DB_ALIAS, connections @@ -9,8 +9,8 @@ def postgres_set_local( *, using: str = DEFAULT_DB_ALIAS, - **options: Dict[str, Optional[Union[str, int, float, List[str]]]], -) -> None: + **options: Optional[Union[str, int, float, List[str]]], +) -> Generator[None, None, None]: """Sets the specified PostgreSQL options using SET LOCAL so that they apply to the current transacton only. @@ -29,7 +29,7 @@ def postgres_set_local( ) sql = [] - params = [] + params: List[Union[str, int, float, List[str]]] = [] for name, value in options.items(): if value is None: sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") @@ -78,7 +78,7 @@ def postgres_set_local( @contextmanager def postgres_set_local_search_path( search_path: List[str], *, using: str = DEFAULT_DB_ALIAS -) -> None: +) -> Generator[None, None, None]: """Sets the search path to the specified schemas.""" with postgres_set_local(search_path=search_path, using=using): @@ -88,7 +88,7 @@ def postgres_set_local_search_path( @contextmanager def postgres_prepend_local_search_path( search_path: List[str], *, using: str = DEFAULT_DB_ALIAS -) -> None: +) -> Generator[None, None, None]: """Prepends the current local search path with the specified schemas.""" connection = connections[using] @@ -111,7 +111,9 @@ def postgres_prepend_local_search_path( @contextmanager -def postgres_reset_local_search_path(*, using: str = DEFAULT_DB_ALIAS) -> None: +def postgres_reset_local_search_path( + *, using: str = DEFAULT_DB_ALIAS +) -> Generator[None, None, None]: """Resets the local search path to the default.""" with postgres_set_local(search_path=None, using=using): diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 25c8314e..2a5b418e 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,11 +1,11 @@ from collections import OrderedDict -from typing import List, Optional, Tuple +from typing import Optional, Tuple import django from django.core.exceptions import SuspiciousOperation from django.db import connections, models -from django.db.models import sql +from django.db.models import Expression, sql from django.db.models.constants import LOOKUP_SEP from .compiler import PostgresInsertOnConflictCompiler @@ -16,6 +16,8 @@ class PostgresQuery(sql.Query): + select: Tuple[Expression, ...] + def chain(self, klass=None): """Chains this query to another. @@ -68,7 +70,7 @@ def rename_annotations(self, annotations) -> None: self.annotations.clear() self.annotations.update(new_annotations) - def add_fields(self, field_names: List[str], *args, **kwargs) -> None: + def add_fields(self, field_names, *args, **kwargs) -> None: """Adds the given (model) fields to the select set. The field names are added in the order specified. This overrides @@ -100,10 +102,11 @@ def add_fields(self, field_names: List[str], *args, **kwargs) -> None: if len(parts) > 1: column_name, hstore_key = parts[:2] is_hstore, field = self._is_hstore_field(column_name) - if is_hstore: + if self.model and is_hstore: select.append( HStoreColumn( - self.model._meta.db_table or self.model.name, + self.model._meta.db_table + or self.model.__class__.__name__, field, hstore_key, ) @@ -115,7 +118,7 @@ def add_fields(self, field_names: List[str], *args, **kwargs) -> None: super().add_fields(field_names_without_hstore, *args, **kwargs) if len(select) > 0: - self.set_select(self.select + tuple(select)) + self.set_select(list(self.select + tuple(select))) def _is_hstore_field( self, field_name: str @@ -127,8 +130,11 @@ def _is_hstore_field( instance. """ + if not self.model: + return (False, None) + field_instance = None - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.name == field_name or field.column == field_name: field_instance = field break @@ -151,7 +157,7 @@ def __init__(self, *args, **kwargs): self.update_fields = [] - def values(self, objs: List, insert_fields: List, update_fields: List = []): + def values(self, objs, insert_fields, update_fields=[]): """Sets the values to be used in this query. Insert fields are fields that are definitely diff --git a/psqlextra/type_assertions.py b/psqlextra/type_assertions.py index 0a7e8608..e18d13be 100644 --- a/psqlextra/type_assertions.py +++ b/psqlextra/type_assertions.py @@ -7,7 +7,7 @@ def is_query_set(value: Any) -> bool: """Gets whether the specified value is a :see:QuerySet.""" - return isinstance(value, QuerySet) + return isinstance(value, QuerySet) # type: ignore[misc] def is_sql(value: Any) -> bool: diff --git a/psqlextra/util.py b/psqlextra/util.py index edc4e955..d0bca000 100644 --- a/psqlextra/util.py +++ b/psqlextra/util.py @@ -1,10 +1,15 @@ from contextlib import contextmanager +from typing import Generator, Type + +from django.db import models from .manager import PostgresManager @contextmanager -def postgres_manager(model): +def postgres_manager( + model: Type[models.Model], +) -> Generator[PostgresManager, None, None]: """Allows you to use the :see:PostgresManager with the specified model instance on the fly. diff --git a/pyproject.toml b/pyproject.toml index 126ae9a3..fb35b3b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,3 +10,18 @@ exclude = ''' )/ ) ''' + +[tool.mypy] +python_version = "3.8" +plugins = ["mypy_django_plugin.main"] +mypy_path = ["stubs", "."] +exclude = "(env|build|dist|migrations)" + +[[tool.mypy.overrides]] +module = [ + "psycopg.*" +] +ignore_missing_imports = true + +[tool.django-stubs] +django_settings_module = "settings" diff --git a/setup.py b/setup.py index 281be89d..311acf11 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,15 @@ def run(self): "autopep8==1.6.0", "isort==5.10.0", "docformatter==1.4", + "mypy==1.2.0; python_version > '3.6'", + "mypy==0.971; python_version <= '3.6'", + "django-stubs==1.16.0; python_version > '3.6'", + "django-stubs==1.9.0; python_version <= '3.6'", + "typing-extensions==4.5.0; python_version > '3.6'", + "typing-extensions==4.1.0; python_version <= '3.6'", + "types-dj-database-url==1.3.0.0", + "types-psycopg2==2.9.21.9", + "types-python-dateutil==2.8.19.12", ], "publish": [ "build==0.7.0", @@ -124,6 +133,18 @@ def run(self): ["autopep8", "-i", "-r", "setup.py", "psqlextra", "tests"], ], ), + "lint_types": create_command( + "Type-checks the code", + [ + [ + "mypy", + "--package", + "psqlextra", + "--pretty", + "--show-error-codes", + ], + ], + ), "format": create_command( "Formats the code", [["black", "setup.py", "psqlextra", "tests"]] ), @@ -162,6 +183,7 @@ def run(self): ["python", "setup.py", "sort_imports"], ["python", "setup.py", "lint_fix"], ["python", "setup.py", "lint"], + ["python", "setup.py", "lint_types"], ], ), "verify": create_command( @@ -171,6 +193,7 @@ def run(self): ["python", "setup.py", "format_docstrings_verify"], ["python", "setup.py", "sort_imports_verify"], ["python", "setup.py", "lint"], + ["python", "setup.py", "lint_types"], ], ), "test": create_command(