Skip to content

Commit

Permalink
fix: memoize primitives (#19930)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored May 2, 2022
1 parent 7b3d0f0 commit 1ebdaac
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 49 deletions.
26 changes: 16 additions & 10 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,18 +864,24 @@ def get_all_datasource_names(
all_datasources: List[utils.DatasourceName] = []
for schema in schemas:
if datasource_type == "table":
all_datasources += database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
all_datasources.extend(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
)
elif datasource_type == "view":
all_datasources += database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
all_datasources.extend(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
)
else:
raise Exception(f"Unsupported datasource_type: {datasource_type}")
Expand Down
30 changes: 18 additions & 12 deletions superset/db_engine_specs/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,25 @@ def get_all_datasource_names(
)
schema = schemas[0]
if datasource_type == "table":
return database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
return [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
]
if datasource_type == "view":
return database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
return [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema,
force=True,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
]
raise Exception(f"Unsupported datasource_type: {datasource_type}")

@classmethod
Expand Down
28 changes: 18 additions & 10 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,16 @@ def get_all_table_names_in_database( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
return self.db_engine_spec.get_all_datasource_names(self, "table")
return [
(datasource_name.table, datasource_name.schema)
for datasource_name in self.db_engine_spec.get_all_datasource_names(
self, "table"
)
]

@cache_util.memoized_func(
key="db:{self.id}:schema:None:view_list",
Expand All @@ -537,11 +542,16 @@ def get_all_view_names_in_database( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[bool] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
return []
return self.db_engine_spec.get_all_datasource_names(self, "view")
return [
(datasource_name.table, datasource_name.schema)
for datasource_name in self.db_engine_spec.get_all_datasource_names(
self, "view"
)
]

@cache_util.memoized_func(
key="db:{self.id}:schema:{schema}:table_list",
Expand All @@ -553,7 +563,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
Expand All @@ -569,9 +579,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
tables = self.db_engine_spec.get_table_names(
database=self, inspector=self.inspector, schema=schema
)
return [
utils.DatasourceName(table=table, schema=schema) for table in tables
]
return [(table, schema) for table in tables]
except Exception as ex: # pylint: disable=broad-except
logger.warning(ex)
return []
Expand All @@ -586,7 +594,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[utils.DatasourceName]:
) -> List[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
Expand All @@ -602,7 +610,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
views = self.db_engine_spec.get_view_names(
database=self, inspector=self.inspector, schema=schema
)
return [utils.DatasourceName(table=view, schema=schema) for view in views]
return [(view, schema) for view in views]
except Exception as ex: # pylint: disable=broad-except
logger.warning(ex)
return []
Expand Down
13 changes: 12 additions & 1 deletion superset/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,18 @@ def memoized_func(
key: Optional[str] = None,
cache: Cache = cache_manager.cache,
) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg.
"""
Decorator with configurable key and cache backend.
@memoized_func(key="{a}+{b}", cache=cache_manager.data_cache)
def sum(a: int, b: int) -> int:
return a + b
In the example above the result for `1+2` will be stored under the key of name "1+2",
in the `cache_manager.data_cache` cache.
Note: this decorator should be used only with functions that return primitives,
otherwise the deserialization might not work correctly.
enable_cache is treated as True by default,
except enable_cache = False is passed to the decorated function.
Expand Down
34 changes: 20 additions & 14 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,31 +1115,37 @@ def tables( # pylint: disable=too-many-locals,no-self-use,too-many-arguments
substr_parsed = utils.parse_js_uri_path_item(substr, eval_undefined=True)

if schema_parsed:
tables = (
database.get_all_table_names_in_schema(
tables = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
or []
)
views = (
database.get_all_view_names_in_schema(
] or []
views = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
or []
)
] or []
else:
tables = database.get_all_table_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
views = database.get_all_view_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
tables = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
]
views = [
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_database(
cache=True, force=False, cache_timeout=24 * 60 * 60
)
]
tables = security_manager.get_datasources_accessible_by_user(
database, tables, schema_parsed
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/db_engine_specs/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_get_all_datasource_names_table(app_context: AppContext) -> None:

database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
table_names = ["table1", "table2"]
table_names = [("table1", "schema1"), ("table2", "schema1")]
get_tables = mock.MagicMock(return_value=table_names)
database.get_all_table_names_in_schema = get_tables
result = SqliteEngineSpec.get_all_datasource_names(database, "table")
Expand All @@ -65,7 +65,7 @@ def test_get_all_datasource_names_view(app_context: AppContext) -> None:

database = mock.MagicMock()
database.get_all_schema_names.return_value = ["schema1"]
views_names = ["view1", "view2"]
views_names = [("view1", "schema1"), ("view2", "schema1")]
get_views = mock.MagicMock(return_value=views_names)
database.get_all_view_names_in_schema = get_views
result = SqliteEngineSpec.get_all_datasource_names(database, "view")
Expand Down

0 comments on commit 1ebdaac

Please sign in to comment.