diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py index 46390777acac5..ac57f13f94b3a 100644 --- a/tests/db_engine_specs/hive_tests.py +++ b/tests/db_engine_specs/hive_tests.py @@ -19,9 +19,10 @@ from unittest import mock import pytest - +import pandas as pd +from sqlalchemy.sql import select from tests.test_app import app -from superset.db_engine_specs.hive import HiveEngineSpec +from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException from superset.sql_parse import Table, ParsedQuery @@ -168,6 +169,102 @@ def test_create_table_from_csv_append() -> None: ) +@mock.patch( + "superset.db_engine_specs.hive.config", + {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, +) +@mock.patch("superset.db_engine_specs.hive.g", spec={}) +@mock.patch("tableschema.Table") +def test_create_table_from_csv_if_exists_fail(mock_table, mock_g): + mock_table.infer.return_value = {} + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + with pytest.raises(SupersetException, match="Table already exists"): + HiveEngineSpec.create_table_from_csv( + "foo.csv", Table("foobar"), mock_database, {}, {"if_exists": "fail"} + ) + + +@mock.patch( + "superset.db_engine_specs.hive.config", + {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, +) +@mock.patch("superset.db_engine_specs.hive.g", spec={}) +@mock.patch("tableschema.Table") +def test_create_table_from_csv_if_exists_fail_with_schema(mock_table, mock_g): + mock_table.infer.return_value = {} + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + with pytest.raises(SupersetException, match="Table already exists"): + HiveEngineSpec.create_table_from_csv( + "foo.csv", + Table(table="foobar", schema="schema"), + mock_database, + {}, + {"if_exists": "fail"}, + ) + + +@mock.patch( + "superset.db_engine_specs.hive.config", + {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, +) +@mock.patch("superset.db_engine_specs.hive.g", spec={}) +@mock.patch("tableschema.Table") +@mock.patch("superset.db_engine_specs.hive.upload_to_s3") +def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table, mock_g): + mock_upload_to_s3.return_value = "mock-location" + mock_table.infer.return_value = {} + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + mock_execute = mock.MagicMock(return_value=True) + mock_database.get_sqla_engine.return_value.execute = mock_execute + table_name = "foobar" + + HiveEngineSpec.create_table_from_csv( + "foo.csv", + Table(table=table_name), + mock_database, + {"sep": "mock", "header": 1, "na_values": "mock"}, + {"if_exists": "replace"}, + ) + + mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}") + + +@mock.patch( + "superset.db_engine_specs.hive.config", + {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, +) +@mock.patch("superset.db_engine_specs.hive.g", spec={}) +@mock.patch("tableschema.Table") +@mock.patch("superset.db_engine_specs.hive.upload_to_s3") +def test_create_table_from_csv_if_exists_replace_with_schema( + mock_upload_to_s3, mock_table, mock_g +): + mock_upload_to_s3.return_value = "mock-location" + mock_table.infer.return_value = {} + mock_g.user = True + mock_database = mock.MagicMock() + mock_database.get_df.return_value.empty = False + mock_execute = mock.MagicMock(return_value=True) + mock_database.get_sqla_engine.return_value.execute = mock_execute + table_name = "foobar" + schema = "schema" + HiveEngineSpec.create_table_from_csv( + "foo.csv", + Table(table=table_name, schema=schema), + mock_database, + {"sep": "mock", "header": 1, "na_values": "mock"}, + {"if_exists": "replace"}, + ) + + mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}") + + def test_get_create_table_stmt() -> None: table = Table("employee") schema_def = """eid int, name String, salary String, destination String""" @@ -247,3 +344,107 @@ def is_readonly(sql: str) -> bool: assert is_readonly("EXPLAIN SELECT 1") assert is_readonly("SELECT 1") assert is_readonly("WITH (SELECT 1) bla SELECT * from bla") + + +def test_upload_to_s3_no_bucket_path(): + with pytest.raises( + Exception, + match="No upload bucket specified. You can specify one in the config file.", + ): + upload_to_s3("filename", "prefix", Table("table")) + + +@mock.patch("boto3.client") +@mock.patch( + "superset.db_engine_specs.hive.config", + {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"}, +) +def test_upload_to_s3_client_error(client): + from botocore.exceptions import ClientError + + client.return_value.upload_file.side_effect = ClientError( + {"Error": {}}, "operation_name" + ) + + with pytest.raises(ClientError): + upload_to_s3("filename", "prefix", Table("table")) + + +@mock.patch("boto3.client") +@mock.patch( + "superset.db_engine_specs.hive.config", + {**app.config, "CSV_TO_HIVE_UPLOAD_S3_BUCKET": "bucket"}, +) +def test_upload_to_s3_success(client): + client.return_value.upload_file.return_value = True + + location = upload_to_s3("filename", "prefix", Table("table")) + assert f"s3a://bucket/prefix/table" == location + + +def test_fetch_data_query_error(): + from TCLIService import ttypes + + err_msg = "error message" + cursor = mock.Mock() + cursor.poll.return_value.operationState = ttypes.TOperationState.ERROR_STATE + cursor.poll.return_value.errorMessage = err_msg + with pytest.raises(Exception, match=f"('Query error', '{err_msg})'"): + HiveEngineSpec.fetch_data(cursor) + + +@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.fetch_data") +def test_fetch_data_programming_error(fetch_data_mock): + from pyhive.exc import ProgrammingError + + fetch_data_mock.side_effect = ProgrammingError + cursor = mock.Mock() + assert HiveEngineSpec.fetch_data(cursor) == [] + + +@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.fetch_data") +def test_fetch_data_success(fetch_data_mock): + return_value = ["a", "b"] + fetch_data_mock.return_value = return_value + cursor = mock.Mock() + assert HiveEngineSpec.fetch_data(cursor) == return_value + + +@mock.patch("superset.db_engine_specs.hive.HiveEngineSpec._latest_partition_from_df") +def test_where_latest_partition(mock_method): + mock_method.return_value = ("01-01-19", 1) + db = mock.Mock() + db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}]) + db.get_extra = mock.Mock(return_value={}) + db.get_df = mock.Mock() + columns = [{"name": "ds"}, {"name": "hour"}] + with app.app_context(): + result = HiveEngineSpec.where_latest_partition( + "test_table", "test_schema", db, select(), columns + ) + query_result = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result + + +@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition") +def test_where_latest_partition_super_method_exception(mock_method): + mock_method.side_effect = Exception() + db = mock.Mock() + columns = [{"name": "ds"}, {"name": "hour"}] + with app.app_context(): + result = HiveEngineSpec.where_latest_partition( + "test_table", "test_schema", db, select(), columns + ) + assert result is None + mock_method.assert_called() + + +@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition") +def test_where_latest_partition_no_columns_no_values(mock_method): + mock_method.return_value = ("01-01-19", None) + db = mock.Mock() + with app.app_context(): + result = HiveEngineSpec.where_latest_partition( + "test_table", "test_schema", db, select() + ) + assert result is None