From a071b1768e43b905f369242d4f7de18c02ff3712 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Thu, 3 Sep 2020 12:58:11 +0100 Subject: [PATCH 1/4] feat: CRUD REST API for saved queries --- superset/app.py | 2 + superset/queries/savedqueries/__init__.py | 16 + superset/queries/savedqueries/api.py | 103 ++++++ tests/queries/saved_queries/__init__.py | 16 + tests/queries/saved_queries/api_tests.py | 378 ++++++++++++++++++++++ 5 files changed, 515 insertions(+) create mode 100644 superset/queries/savedqueries/__init__.py create mode 100644 superset/queries/savedqueries/api.py create mode 100644 tests/queries/saved_queries/__init__.py create mode 100644 tests/queries/saved_queries/api_tests.py diff --git a/superset/app.py b/superset/app.py index 4878a83d6afe2..5280922a1d1b0 100644 --- a/superset/app.py +++ b/superset/app.py @@ -143,6 +143,7 @@ def init_views(self) -> None: from superset.databases.api import DatabaseRestApi from superset.datasets.api import DatasetRestApi from superset.queries.api import QueryRestApi + from superset.queries.savedqueries.api import SavedQueryRestApi from superset.views.access_requests import AccessRequestsModelView from superset.views.alerts import ( AlertLogModelView, @@ -198,6 +199,7 @@ def init_views(self) -> None: appbuilder.add_api(DatabaseRestApi) appbuilder.add_api(DatasetRestApi) appbuilder.add_api(QueryRestApi) + appbuilder.add_api(SavedQueryRestApi) # # Setup regular views # diff --git a/superset/queries/savedqueries/__init__.py b/superset/queries/savedqueries/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/queries/savedqueries/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/queries/savedqueries/api.py b/superset/queries/savedqueries/api.py new file mode 100644 index 0000000000000..79be4399a475f --- /dev/null +++ b/superset/queries/savedqueries/api.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.constants import RouteMethod +from superset.databases.filters import DatabaseFilter +from superset.models.sql_lab import SavedQuery +from superset.views.base_api import BaseSupersetModelRestApi + +logger = logging.getLogger(__name__) + + +class SavedQueryRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(SavedQuery) + + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + RouteMethod.RELATED, + RouteMethod.DISTINCT, + } + class_permission_name = "SavedQueryView" + resource_name = "saved_query" + allow_browser_login = True + show_columns = [ + "id", + "schema", + "label", + "description", + "sql", + "user.first_name", + "user.last_name", + "user.id", + "database.database_name", + "database.id", + ] + list_columns = [ + "user_id", + "db_id", + "schema", + "label", + "description", + "sql", + "user.first_name", + "user.last_name", + "user.id", + "database.database_name", + "database.id", + ] + add_columns = [ + "schema", + "label", + "description", + "sql", + "user_id", + "db_id", + ] + edit_columns = add_columns + order_columns = [ + "schema", + "label", + "description", + "sql", + "user.first_name", + "database.database_name", + ] + + openapi_spec_tag = "Queries" + openapi_spec_methods = { + "get": {"get": {"description": "Get a saved query",}}, + "get_list": { + "get": { + "description": "Get a list of saved queries, use Rison or JSON " + "query parameters for filtering, sorting," + " pagination and for selecting specific" + " columns and metadata.", + } + }, + "post": {"post": {"description": "Create saved query",}}, + "put": {"put": {"description": "Update saved query",}}, + "delete": {"delete": {"description": "Delete saved query",}}, + } + + related_field_filters = { + "database": "database_name", + } + filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]} + allowed_rel_fields = {"database"} + allowed_distinct_fields = {"schema"} diff --git a/tests/queries/saved_queries/__init__.py b/tests/queries/saved_queries/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/queries/saved_queries/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py new file mode 100644 index 0000000000000..43231f7c4d53d --- /dev/null +++ b/tests/queries/saved_queries/api_tests.py @@ -0,0 +1,378 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +"""Unit tests for Superset""" +import json +from typing import Optional + +import prison +from sqlalchemy.sql import func, asc + +import tests.test_app +from superset import db, security_manager +from superset.models.core import Database +from superset.models.sql_lab import SavedQuery +from superset.utils.core import get_example_database + +from tests.base_tests import SupersetTestCase + + +class TestSavedQueryApi(SupersetTestCase): + def insert_saved_query( + self, + label: str, + sql: str, + db_id: Optional[int] = None, + user_id: Optional[int] = None, + schema: Optional[str] = "", + ) -> SavedQuery: + database = None + user = None + if db_id: + database = db.session.query(Database).get(db_id) + if user_id: + user = db.session.query(security_manager.user_model).get(user_id) + query = SavedQuery( + database=database, user=user, sql=sql, label=label, schema=schema + ) + db.session.add(query) + db.session.commit() + return query + + def insert_default_saved_query( + self, label: str = "saved1", schema: str = "schema1", + ) -> SavedQuery: + admin = self.get_user("admin") + example_db = get_example_database() + return self.insert_saved_query( + label, + "SELECT col1, col2 from table1", + db_id=example_db.id, + user_id=admin.id, + schema=schema, + ) + + def test_get_list_saved_query(self): + """ + Saved Query API: Test get list saved query + """ + query = self.insert_default_saved_query() + queries = db.session.query(SavedQuery).all() + + self.login(username="admin") + uri = f"api/v1/saved_query/" + rv = self.get_assert_metric(uri, "get_list") + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], len(queries)) + expected_columns = [ + "user_id", + "db_id", + "schema", + "label", + "description", + "sql", + "user", + "database", + ] + for expected_column in expected_columns: + self.assertIn(expected_column, data["result"][0]) + # rollback changes + db.session.delete(query) + db.session.commit() + + def test_get_list_sort_saved_query(self): + """ + Saved Query API: Test get list and sort saved query + """ + num_saved_queries = 5 + saved_queries = [] + for cx in range(num_saved_queries): + saved_queries.append( + self.insert_default_saved_query( + label=f"label{cx}", schema=f"schema{cx}" + ) + ) + all_queries = ( + db.session.query(SavedQuery).order_by(asc(SavedQuery.schema)).all() + ) + + self.login(username="admin") + query_string = {"order_column": "schema", "order_direction": "asc"} + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], len(all_queries)) + for i, query in enumerate(all_queries): + self.assertEqual(query.schema, data["result"][i]["schema"]) + + query_string = { + "order_column": "database.database_name", + "order_direction": "asc", + } + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + self.assertEqual(rv.status_code, 200) + + query_string = {"order_column": "user.first_name", "order_direction": "asc"} + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + self.assertEqual(rv.status_code, 200) + + # rollback changes + for saved_query in saved_queries: + db.session.delete(saved_query) + db.session.commit() + + def test_get_list_filter_saved_query(self): + """ + Saved Query API: Test get list and filter saved query + """ + num_saved_queries = 5 + saved_queries = [] + for cx in range(num_saved_queries): + saved_queries.append( + self.insert_default_saved_query( + label=f"label{cx}", schema=f"schema{cx}" + ) + ) + all_queries = ( + db.session.query(SavedQuery).filter(SavedQuery.label.ilike("%2%")).all() + ) + + self.login(username="admin") + query_string = { + "filters": [{"col": "label", "opr": "ct", "value": "2"}], + } + uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" + rv = self.get_assert_metric(uri, "get_list") + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], len(all_queries)) + # rollback changes + for saved_query in saved_queries: + db.session.delete(saved_query) + db.session.commit() + + def test_info_saved_query(self): + """ + SavedQuery API: Test info + """ + self.login(username="admin") + uri = f"api/v1/saved_query/_info" + rv = self.get_assert_metric(uri, "info") + self.assertEqual(rv.status_code, 200) + + def test_related_saved_query(self): + """ + SavedQuery API: Test related databases + """ + self.login(username="admin") + uri = f"api/v1/saved_query/related/database" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + expected_result = {"count": 1, "result": [{"text": "examples", "value": 1}]} + self.assertEqual(data, expected_result) + + def test_related_saved_query_not_found(self): + """ + SavedQuery API: Test related user not found + """ + self.login(username="admin") + uri = f"api/v1/saved_query/related/user" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_distinct_saved_query(self): + """ + SavedQuery API: Test distinct schemas + """ + query1 = self.insert_default_saved_query(schema="schema1") + query2 = self.insert_default_saved_query(schema="schema2") + + self.login(username="admin") + uri = f"api/v1/saved_query/distinct/schema" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + expected_response = { + "count": 2, + "result": [ + {"text": "schema1", "value": "schema1"}, + {"text": "schema2", "value": "schema2"}, + ], + } + self.assertEqual(data, expected_response) + # Rollback changes + db.session.delete(query1) + db.session.delete(query2) + db.session.commit() + + def test_get_saved_query_not_allowed(self): + """ + SavedQuery API: Test related user not allowed + """ + self.login(username="admin") + uri = f"api/v1/saved_query/wrong" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 405) + + def test_get_saved_query(self): + """ + Saved Query API: Test get saved query + """ + query = self.insert_default_saved_query() + self.login(username="admin") + uri = f"api/v1/saved_query/{query.id}" + rv = self.get_assert_metric(uri, "get") + self.assertEqual(rv.status_code, 200) + + expected_result = { + "id": query.id, + "database": {"id": query.database.id, "database_name": "examples"}, + "description": None, + "user": {"first_name": "admin", "id": query.user_id, "last_name": "user"}, + "sql": "SELECT col1, col2 from table1", + "schema": "schema1", + "label": "saved1", + } + data = json.loads(rv.data.decode("utf-8")) + for key, value in data["result"].items(): + self.assertEqual(value, expected_result[key]) + # rollback changes + db.session.delete(query) + db.session.commit() + + def test_get_saved_query_not_found(self): + """ + Saved Query API: Test get saved query not found + """ + query = self.insert_default_saved_query() + max_id = db.session.query(func.max(SavedQuery.id)).scalar() + self.login(username="admin") + uri = f"api/v1/saved_query/{max_id + 1}" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_create_saved_query(self): + """ + Saved Query API: Test create + """ + admin = self.get_user("admin") + example_db = get_example_database() + + post_data = { + "schema": "schema1", + "label": "label1", + "description": "some description", + "sql": "SELECT col1, col2 from table1", + "user_id": admin.id, + "db_id": example_db.id, + } + + self.login(username="admin") + uri = f"api/v1/saved_query/" + rv = self.client.post(uri, json=post_data) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + saved_query_id = data.get("id") + model = db.session.query(SavedQuery).get(saved_query_id) + for key in post_data: + self.assertEqual(getattr(model, key), data["result"][key]) + + # Rollback changes + db.session.delete(model) + db.session.commit() + + def test_update_saved_query(self): + """ + Saved Query API: Test update + """ + saved_query = self.insert_default_saved_query() + + put_data = { + "schema": "schema_changed", + "label": "label_changed", + } + + self.login(username="admin") + uri = f"api/v1/saved_query/{saved_query.id}" + rv = self.client.put(uri, json=put_data) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + model = db.session.query(SavedQuery).get(saved_query.id) + self.assertEqual(model.label, "label_changed") + self.assertEqual(model.schema, "schema_changed") + # Rollback changes + db.session.delete(saved_query) + db.session.commit() + + def test_update_saved_query_not_found(self): + """ + Saved Query API: Test update not found + """ + saved_query = self.insert_default_saved_query() + + max_id = db.session.query(func.max(SavedQuery.id)).scalar() + self.login(username="admin") + + put_data = { + "schema": "schema_changed", + "label": "label_changed", + } + + uri = f"api/v1/saved_query/{max_id + 1}" + rv = self.client.put(uri, json=put_data) + self.assertEqual(rv.status_code, 404) + + # Rollback changes + db.session.delete(saved_query) + db.session.commit() + + def test_delete_saved_query(self): + """ + Saved Query API: Test delete + """ + saved_query = self.insert_default_saved_query() + + self.login(username="admin") + uri = f"api/v1/saved_query/{saved_query.id}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 200) + + model = db.session.query(SavedQuery).get(saved_query.id) + self.assertIsNone(model) + + def test_delete_saved_query_not_found(self): + """ + Saved Query API: Test delete not found + """ + saved_query = self.insert_default_saved_query() + + max_id = db.session.query(func.max(SavedQuery.id)).scalar() + self.login(username="admin") + uri = f"api/v1/saved_query/{max_id + 1}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 404) + + # Rollback changes + db.session.delete(saved_query) + db.session.commit() From 0402e781778d41a90de175a2988570f51f973cb7 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Thu, 3 Sep 2020 13:19:37 +0100 Subject: [PATCH 2/4] debug test --- tests/queries/saved_queries/api_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index 43231f7c4d53d..ff977325ebc51 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -187,6 +187,7 @@ def test_related_saved_query(self): rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) + raise Exception(data) expected_result = {"count": 1, "result": [{"text": "examples", "value": 1}]} self.assertEqual(data, expected_result) From 58cc7214e5f3653508182e7de6ce1d9a133f91a4 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Thu, 3 Sep 2020 13:44:59 +0100 Subject: [PATCH 3/4] fix test --- superset/queries/savedqueries/api.py | 6 +++--- tests/queries/saved_queries/api_tests.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/superset/queries/savedqueries/api.py b/superset/queries/savedqueries/api.py index 79be4399a475f..0b62d37410c87 100644 --- a/superset/queries/savedqueries/api.py +++ b/superset/queries/savedqueries/api.py @@ -90,9 +90,9 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): " columns and metadata.", } }, - "post": {"post": {"description": "Create saved query",}}, - "put": {"put": {"description": "Update saved query",}}, - "delete": {"delete": {"description": "Delete saved query",}}, + "post": {"post": {"description": "Create a saved query"}}, + "put": {"put": {"description": "Update a saved query"}}, + "delete": {"delete": {"description": "Delete saved query"}}, } related_field_filters = { diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index ff977325ebc51..2019a2eeb16a8 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -183,12 +183,18 @@ def test_related_saved_query(self): SavedQuery API: Test related databases """ self.login(username="admin") + databases = db.session.query(Database).all() + expected_result = { + "count": len(databases), + "result": [ + {"text": str(database), "value": database.id} for database in databases + ], + } + uri = f"api/v1/saved_query/related/database" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) - raise Exception(data) - expected_result = {"count": 1, "result": [{"text": "examples", "value": 1}]} self.assertEqual(data, expected_result) def test_related_saved_query_not_found(self): From 047de3388671cca9291d21b68f9d16f74f6ed610 Mon Sep 17 00:00:00 2001 From: Daniel Gaspar Date: Thu, 3 Sep 2020 15:58:58 +0100 Subject: [PATCH 4/4] use pytest fixtures --- tests/queries/saved_queries/api_tests.py | 164 ++++++++++------------- 1 file changed, 69 insertions(+), 95 deletions(-) diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index 2019a2eeb16a8..73d373089a42c 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -19,6 +19,7 @@ import json from typing import Optional +import pytest import prison from sqlalchemy.sql import func, asc @@ -66,19 +67,37 @@ def insert_default_saved_query( schema=schema, ) + @pytest.fixture() + def create_saved_queries(self): + with self.create_app().app_context(): + num_saved_queries = 5 + saved_queries = [] + for cx in range(num_saved_queries): + saved_queries.append( + self.insert_default_saved_query( + label=f"label{cx}", schema=f"schema{cx}" + ) + ) + yield saved_queries + + # rollback changes + for saved_query in saved_queries: + db.session.delete(saved_query) + db.session.commit() + + @pytest.mark.usefixtures("create_saved_queries") def test_get_list_saved_query(self): """ Saved Query API: Test get list saved query """ - query = self.insert_default_saved_query() queries = db.session.query(SavedQuery).all() self.login(username="admin") uri = f"api/v1/saved_query/" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], len(queries)) + assert data["count"] == len(queries) expected_columns = [ "user_id", "db_id", @@ -90,36 +109,25 @@ def test_get_list_saved_query(self): "database", ] for expected_column in expected_columns: - self.assertIn(expected_column, data["result"][0]) - # rollback changes - db.session.delete(query) - db.session.commit() + assert expected_column in data["result"][0] + @pytest.mark.usefixtures("create_saved_queries") def test_get_list_sort_saved_query(self): """ Saved Query API: Test get list and sort saved query """ - num_saved_queries = 5 - saved_queries = [] - for cx in range(num_saved_queries): - saved_queries.append( - self.insert_default_saved_query( - label=f"label{cx}", schema=f"schema{cx}" - ) - ) all_queries = ( db.session.query(SavedQuery).order_by(asc(SavedQuery.schema)).all() ) - self.login(username="admin") query_string = {"order_column": "schema", "order_direction": "asc"} uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], len(all_queries)) + assert data["count"] == len(all_queries) for i, query in enumerate(all_queries): - self.assertEqual(query.schema, data["result"][i]["schema"]) + assert query.schema == data["result"][i]["schema"] query_string = { "order_column": "database.database_name", @@ -127,47 +135,30 @@ def test_get_list_sort_saved_query(self): } uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 query_string = {"order_column": "user.first_name", "order_direction": "asc"} uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) - - # rollback changes - for saved_query in saved_queries: - db.session.delete(saved_query) - db.session.commit() + assert rv.status_code == 200 + @pytest.mark.usefixtures("create_saved_queries") def test_get_list_filter_saved_query(self): """ Saved Query API: Test get list and filter saved query """ - num_saved_queries = 5 - saved_queries = [] - for cx in range(num_saved_queries): - saved_queries.append( - self.insert_default_saved_query( - label=f"label{cx}", schema=f"schema{cx}" - ) - ) all_queries = ( db.session.query(SavedQuery).filter(SavedQuery.label.ilike("%2%")).all() ) - self.login(username="admin") query_string = { "filters": [{"col": "label", "opr": "ct", "value": "2"}], } uri = f"api/v1/saved_query/?q={prison.dumps(query_string)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], len(all_queries)) - # rollback changes - for saved_query in saved_queries: - db.session.delete(saved_query) - db.session.commit() + assert data["count"] == len(all_queries) def test_info_saved_query(self): """ @@ -176,7 +167,7 @@ def test_info_saved_query(self): self.login(username="admin") uri = f"api/v1/saved_query/_info" rv = self.get_assert_metric(uri, "info") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 def test_related_saved_query(self): """ @@ -193,9 +184,9 @@ def test_related_saved_query(self): uri = f"api/v1/saved_query/related/database" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data, expected_result) + assert data == expected_result def test_related_saved_query_not_found(self): """ @@ -204,32 +195,23 @@ def test_related_saved_query_not_found(self): self.login(username="admin") uri = f"api/v1/saved_query/related/user" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 + @pytest.mark.usefixtures("create_saved_queries") def test_distinct_saved_query(self): """ SavedQuery API: Test distinct schemas """ - query1 = self.insert_default_saved_query(schema="schema1") - query2 = self.insert_default_saved_query(schema="schema2") - self.login(username="admin") uri = f"api/v1/saved_query/distinct/schema" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) expected_response = { - "count": 2, - "result": [ - {"text": "schema1", "value": "schema1"}, - {"text": "schema2", "value": "schema2"}, - ], + "count": 5, + "result": [{"text": f"schema{i}", "value": f"schema{i}"} for i in range(5)], } - self.assertEqual(data, expected_response) - # Rollback changes - db.session.delete(query1) - db.session.delete(query2) - db.session.commit() + assert data == expected_response def test_get_saved_query_not_allowed(self): """ @@ -238,17 +220,20 @@ def test_get_saved_query_not_allowed(self): self.login(username="admin") uri = f"api/v1/saved_query/wrong" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 405) + assert rv.status_code == 405 + @pytest.mark.usefixtures("create_saved_queries") def test_get_saved_query(self): """ Saved Query API: Test get saved query """ - query = self.insert_default_saved_query() + query = ( + db.session.query(SavedQuery).filter(SavedQuery.label == "label1").all()[0] + ) self.login(username="admin") uri = f"api/v1/saved_query/{query.id}" rv = self.get_assert_metric(uri, "get") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_result = { "id": query.id, @@ -257,14 +242,11 @@ def test_get_saved_query(self): "user": {"first_name": "admin", "id": query.user_id, "last_name": "user"}, "sql": "SELECT col1, col2 from table1", "schema": "schema1", - "label": "saved1", + "label": "label1", } data = json.loads(rv.data.decode("utf-8")) for key, value in data["result"].items(): - self.assertEqual(value, expected_result[key]) - # rollback changes - db.session.delete(query) - db.session.commit() + assert value == expected_result[key] def test_get_saved_query_not_found(self): """ @@ -275,7 +257,7 @@ def test_get_saved_query_not_found(self): self.login(username="admin") uri = f"api/v1/saved_query/{max_id + 1}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_create_saved_query(self): """ @@ -297,22 +279,25 @@ def test_create_saved_query(self): uri = f"api/v1/saved_query/" rv = self.client.post(uri, json=post_data) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 saved_query_id = data.get("id") model = db.session.query(SavedQuery).get(saved_query_id) for key in post_data: - self.assertEqual(getattr(model, key), data["result"][key]) + assert getattr(model, key) == data["result"][key] # Rollback changes db.session.delete(model) db.session.commit() + @pytest.mark.usefixtures("create_saved_queries") def test_update_saved_query(self): """ Saved Query API: Test update """ - saved_query = self.insert_default_saved_query() + saved_query = ( + db.session.query(SavedQuery).filter(SavedQuery.label == "label1").all()[0] + ) put_data = { "schema": "schema_changed", @@ -322,22 +307,17 @@ def test_update_saved_query(self): self.login(username="admin") uri = f"api/v1/saved_query/{saved_query.id}" rv = self.client.put(uri, json=put_data) - data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(SavedQuery).get(saved_query.id) - self.assertEqual(model.label, "label_changed") - self.assertEqual(model.schema, "schema_changed") - # Rollback changes - db.session.delete(saved_query) - db.session.commit() + assert model.label == "label_changed" + assert model.schema == "schema_changed" + @pytest.mark.usefixtures("create_saved_queries") def test_update_saved_query_not_found(self): """ Saved Query API: Test update not found """ - saved_query = self.insert_default_saved_query() - max_id = db.session.query(func.max(SavedQuery.id)).scalar() self.login(username="admin") @@ -348,38 +328,32 @@ def test_update_saved_query_not_found(self): uri = f"api/v1/saved_query/{max_id + 1}" rv = self.client.put(uri, json=put_data) - self.assertEqual(rv.status_code, 404) - - # Rollback changes - db.session.delete(saved_query) - db.session.commit() + assert rv.status_code == 404 + @pytest.mark.usefixtures("create_saved_queries") def test_delete_saved_query(self): """ Saved Query API: Test delete """ - saved_query = self.insert_default_saved_query() + saved_query = ( + db.session.query(SavedQuery).filter(SavedQuery.label == "label1").all()[0] + ) self.login(username="admin") uri = f"api/v1/saved_query/{saved_query.id}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(SavedQuery).get(saved_query.id) - self.assertIsNone(model) + assert model is None + @pytest.mark.usefixtures("create_saved_queries") def test_delete_saved_query_not_found(self): """ Saved Query API: Test delete not found """ - saved_query = self.insert_default_saved_query() - max_id = db.session.query(func.max(SavedQuery.id)).scalar() self.login(username="admin") uri = f"api/v1/saved_query/{max_id + 1}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 404) - - # Rollback changes - db.session.delete(saved_query) - db.session.commit() + assert rv.status_code == 404