From 12f8b3aa513f4c099a84e2bc1ed3d8d40153cd9d Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Thu, 10 Sep 2020 13:52:10 +0100 Subject: [PATCH] fix: improve type annotations on SQLAlchemy (#1458) * fix: improve type annotations * lint --- flask_appbuilder/models/base.py | 4 +-- flask_appbuilder/models/sqla/interface.py | 40 +++++++++++++---------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/flask_appbuilder/models/base.py b/flask_appbuilder/models/base.py index 8971da66a8..9fbed309e8 100644 --- a/flask_appbuilder/models/base.py +++ b/flask_appbuilder/models/base.py @@ -1,7 +1,7 @@ import datetime from functools import reduce import logging -from typing import Type +from typing import Any, Type from flask_babel import lazy_gettext @@ -44,7 +44,7 @@ class BaseInterface(object): """ Tuple with message and text with severity type ex: ("Added Row", "info") """ message = () - def __init__(self, obj): + def __init__(self, obj: Type[Any]): self.obj = obj def _get_attr(self, col_name): diff --git a/flask_appbuilder/models/sqla/interface.py b/flask_appbuilder/models/sqla/interface.py index 028c0d9f65..310c1ab3cb 100644 --- a/flask_appbuilder/models/sqla/interface.py +++ b/flask_appbuilder/models/sqla/interface.py @@ -55,7 +55,7 @@ class SQLAInterface(BaseInterface): filter_converter_class = filters.SQLAFilterConverter - def __init__(self, obj: Model, session: Optional[SessionBase] = None) -> None: + def __init__(self, obj: Type[Model], session: Optional[SessionBase] = None) -> None: _include_filters(self) self.list_columns = dict() self.list_properties = dict() @@ -115,7 +115,11 @@ def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuer return query def apply_engine_specific_hack( - self, query: BaseQuery, page, page_size, order_column + self, + query: BaseQuery, + page: Optional[int], + page_size: Optional[int], + order_column: Optional[str], ) -> BaseQuery: # MSSQL exception page/limit must have an order by if ( @@ -579,7 +583,7 @@ def get_max_length(self, col_name: str) -> int: ------------------------------- """ - def add(self, item, raise_exception=False): + def add(self, item: Model, raise_exception: bool = False) -> bool: try: self.session.add(item) self.session.commit() @@ -603,7 +607,7 @@ def add(self, item, raise_exception=False): raise e return False - def edit(self, item, raise_exception=False): + def edit(self, item: Model, raise_exception: bool = False) -> bool: try: self.session.merge(item) self.session.commit() @@ -627,7 +631,7 @@ def edit(self, item, raise_exception=False): raise e return False - def delete(self, item, raise_exception=False): + def delete(self, item: Model, raise_exception: bool = False) -> bool: try: self._delete_files(item) self.session.delete(item) @@ -652,7 +656,7 @@ def delete(self, item, raise_exception=False): raise e return False - def delete_all(self, items): + def delete_all(self, items: List[Model]) -> bool: try: for item in items: self._delete_files(item) @@ -680,7 +684,7 @@ def delete_all(self, items): ----------------------- """ - def _add_files(self, this_request, item): + def _add_files(self, this_request, item: Model): fm = FileManager() im = ImageManager() for file_col in this_request.files: @@ -690,7 +694,7 @@ def _add_files(self, this_request, item): if self.is_image(file_col): im.save_file(this_request.files[file_col], getattr(item, file_col)) - def _delete_files(self, item): + def _delete_files(self, item: Model): for file_col in self.get_file_column_list(): if self.is_file(file_col): if getattr(item, file_col): @@ -708,7 +712,7 @@ def _delete_files(self, item): ------------------------------ """ - def get_col_default(self, col_name): + def get_col_default(self, col_name: str) -> Any: default = getattr(self.list_columns[col_name], "default", None) if default is not None: value = getattr(default, "arg", None) @@ -720,10 +724,12 @@ def get_col_default(self, col_name): return None return value - def get_related_model(self, col_name: str) -> Model: + def get_related_model(self, col_name: str) -> Type[Model]: return self.list_properties[col_name].mapper.class_ - def get_related_model_and_join(self, col_name: str) -> List[Tuple[Model, object]]: + def get_related_model_and_join( + self, col_name: str + ) -> List[Tuple[Type[Model], object]]: relation = self.list_properties[col_name] if relation.direction.name == "MANYTOMANY": return [ @@ -744,14 +750,14 @@ def get_related_obj(self, col_name: str, value: Any) -> Optional[Type[Model]]: def get_related_fks(self, related_views) -> List[str]: return [view.datamodel.get_related_fk(self.obj) for view in related_views] - def get_related_fk(self, model) -> Optional[str]: + def get_related_fk(self, model: Type[Model]) -> Optional[str]: for col_name in self.list_properties.keys(): if self.is_relation(col_name): if model == self.get_related_model(col_name): return col_name return None - def get_info(self, col_name): + def get_info(self, col_name: str): if col_name in self.list_properties: return self.list_properties[col_name].info return {} @@ -815,14 +821,14 @@ def get_order_columns_list(self, list_columns: List[str] = None) -> List[str]: ret_lst.append(col_name) return ret_lst - def get_file_column_list(self): + def get_file_column_list(self) -> List[str]: return [ i.name for i in self.obj.__mapper__.columns if isinstance(i.type, FileColumn) ] - def get_image_column_list(self): + def get_image_column_list(self) -> List[str]: return [ i.name for i in self.obj.__mapper__.columns @@ -880,7 +886,7 @@ def get_pk_name(self) -> Optional[Union[List[str], str]]: """ return self._get_pk_name(self.obj) - def get_pk(self, model: Optional[Model] = None): + def get_pk(self, model: Optional[Type[Model]] = None): """ Get the model primary key SQLAlchemy column. Will not support composite keys @@ -891,7 +897,7 @@ def get_pk(self, model: Optional[Model] = None): return getattr(model_, pk_name) return None - def _get_pk_name(self, model: Model) -> Optional[Union[List[str], str]]: + def _get_pk_name(self, model: Type[Model]) -> Optional[Union[List[str], str]]: pk = [pk.name for pk in model.__mapper__.primary_key] if pk: return pk if self.is_pk_composite() else pk[0]