Skip to content

Commit

Permalink
fix: improve type annotations on SQLAlchemy (#1458)
Browse files Browse the repository at this point in the history
* fix: improve type annotations

* lint
  • Loading branch information
dpgaspar authored Sep 10, 2020
1 parent 9f1f64d commit 12f8b3a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
4 changes: 2 additions & 2 deletions flask_appbuilder/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from functools import reduce
import logging
from typing import Type
from typing import Any, Type

from flask_babel import lazy_gettext

Expand Down Expand Up @@ -44,7 +44,7 @@ class BaseInterface(object):
""" Tuple with message and text with severity type ex: ("Added Row", "info") """
message = ()

def __init__(self, obj):
def __init__(self, obj: Type[Any]):
self.obj = obj

def _get_attr(self, col_name):
Expand Down
40 changes: 23 additions & 17 deletions flask_appbuilder/models/sqla/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SQLAInterface(BaseInterface):

filter_converter_class = filters.SQLAFilterConverter

def __init__(self, obj: Model, session: Optional[SessionBase] = None) -> None:
def __init__(self, obj: Type[Model], session: Optional[SessionBase] = None) -> None:
_include_filters(self)
self.list_columns = dict()
self.list_properties = dict()
Expand Down Expand Up @@ -115,7 +115,11 @@ def _query_join_relation(self, query: BaseQuery, root_relation: str) -> BaseQuer
return query

def apply_engine_specific_hack(
self, query: BaseQuery, page, page_size, order_column
self,
query: BaseQuery,
page: Optional[int],
page_size: Optional[int],
order_column: Optional[str],
) -> BaseQuery:
# MSSQL exception page/limit must have an order by
if (
Expand Down Expand Up @@ -579,7 +583,7 @@ def get_max_length(self, col_name: str) -> int:
-------------------------------
"""

def add(self, item, raise_exception=False):
def add(self, item: Model, raise_exception: bool = False) -> bool:
try:
self.session.add(item)
self.session.commit()
Expand All @@ -603,7 +607,7 @@ def add(self, item, raise_exception=False):
raise e
return False

def edit(self, item, raise_exception=False):
def edit(self, item: Model, raise_exception: bool = False) -> bool:
try:
self.session.merge(item)
self.session.commit()
Expand All @@ -627,7 +631,7 @@ def edit(self, item, raise_exception=False):
raise e
return False

def delete(self, item, raise_exception=False):
def delete(self, item: Model, raise_exception: bool = False) -> bool:
try:
self._delete_files(item)
self.session.delete(item)
Expand All @@ -652,7 +656,7 @@ def delete(self, item, raise_exception=False):
raise e
return False

def delete_all(self, items):
def delete_all(self, items: List[Model]) -> bool:
try:
for item in items:
self._delete_files(item)
Expand Down Expand Up @@ -680,7 +684,7 @@ def delete_all(self, items):
-----------------------
"""

def _add_files(self, this_request, item):
def _add_files(self, this_request, item: Model):
fm = FileManager()
im = ImageManager()
for file_col in this_request.files:
Expand All @@ -690,7 +694,7 @@ def _add_files(self, this_request, item):
if self.is_image(file_col):
im.save_file(this_request.files[file_col], getattr(item, file_col))

def _delete_files(self, item):
def _delete_files(self, item: Model):
for file_col in self.get_file_column_list():
if self.is_file(file_col):
if getattr(item, file_col):
Expand All @@ -708,7 +712,7 @@ def _delete_files(self, item):
------------------------------
"""

def get_col_default(self, col_name):
def get_col_default(self, col_name: str) -> Any:
default = getattr(self.list_columns[col_name], "default", None)
if default is not None:
value = getattr(default, "arg", None)
Expand All @@ -720,10 +724,12 @@ def get_col_default(self, col_name):
return None
return value

def get_related_model(self, col_name: str) -> Model:
def get_related_model(self, col_name: str) -> Type[Model]:
return self.list_properties[col_name].mapper.class_

def get_related_model_and_join(self, col_name: str) -> List[Tuple[Model, object]]:
def get_related_model_and_join(
self, col_name: str
) -> List[Tuple[Type[Model], object]]:
relation = self.list_properties[col_name]
if relation.direction.name == "MANYTOMANY":
return [
Expand All @@ -744,14 +750,14 @@ def get_related_obj(self, col_name: str, value: Any) -> Optional[Type[Model]]:
def get_related_fks(self, related_views) -> List[str]:
return [view.datamodel.get_related_fk(self.obj) for view in related_views]

def get_related_fk(self, model) -> Optional[str]:
def get_related_fk(self, model: Type[Model]) -> Optional[str]:
for col_name in self.list_properties.keys():
if self.is_relation(col_name):
if model == self.get_related_model(col_name):
return col_name
return None

def get_info(self, col_name):
def get_info(self, col_name: str):
if col_name in self.list_properties:
return self.list_properties[col_name].info
return {}
Expand Down Expand Up @@ -815,14 +821,14 @@ def get_order_columns_list(self, list_columns: List[str] = None) -> List[str]:
ret_lst.append(col_name)
return ret_lst

def get_file_column_list(self):
def get_file_column_list(self) -> List[str]:
return [
i.name
for i in self.obj.__mapper__.columns
if isinstance(i.type, FileColumn)
]

def get_image_column_list(self):
def get_image_column_list(self) -> List[str]:
return [
i.name
for i in self.obj.__mapper__.columns
Expand Down Expand Up @@ -880,7 +886,7 @@ def get_pk_name(self) -> Optional[Union[List[str], str]]:
"""
return self._get_pk_name(self.obj)

def get_pk(self, model: Optional[Model] = None):
def get_pk(self, model: Optional[Type[Model]] = None):
"""
Get the model primary key SQLAlchemy column.
Will not support composite keys
Expand All @@ -891,7 +897,7 @@ def get_pk(self, model: Optional[Model] = None):
return getattr(model_, pk_name)
return None

def _get_pk_name(self, model: Model) -> Optional[Union[List[str], str]]:
def _get_pk_name(self, model: Type[Model]) -> Optional[Union[List[str], str]]:
pk = [pk.name for pk in model.__mapper__.primary_key]
if pk:
return pk if self.is_pk_composite() else pk[0]
Expand Down

0 comments on commit 12f8b3a

Please sign in to comment.