From 1dc9c74a43ea96905a9d374a973c4255b3a44888 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Wed, 11 May 2022 16:34:25 +0000 Subject: [PATCH 1/4] restart --- superset/dao/datasource/dao.py | 151 ++++++++++++++++++++++++ tests/unit_tests/dao/datasource_test.py | 116 ++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 superset/dao/datasource/dao.py create mode 100644 tests/unit_tests/dao/datasource_test.py diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py new file mode 100644 index 0000000000000..1302629b39ddb --- /dev/null +++ b/superset/dao/datasource/dao.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional, Set, Union + +from flask_babel import _ +from sqlalchemy import or_ +from sqlalchemy.orm import Session, subqueryload +from sqlalchemy.orm.exc import NoResultFound + +from superset.connectors.sqla.models import SqlaTable, Table +from superset.dao.base import BaseDAO +from superset.datasets.commands.exceptions import DatasetNotFoundError +from superset.datasets.models import Dataset +from superset.models.core import Database +from superset.models.sql_lab import Query, SavedQuery + +Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] + + +class DatasourceDAO(BaseDAO): + + sources = { + # using table -> SqlaTable for backward compatibility at the moment + "table": SqlaTable, + "query": Query, + "saved_query": SavedQuery, + "sl_dataset": Dataset, + "sl_table": Table, + } + + @classmethod + def get_datasource( + cls, datasource_type: str, datasource_id: int, session: Session + ) -> Datasource: + if datasource_type not in cls.sources: + raise DatasetNotFoundError() + + datasource = ( + session.query(cls.sources[datasource_type]) + .filter_by(id=datasource_id) + .one_or_none() + ) + + if not datasource: + raise DatasetNotFoundError() + + return datasource + + def get_all_datasources(self, session: Session) -> List[Datasource]: + datasources: List["Datasource"] = [] + for source_class in DatasourceDAO.sources.values(): + qry = session.query(source_class) + qry = source_class.default_query(qry) + datasources.extend(qry.all()) + return datasources + + def get_datasource_by_id(self, session: Session, datasource_id: int) -> Datasource: + """ + Find a datasource instance based on the unique id. + :param session: Session to use + :param datasource_id: unique id of datasource + :return: Datasource corresponding to the id + :raises NoResultFound: if no datasource is found corresponding to the id + """ + for datasource_class in DatasourceDAO.sources.values(): + try: + return ( + session.query(datasource_class) + .filter(datasource_class.id == datasource_id) + .one() + ) + except NoResultFound: + # proceed to next datasource type + pass + raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id)) + + def get_datasource_by_name( # pylint: disable=too-many-arguments + self, + session: Session, + datasource_type: str, + datasource_name: str, + schema: str, + database_name: str, + ) -> Optional[Datasource]: + datasource_class = DatasourceDAO.sources[datasource_type] + return datasource_class.get_datasource_by_name( + session, datasource_name, schema, database_name + ) + + def query_datasources_by_permissions( # pylint: disable=invalid-name + self, + session: Session, + database: Database, + permissions: Set[str], + schema_perms: Set[str], + ) -> List[Datasource]: + # TODO(bogdan): add unit test + datasource_class = DatasourceDAO.sources[database.type] + return ( + session.query(datasource_class) + .filter_by(database_id=database.id) + .filter( + or_( + datasource_class.perm.in_(permissions), + datasource_class.schema_perm.in_(schema_perms), + ) + ) + .all() + ) + + def get_eager_datasource( + self, session: Session, datasource_type: str, datasource_id: int + ) -> Datasource: + """Returns datasource with columns and metrics.""" + datasource_class = DatasourceDAO.sources[datasource_type] + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics), + ) + .filter_by(id=datasource_id) + .one() + ) + + def query_datasources_by_name( + self, + session: Session, + database: Database, + datasource_name: str, + schema: Optional[str] = None, + ) -> List[Datasource]: + datasource_class = DatasourceDAO.sources[database.type] + return datasource_class.query_datasources_by_name( + session, database, datasource_name, schema=schema + ) diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/dao/datasource_test.py new file mode 100644 index 0000000000000..91a91a4ecc70d --- /dev/null +++ b/tests/unit_tests/dao/datasource_test.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from sqlalchemy.orm.session import Session + + +def create_test_data(session: Session) -> None: + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.models.core import Database + from superset.models.sql_lab import Query, SavedQuery + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=columns, + metrics=[], + database=db, + ) + + query_obj = Query( + client_id="foo", + database=db, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=100, + error_message="none", + results_key="abc", + ) + + saved_query = SavedQuery(database=db, sql="select * from foo") + + session.add(saved_query) + session.add(query_obj) + session.add(db) + session.add(sqla_table) + session.flush() + + +def test_get_datasource_sqlatable(app_context: None, session: Session) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.dao.datasource.dao import DatasourceDAO + + create_test_data(session) + + result = DatasourceDAO.get_datasource( + datasource_type="table", datasource_id=1, session=session + ) + + assert 1 == result.id + assert "my_sqla_table" == result.table_name + assert isinstance(result, SqlaTable) + + +def test_get_datasource_query(app_context: None, session: Session) -> None: + from superset.dao.datasource.dao import DatasourceDAO + from superset.models.sql_lab import Query + + create_test_data(session) + + result = DatasourceDAO.get_datasource( + datasource_type="query", datasource_id=1, session=session + ) + + assert result.id == 1 + assert isinstance(result, Query) + + +def test_get_datasource_saved_query(app_context: None, session: Session) -> None: + from superset.dao.datasource.dao import DatasourceDAO + from superset.models.sql_lab import SavedQuery + + create_test_data(session) + + result = DatasourceDAO.get_datasource( + datasource_type="saved_query", datasource_id=1, session=session + ) + + assert result.id == 1 + assert isinstance(result, SavedQuery) + + +def test_get_datasource_sl_table(app_context: None, session: Session) -> None: + pass + + +def test_get_datasource_sl_dataset(app_context: None, session: Session) -> None: + pass From bcfc6831f16bbf5649ea4a1f7741a098e42c9f46 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Wed, 11 May 2022 16:48:39 +0000 Subject: [PATCH 2/4] update with enums --- superset/dao/datasource/dao.py | 127 ++++++++++++------------ tests/unit_tests/dao/datasource_test.py | 75 +++++++++++++- 2 files changed, 131 insertions(+), 71 deletions(-) diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py index 1302629b39ddb..d58717ad5f60c 100644 --- a/superset/dao/datasource/dao.py +++ b/superset/dao/datasource/dao.py @@ -15,37 +15,38 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union from flask_babel import _ from sqlalchemy import or_ from sqlalchemy.orm import Session, subqueryload from sqlalchemy.orm.exc import NoResultFound -from superset.connectors.sqla.models import SqlaTable, Table +from superset.connectors.sqla.models import SqlaTable from superset.dao.base import BaseDAO from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.models import Dataset from superset.models.core import Database from superset.models.sql_lab import Query, SavedQuery +from superset.tables.models import Table +from superset.utils.core import DatasourceType -Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] +Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery, Any] class DatasourceDAO(BaseDAO): - sources = { - # using table -> SqlaTable for backward compatibility at the moment - "table": SqlaTable, - "query": Query, - "saved_query": SavedQuery, - "sl_dataset": Dataset, - "sl_table": Table, + sources: Dict[DatasourceType, Datasource] = { + DatasourceType.SQLATABLE: SqlaTable, + DatasourceType.QUERY: Query, + DatasourceType.SAVEDQUERY: SavedQuery, + DatasourceType.DATASET: Dataset, + DatasourceType.TABLE: Table, } @classmethod def get_datasource( - cls, datasource_type: str, datasource_id: int, session: Session + cls, datasource_type: DatasourceType, datasource_id: int, session: Session ) -> Datasource: if datasource_type not in cls.sources: raise DatasetNotFoundError() @@ -61,91 +62,85 @@ def get_datasource( return datasource - def get_all_datasources(self, session: Session) -> List[Datasource]: - datasources: List["Datasource"] = [] + @classmethod + def get_all_datasources(cls, session: Session) -> List[Datasource]: + datasources: List[Datasource] = [] for source_class in DatasourceDAO.sources.values(): qry = session.query(source_class) - qry = source_class.default_query(qry) + if isinstance(source_class, SqlaTable): + qry = source_class.default_query(qry) datasources.extend(qry.all()) return datasources - def get_datasource_by_id(self, session: Session, datasource_id: int) -> Datasource: - """ - Find a datasource instance based on the unique id. - :param session: Session to use - :param datasource_id: unique id of datasource - :return: Datasource corresponding to the id - :raises NoResultFound: if no datasource is found corresponding to the id - """ - for datasource_class in DatasourceDAO.sources.values(): - try: - return ( - session.query(datasource_class) - .filter(datasource_class.id == datasource_id) - .one() - ) - except NoResultFound: - # proceed to next datasource type - pass - raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id)) - + @classmethod def get_datasource_by_name( # pylint: disable=too-many-arguments - self, + cls, session: Session, - datasource_type: str, + datasource_type: DatasourceType, datasource_name: str, schema: str, database_name: str, ) -> Optional[Datasource]: datasource_class = DatasourceDAO.sources[datasource_type] - return datasource_class.get_datasource_by_name( - session, datasource_name, schema, database_name - ) + if isinstance(datasource_class, SqlaTable): + return datasource_class.get_datasource_by_name( + session, datasource_name, schema, database_name + ) + return None + @classmethod def query_datasources_by_permissions( # pylint: disable=invalid-name - self, + cls, session: Session, database: Database, permissions: Set[str], schema_perms: Set[str], ) -> List[Datasource]: # TODO(bogdan): add unit test - datasource_class = DatasourceDAO.sources[database.type] - return ( - session.query(datasource_class) - .filter_by(database_id=database.id) - .filter( - or_( - datasource_class.perm.in_(permissions), - datasource_class.schema_perm.in_(schema_perms), + datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] + if isinstance(datasource_class, SqlaTable): + return ( + session.query(datasource_class) + .filter_by(database_id=database.id) + .filter( + or_( + datasource_class.perm.in_(permissions), + datasource_class.schema_perm.in_(schema_perms), + ) ) + .all() ) - .all() - ) + return [] + @classmethod def get_eager_datasource( - self, session: Session, datasource_type: str, datasource_id: int - ) -> Datasource: + cls, session: Session, datasource_type: str, datasource_id: int + ) -> Optional[Datasource]: """Returns datasource with columns and metrics.""" - datasource_class = DatasourceDAO.sources[datasource_type] - return ( - session.query(datasource_class) - .options( - subqueryload(datasource_class.columns), - subqueryload(datasource_class.metrics), + datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]] + if isinstance(datasource_class, SqlaTable): + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics), + ) + .filter_by(id=datasource_id) + .one() ) - .filter_by(id=datasource_id) - .one() - ) + return None + @classmethod def query_datasources_by_name( - self, + cls, session: Session, database: Database, datasource_name: str, schema: Optional[str] = None, ) -> List[Datasource]: - datasource_class = DatasourceDAO.sources[database.type] - return datasource_class.query_datasources_by_name( - session, database, datasource_name, schema=schema - ) + datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] + if isinstance(datasource_class, SqlaTable): + return datasource_class.query_datasources_by_name( + session, database, datasource_name, schema=schema + ) + return [] diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/dao/datasource_test.py index 91a91a4ecc70d..e69cee8699efd 100644 --- a/tests/unit_tests/dao/datasource_test.py +++ b/tests/unit_tests/dao/datasource_test.py @@ -18,12 +18,16 @@ import pytest from sqlalchemy.orm.session import Session +from superset.utils.core import DatasourceType + def create_test_data(session: Session) -> None: from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset from superset.models.core import Database from superset.models.sql_lab import Query, SavedQuery + from superset.tables.models import Table engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member @@ -58,6 +62,32 @@ def create_test_data(session: Session) -> None: saved_query = SavedQuery(database=db, sql="select * from foo") + table = Table( + name="my_table", + schema="my_schema", + catalog="my_catalog", + database=db, + columns=[], + ) + + dataset = Dataset( + database=table.database, + name="positions", + expression=""" +SELECT array_agg(array[longitude,latitude]) AS position +FROM my_catalog.my_schema.my_table +""", + tables=[table], + columns=[ + Column( + name="position", + expression="array_agg(array[longitude,latitude])", + ), + ], + ) + + session.add(dataset) + session.add(table) session.add(saved_query) session.add(query_obj) session.add(db) @@ -72,7 +102,7 @@ def test_get_datasource_sqlatable(app_context: None, session: Session) -> None: create_test_data(session) result = DatasourceDAO.get_datasource( - datasource_type="table", datasource_id=1, session=session + datasource_type=DatasourceType.SQLATABLE, datasource_id=1, session=session ) assert 1 == result.id @@ -87,7 +117,7 @@ def test_get_datasource_query(app_context: None, session: Session) -> None: create_test_data(session) result = DatasourceDAO.get_datasource( - datasource_type="query", datasource_id=1, session=session + datasource_type=DatasourceType.QUERY, datasource_id=1, session=session ) assert result.id == 1 @@ -101,7 +131,7 @@ def test_get_datasource_saved_query(app_context: None, session: Session) -> None create_test_data(session) result = DatasourceDAO.get_datasource( - datasource_type="saved_query", datasource_id=1, session=session + datasource_type=DatasourceType.SAVEDQUERY, datasource_id=1, session=session ) assert result.id == 1 @@ -109,8 +139,43 @@ def test_get_datasource_saved_query(app_context: None, session: Session) -> None def test_get_datasource_sl_table(app_context: None, session: Session) -> None: - pass + from superset.dao.datasource.dao import DatasourceDAO + from superset.tables.models import Table + + create_test_data(session) + + # todo(hugh): This will break once we remove the dual write + # update the datsource_id=1 and this will pass again + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.TABLE, datasource_id=2, session=session + ) + + assert result.id == 2 + assert isinstance(result, Table) def test_get_datasource_sl_dataset(app_context: None, session: Session) -> None: - pass + from superset.dao.datasource.dao import DatasourceDAO + from superset.datasets.models import Dataset + + create_test_data(session) + + # todo(hugh): This will break once we remove the dual write + # update the datsource_id=1 and this will pass again + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.DATASET, datasource_id=2, session=session + ) + + assert result.id == 2 + assert isinstance(result, Dataset) + + +def test_get_all_datasources(app_context: None, session: Session) -> None: + from superset.dao.datasource.dao import DatasourceDAO + + create_test_data(session) + + # todo(hugh): This will break once we remove the dual write + # update the assert len(result) == 5 and this will pass again + result = DatasourceDAO.get_all_datasources(session=session) + assert len(result) == 7 From 9983ce97bf5f4a6516ac7fd19878b2ae4849ce32 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Thu, 12 May 2022 17:31:39 +0000 Subject: [PATCH 3/4] address concerns --- superset/dao/datasource/dao.py | 78 ++++++++++++------------- superset/dao/exceptions.py | 12 ++++ tests/unit_tests/dao/datasource_test.py | 60 ++++++++++--------- 3 files changed, 83 insertions(+), 67 deletions(-) diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py index d58717ad5f60c..51b9bf0935770 100644 --- a/superset/dao/datasource/dao.py +++ b/superset/dao/datasource/dao.py @@ -24,6 +24,7 @@ from superset.connectors.sqla.models import SqlaTable from superset.dao.base import BaseDAO +from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.models import Dataset from superset.models.core import Database @@ -46,10 +47,10 @@ class DatasourceDAO(BaseDAO): @classmethod def get_datasource( - cls, datasource_type: DatasourceType, datasource_id: int, session: Session + cls, session: Session, datasource_type: DatasourceType, datasource_id: int ) -> Datasource: if datasource_type not in cls.sources: - raise DatasetNotFoundError() + raise DatasourceTypeNotSupportedError() datasource = ( session.query(cls.sources[datasource_type]) @@ -58,19 +59,16 @@ def get_datasource( ) if not datasource: - raise DatasetNotFoundError() + raise DatasourceNotFound() return datasource @classmethod - def get_all_datasources(cls, session: Session) -> List[Datasource]: - datasources: List[Datasource] = [] - for source_class in DatasourceDAO.sources.values(): - qry = session.query(source_class) - if isinstance(source_class, SqlaTable): - qry = source_class.default_query(qry) - datasources.extend(qry.all()) - return datasources + def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]: + source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE] + qry = session.query(source_class) + qry = source_class.default_query(qry) + return qry.all() @classmethod def get_datasource_by_name( # pylint: disable=too-many-arguments @@ -78,8 +76,8 @@ def get_datasource_by_name( # pylint: disable=too-many-arguments session: Session, datasource_type: DatasourceType, datasource_name: str, - schema: str, database_name: str, + schema: str, ) -> Optional[Datasource]: datasource_class = DatasourceDAO.sources[datasource_type] if isinstance(datasource_class, SqlaTable): @@ -96,21 +94,22 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name permissions: Set[str], schema_perms: Set[str], ) -> List[Datasource]: - # TODO(bogdan): add unit test + # TODO(hughhhh): add unit test datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] - if isinstance(datasource_class, SqlaTable): - return ( - session.query(datasource_class) - .filter_by(database_id=database.id) - .filter( - or_( - datasource_class.perm.in_(permissions), - datasource_class.schema_perm.in_(schema_perms), - ) + if not isinstance(datasource_class, SqlaTable): + return [] + + return ( + session.query(datasource_class) + .filter_by(database_id=database.id) + .filter( + or_( + datasource_class.perm.in_(permissions), + datasource_class.schema_perm.in_(schema_perms), ) - .all() ) - return [] + .all() + ) @classmethod def get_eager_datasource( @@ -118,17 +117,17 @@ def get_eager_datasource( ) -> Optional[Datasource]: """Returns datasource with columns and metrics.""" datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]] - if isinstance(datasource_class, SqlaTable): - return ( - session.query(datasource_class) - .options( - subqueryload(datasource_class.columns), - subqueryload(datasource_class.metrics), - ) - .filter_by(id=datasource_id) - .one() + if not isinstance(datasource_class, SqlaTable): + return None + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics), ) - return None + .filter_by(id=datasource_id) + .one() + ) @classmethod def query_datasources_by_name( @@ -139,8 +138,9 @@ def query_datasources_by_name( schema: Optional[str] = None, ) -> List[Datasource]: datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] - if isinstance(datasource_class, SqlaTable): - return datasource_class.query_datasources_by_name( - session, database, datasource_name, schema=schema - ) - return [] + if not isinstance(datasource_class, SqlaTable): + return [] + + return datasource_class.query_datasources_by_name( + session, database, datasource_name, schema=schema + ) diff --git a/superset/dao/exceptions.py b/superset/dao/exceptions.py index 822b23982e5d8..9b5624bd5d31d 100644 --- a/superset/dao/exceptions.py +++ b/superset/dao/exceptions.py @@ -53,3 +53,15 @@ class DAOConfigError(DAOException): """ message = "DAO is not configured correctly missing model definition" + + +class DatasourceTypeNotSupportedError(DAOException): + """ + DAO datasource query source type is not supported + """ + + message = "DAO datasource query source type is not supported" + + +class DatasourceNotFound(DAOException): + message = "Datasource does not exist" diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/dao/datasource_test.py index e69cee8699efd..dd0db265e7a02 100644 --- a/tests/unit_tests/dao/datasource_test.py +++ b/tests/unit_tests/dao/datasource_test.py @@ -15,13 +15,16 @@ # specific language governing permissions and limitations # under the License. +from typing import Iterator + import pytest from sqlalchemy.orm.session import Session from superset.utils.core import DatasourceType -def create_test_data(session: Session) -> None: +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset @@ -93,16 +96,19 @@ def create_test_data(session: Session) -> None: session.add(db) session.add(sqla_table) session.flush() + yield session -def test_get_datasource_sqlatable(app_context: None, session: Session) -> None: +def test_get_datasource_sqlatable( + app_context: None, session_with_data: Session +) -> None: from superset.connectors.sqla.models import SqlaTable from superset.dao.datasource.dao import DatasourceDAO - create_test_data(session) - result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.SQLATABLE, datasource_id=1, session=session + datasource_type=DatasourceType.SQLATABLE, + datasource_id=1, + session=session_with_data, ) assert 1 == result.id @@ -110,72 +116,70 @@ def test_get_datasource_sqlatable(app_context: None, session: Session) -> None: assert isinstance(result, SqlaTable) -def test_get_datasource_query(app_context: None, session: Session) -> None: +def test_get_datasource_query(app_context: None, session_with_data: Session) -> None: from superset.dao.datasource.dao import DatasourceDAO from superset.models.sql_lab import Query - create_test_data(session) - result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.QUERY, datasource_id=1, session=session + datasource_type=DatasourceType.QUERY, datasource_id=1, session=session_with_data ) assert result.id == 1 assert isinstance(result, Query) -def test_get_datasource_saved_query(app_context: None, session: Session) -> None: +def test_get_datasource_saved_query( + app_context: None, session_with_data: Session +) -> None: from superset.dao.datasource.dao import DatasourceDAO from superset.models.sql_lab import SavedQuery - create_test_data(session) - result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.SAVEDQUERY, datasource_id=1, session=session + datasource_type=DatasourceType.SAVEDQUERY, + datasource_id=1, + session=session_with_data, ) assert result.id == 1 assert isinstance(result, SavedQuery) -def test_get_datasource_sl_table(app_context: None, session: Session) -> None: +def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None: from superset.dao.datasource.dao import DatasourceDAO from superset.tables.models import Table - create_test_data(session) - # todo(hugh): This will break once we remove the dual write # update the datsource_id=1 and this will pass again result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.TABLE, datasource_id=2, session=session + datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data ) assert result.id == 2 assert isinstance(result, Table) -def test_get_datasource_sl_dataset(app_context: None, session: Session) -> None: +def test_get_datasource_sl_dataset( + app_context: None, session_with_data: Session +) -> None: from superset.dao.datasource.dao import DatasourceDAO from superset.datasets.models import Dataset - create_test_data(session) - # todo(hugh): This will break once we remove the dual write # update the datsource_id=1 and this will pass again result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.DATASET, datasource_id=2, session=session + datasource_type=DatasourceType.DATASET, + datasource_id=2, + session=session_with_data, ) assert result.id == 2 assert isinstance(result, Dataset) -def test_get_all_datasources(app_context: None, session: Session) -> None: +def test_get_all_sqlatables_datasources( + app_context: None, session_with_data: Session +) -> None: from superset.dao.datasource.dao import DatasourceDAO - create_test_data(session) - - # todo(hugh): This will break once we remove the dual write - # update the assert len(result) == 5 and this will pass again - result = DatasourceDAO.get_all_datasources(session=session) - assert len(result) == 7 + result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data) + assert len(result) == 1 From afa8760e69600343a7f3dffc13a44a970b180d43 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Thu, 12 May 2022 18:36:42 +0000 Subject: [PATCH 4/4] remove any --- superset/dao/datasource/dao.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py index 51b9bf0935770..8b4845db3c51b 100644 --- a/superset/dao/datasource/dao.py +++ b/superset/dao/datasource/dao.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Set, Union +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Type, Union from flask_babel import _ from sqlalchemy import or_ @@ -32,12 +33,12 @@ from superset.tables.models import Table from superset.utils.core import DatasourceType -Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery, Any] +Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] class DatasourceDAO(BaseDAO): - sources: Dict[DatasourceType, Datasource] = { + sources: Dict[DatasourceType, Type[Datasource]] = { DatasourceType.SQLATABLE: SqlaTable, DatasourceType.QUERY: Query, DatasourceType.SAVEDQUERY: SavedQuery,