From b1cda7c28bd311eb1bece066fd8863e5f227fd71 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Wed, 24 May 2023 01:01:46 -0700 Subject: [PATCH] feature(pre-commit): Keeping up with the Joneses --- .pre-commit-config.yaml | 22 +- RELEASING/changelog.py | 29 +- RELEASING/generate_email.py | 8 +- docker/pythonpath_dev/superset_config.py | 7 +- scripts/benchmark_migration.py | 24 +- scripts/cancel_github_workflows.py | 23 +- scripts/permissions_cleanup.py | 6 +- setup.py | 6 +- .../plugins/internet_address.py | 4 +- .../plugins/internet_port.py | 6 +- superset/advanced_data_type/types.py | 10 +- superset/annotation_layers/annotations/api.py | 4 +- .../annotations/commands/bulk_delete.py | 6 +- .../annotations/commands/create.py | 6 +- .../annotations/commands/update.py | 6 +- superset/annotation_layers/annotations/dao.py | 4 +- .../annotation_layers/commands/bulk_delete.py | 6 +- superset/annotation_layers/commands/create.py | 6 +- superset/annotation_layers/commands/update.py | 6 +- superset/annotation_layers/dao.py | 6 +- superset/charts/commands/bulk_delete.py | 6 +- superset/charts/commands/create.py | 6 +- superset/charts/commands/export.py | 4 +- .../charts/commands/importers/dispatcher.py | 4 +- .../charts/commands/importers/v1/__init__.py | 15 +- .../charts/commands/importers/v1/utils.py | 4 +- superset/charts/commands/update.py | 10 +- superset/charts/dao.py | 6 +- superset/charts/data/api.py | 18 +- .../data/commands/create_async_job_command.py | 4 +- .../charts/data/commands/get_data_command.py | 4 +- .../charts/data/query_context_cache_loader.py | 4 +- superset/charts/post_processing.py | 26 +- superset/charts/schemas.py | 6 +- superset/cli/importexport.py | 6 +- superset/cli/main.py | 6 +- superset/cli/native_filters.py | 9 +- superset/cli/thumbnails.py | 4 +- superset/commands/base.py | 6 +- superset/commands/exceptions.py | 12 +- superset/commands/export/assets.py | 4 +- superset/commands/export/models.py | 14 +- superset/commands/importers/v1/__init__.py | 26 +- superset/commands/importers/v1/assets.py | 30 +- superset/commands/importers/v1/examples.py | 18 +- superset/commands/importers/v1/utils.py | 38 +-- superset/commands/utils.py | 10 +- superset/common/chart_data.py | 3 +- superset/common/query_actions.py | 24 +- superset/common/query_context.py | 36 +-- superset/common/query_context_factory.py | 18 +- superset/common/query_context_processor.py | 38 +-- superset/common/query_object.py | 106 +++---- superset/common/query_object_factory.py | 22 +- superset/common/tags.py | 18 +- superset/common/utils/dataframe_utils.py | 4 +- superset/common/utils/query_cache_manager.py | 52 ++-- superset/common/utils/time_range_utils.py | 12 +- superset/config.py | 137 ++++----- superset/connectors/base/models.py | 121 ++++---- superset/connectors/sqla/models.py | 169 +++++------ superset/connectors/sqla/utils.py | 30 +- superset/connectors/sqla/views.py | 2 +- .../css_templates/commands/bulk_delete.py | 6 +- superset/css_templates/dao.py | 4 +- superset/dao/base.py | 16 +- superset/dashboards/commands/bulk_delete.py | 6 +- superset/dashboards/commands/create.py | 10 +- superset/dashboards/commands/export.py | 9 +- .../commands/importers/dispatcher.py | 4 +- superset/dashboards/commands/importers/v0.py | 14 +- .../commands/importers/v1/__init__.py | 20 +- .../dashboards/commands/importers/v1/utils.py | 20 +- superset/dashboards/commands/update.py | 10 +- superset/dashboards/dao.py | 20 +- .../dashboards/filter_sets/commands/create.py | 4 +- .../dashboards/filter_sets/commands/update.py | 4 +- superset/dashboards/filter_sets/dao.py | 4 +- superset/dashboards/filter_sets/schemas.py | 15 +- superset/dashboards/filter_state/api.py | 9 +- superset/dashboards/permalink/types.py | 8 +- superset/dashboards/schemas.py | 8 +- superset/databases/api.py | 6 +- superset/databases/commands/create.py | 6 +- superset/databases/commands/export.py | 7 +- .../commands/importers/dispatcher.py | 4 +- .../commands/importers/v1/__init__.py | 8 +- .../databases/commands/importers/v1/utils.py | 4 +- superset/databases/commands/tables.py | 4 +- .../databases/commands/test_connection.py | 4 +- superset/databases/commands/update.py | 8 +- superset/databases/commands/validate.py | 4 +- superset/databases/commands/validate_sql.py | 12 +- superset/databases/dao.py | 6 +- superset/databases/filters.py | 4 +- superset/databases/schemas.py | 26 +- .../databases/ssh_tunnel/commands/create.py | 6 +- .../databases/ssh_tunnel/commands/update.py | 4 +- superset/databases/ssh_tunnel/dao.py | 4 +- superset/databases/ssh_tunnel/models.py | 4 +- superset/databases/utils.py | 12 +- superset/dataframe.py | 4 +- superset/datasets/commands/bulk_delete.py | 6 +- superset/datasets/commands/create.py | 8 +- superset/datasets/commands/duplicate.py | 6 +- superset/datasets/commands/export.py | 4 +- .../datasets/commands/importers/dispatcher.py | 4 +- superset/datasets/commands/importers/v0.py | 8 +- .../commands/importers/v1/__init__.py | 10 +- .../datasets/commands/importers/v1/utils.py | 6 +- superset/datasets/commands/update.py | 22 +- superset/datasets/dao.py | 28 +- superset/datasets/models.py | 5 +- superset/datasets/schemas.py | 8 +- superset/datasource/dao.py | 4 +- superset/db_engine_specs/__init__.py | 12 +- superset/db_engine_specs/athena.py | 7 +- superset/db_engine_specs/base.py | 256 ++++++++-------- superset/db_engine_specs/bigquery.py | 45 +-- superset/db_engine_specs/clickhouse.py | 22 +- superset/db_engine_specs/crate.py | 6 +- superset/db_engine_specs/databricks.py | 22 +- superset/db_engine_specs/dremio.py | 4 +- superset/db_engine_specs/drill.py | 10 +- superset/db_engine_specs/druid.py | 14 +- superset/db_engine_specs/duckdb.py | 13 +- superset/db_engine_specs/dynamodb.py | 4 +- superset/db_engine_specs/elasticsearch.py | 10 +- superset/db_engine_specs/exasol.py | 4 +- superset/db_engine_specs/firebird.py | 4 +- superset/db_engine_specs/firebolt.py | 4 +- superset/db_engine_specs/gsheets.py | 19 +- superset/db_engine_specs/hana.py | 4 +- superset/db_engine_specs/hive.py | 84 +++--- superset/db_engine_specs/impala.py | 6 +- superset/db_engine_specs/kusto.py | 16 +- superset/db_engine_specs/kylin.py | 4 +- superset/db_engine_specs/mssql.py | 9 +- superset/db_engine_specs/mysql.py | 15 +- superset/db_engine_specs/ocient.py | 33 +- superset/db_engine_specs/oracle.py | 6 +- superset/db_engine_specs/pinot.py | 8 +- superset/db_engine_specs/postgres.py | 31 +- superset/db_engine_specs/presto.py | 146 +++++---- superset/db_engine_specs/redshift.py | 7 +- superset/db_engine_specs/rockset.py | 4 +- superset/db_engine_specs/snowflake.py | 33 +- superset/db_engine_specs/sqlite.py | 9 +- superset/db_engine_specs/starrocks.py | 31 +- superset/db_engine_specs/trino.py | 32 +- superset/embedded/dao.py | 6 +- superset/errors.py | 6 +- superset/examples/bart_lines.py | 2 +- superset/examples/big_data.py | 3 +- superset/examples/birth_names.py | 8 +- superset/examples/countries.py | 8 +- superset/examples/helpers.py | 10 +- superset/examples/multiformat_time_series.py | 4 +- superset/examples/paris.py | 2 +- superset/examples/sf_population_polygons.py | 2 +- .../examples/supported_charts_dashboard.py | 3 +- superset/examples/utils.py | 8 +- superset/examples/world_bank.py | 3 +- superset/exceptions.py | 18 +- superset/explore/commands/get.py | 6 +- superset/explore/permalink/commands/create.py | 4 +- superset/explore/permalink/types.py | 6 +- superset/extensions/__init__.py | 14 +- superset/extensions/metastore_cache.py | 4 +- superset/forms.py | 12 +- superset/initialization/__init__.py | 6 +- superset/jinja_context.py | 56 ++-- superset/key_value/types.py | 10 +- superset/key_value/utils.py | 4 +- superset/legacy.py | 4 +- superset/migrations/env.py | 3 +- .../migrations/shared/migrate_viz/base.py | 8 +- .../migrations/shared/security_converge.py | 15 +- superset/migrations/shared/utils.py | 5 +- ...31_db0c65b146bd_update_slice_model_json.py | 2 +- ...a_rewriting_url_from_shortner_with_new_.py | 2 +- .../2017-10-03_14-37_4736ec66ce19_.py | 10 +- ...11-06_21e88bc06c02_annotation_migration.py | 2 +- ...8-02-13_08-07_e866bd2d4976_smaller_grid.py | 4 +- .../2018-03-20_19-47_f231d82b9b26_.py | 4 +- ...06ae5eb46_cal_heatmap_metric_to_metrics.py | 2 +- ...-06-13_14-54_bddc498dd179_adhoc_filters.py | 2 - ..._migrate_num_period_compare_and_period_.py | 18 +- ...f3fed1fe_convert_dashboard_v1_positions.py | 16 +- ..._migrate_time_range_for_default_filters.py | 4 +- ...d1d_reconvert_legacy_filters_into_adhoc.py | 2 - ...-49_b5998378c225_add_certificate_to_dbs.py | 3 +- ...5563a02_migrate_iframe_to_dash_markdown.py | 3 +- ...ix_data_access_permissions_for_virtual_.py | 2 +- ...ff221_migrate_filter_sets_to_new_format.py | 9 +- ...95_migrate_native_filters_to_new_schema.py | 11 +- ...migrate_pivot_table_v2_heatmaps_to_new_.py | 1 - ...e2e_migrate_timeseries_limit_metric_to_.py | 1 - ...-15_32646df09c64_update_time_grain_sqla.py | 3 +- ..._a9422eeaae74_new_dataset_models_take_2.py | 8 +- superset/models/annotations.py | 4 +- superset/models/core.py | 57 ++-- superset/models/dashboard.py | 48 ++- superset/models/datasource_access_request.py | 10 +- superset/models/embedded_dashboard.py | 3 +- superset/models/filter_set.py | 6 +- superset/models/helpers.py | 190 ++++++------ superset/models/slice.py | 38 +-- superset/models/sql_lab.py | 38 +-- superset/models/sql_types/presto_sql_types.py | 12 +- superset/queries/dao.py | 6 +- .../saved_queries/commands/bulk_delete.py | 6 +- .../queries/saved_queries/commands/export.py | 4 +- .../commands/importers/dispatcher.py | 4 +- .../commands/importers/v1/__init__.py | 10 +- .../commands/importers/v1/utils.py | 4 +- superset/queries/saved_queries/dao.py | 4 +- superset/queries/schemas.py | 3 +- superset/reports/commands/alert.py | 4 +- superset/reports/commands/base.py | 6 +- superset/reports/commands/bulk_delete.py | 6 +- superset/reports/commands/create.py | 10 +- superset/reports/commands/exceptions.py | 5 +- superset/reports/commands/execute.py | 10 +- superset/reports/commands/update.py | 8 +- superset/reports/dao.py | 22 +- superset/reports/filters.py | 2 +- superset/reports/logs/api.py | 4 +- superset/reports/notifications/__init__.py | 1 - superset/reports/notifications/base.py | 7 +- superset/reports/notifications/email.py | 7 +- superset/reports/notifications/slack.py | 4 +- superset/reports/schemas.py | 4 +- superset/result_set.py | 22 +- .../commands/bulk_delete.py | 5 +- .../row_level_security/commands/create.py | 4 +- .../row_level_security/commands/update.py | 4 +- superset/security/api.py | 6 +- superset/security/guest_token.py | 8 +- superset/security/manager.py | 61 ++-- superset/sql_lab.py | 28 +- superset/sql_parse.py | 17 +- superset/sql_validators/__init__.py | 4 +- superset/sql_validators/base.py | 6 +- superset/sql_validators/postgres.py | 6 +- superset/sql_validators/presto_db.py | 8 +- superset/sqllab/api.py | 4 +- superset/sqllab/commands/estimate.py | 8 +- superset/sqllab/commands/execute.py | 14 +- superset/sqllab/commands/export.py | 4 +- superset/sqllab/commands/results.py | 8 +- superset/sqllab/exceptions.py | 26 +- .../sqllab/execution_context_convertor.py | 4 +- superset/sqllab/query_render.py | 16 +- superset/sqllab/sql_json_executer.py | 18 +- superset/sqllab/sqllab_execution_context.py | 36 +-- superset/sqllab/utils.py | 6 +- superset/stats_logger.py | 16 +- superset/superset_typing.py | 35 +-- superset/tables/models.py | 13 +- superset/tags/commands/create.py | 3 +- superset/tags/commands/delete.py | 3 +- superset/tags/dao.py | 12 +- superset/tags/models.py | 40 ++- superset/tasks/__init__.py | 1 - superset/tasks/async_queries.py | 16 +- superset/tasks/cache.py | 18 +- superset/tasks/cron_util.py | 2 +- superset/tasks/utils.py | 12 +- superset/translations/utils.py | 8 +- superset/utils/async_query_manager.py | 14 +- superset/utils/cache.py | 22 +- superset/utils/celery.py | 2 +- superset/utils/core.py | 284 ++++++++---------- superset/utils/csv.py | 6 +- .../dashboard_filter_scopes_converter.py | 36 +-- superset/utils/database.py | 4 +- superset/utils/date_parser.py | 8 +- superset/utils/decorators.py | 9 +- superset/utils/dict_import_export.py | 6 +- superset/utils/encrypt.py | 26 +- superset/utils/feature_flag_manager.py | 5 +- superset/utils/filters.py | 4 +- superset/utils/hashing.py | 4 +- superset/utils/log.py | 67 ++--- superset/utils/machine_auth.py | 4 +- superset/utils/mock_data.py | 34 +-- superset/utils/network.py | 4 +- .../utils/pandas_postprocessing/aggregate.py | 4 +- .../utils/pandas_postprocessing/boxplot.py | 14 +- .../utils/pandas_postprocessing/compare.py | 6 +- .../pandas_postprocessing/contribution.py | 6 +- superset/utils/pandas_postprocessing/cum.py | 3 +- superset/utils/pandas_postprocessing/diff.py | 3 +- .../utils/pandas_postprocessing/flatten.py | 4 +- .../utils/pandas_postprocessing/geography.py | 4 +- superset/utils/pandas_postprocessing/pivot.py | 8 +- .../utils/pandas_postprocessing/rename.py | 4 +- .../utils/pandas_postprocessing/rolling.py | 8 +- .../utils/pandas_postprocessing/select.py | 8 +- superset/utils/pandas_postprocessing/sort.py | 6 +- superset/utils/pandas_postprocessing/utils.py | 13 +- superset/utils/retries.py | 9 +- superset/utils/screenshots.py | 44 +-- superset/utils/ssh_tunnel.py | 8 +- superset/utils/url_map_converters.py | 4 +- superset/utils/webdriver.py | 16 +- superset/views/__init__.py | 2 - superset/views/all_entities.py | 1 - superset/views/base.py | 26 +- superset/views/base_api.py | 48 +-- superset/views/base_schemas.py | 13 +- superset/views/core.py | 72 +++-- superset/views/dashboard/views.py | 10 +- superset/views/database/forms.py | 3 +- superset/views/database/mixins.py | 2 +- superset/views/database/validators.py | 4 +- superset/views/datasource/schemas.py | 4 +- superset/views/datasource/utils.py | 6 +- superset/views/log/dao.py | 6 +- superset/views/tags.py | 1 - superset/views/users/__init__.py | 1 - superset/views/utils.py | 42 +-- superset/viz.py | 202 ++++++------- tests/common/logger_utils.py | 12 +- tests/common/query_context_generator.py | 12 +- .../data_generator/base_generator.py | 5 +- .../birth_names/birth_names_generator.py | 7 +- .../data_loading/data_definitions/types.py | 9 +- .../data_loading/pandas/pandas_data_loader.py | 6 +- .../pandas/pands_data_loading_conf.py | 4 +- .../data_loading/pandas/table_df_convertor.py | 6 +- tests/integration_tests/access_tests.py | 2 +- .../advanced_data_type/api_tests.py | 4 +- tests/integration_tests/base_tests.py | 12 +- .../integration_tests/cachekeys/api_tests.py | 4 +- tests/integration_tests/charts/api_tests.py | 1 - .../charts/data/api_tests.py | 10 +- tests/integration_tests/conftest.py | 8 +- tests/integration_tests/core_tests.py | 30 +- tests/integration_tests/csv_upload_tests.py | 10 +- tests/integration_tests/dashboard_tests.py | 20 +- tests/integration_tests/dashboard_utils.py | 6 +- .../integration_tests/dashboards/api_tests.py | 16 +- .../integration_tests/dashboards/base_case.py | 6 +- .../dashboards/dashboard_test_utils.py | 10 +- .../dashboards/filter_sets/conftest.py | 39 +-- .../filter_sets/create_api_tests.py | 86 +++--- .../filter_sets/delete_api_tests.py | 54 ++-- .../dashboards/filter_sets/get_api_tests.py | 14 +- .../filter_sets/update_api_tests.py | 112 +++---- .../dashboards/filter_sets/utils.py | 24 +- .../dashboards/permalink/api_tests.py | 3 +- .../dashboards/security/base_case.py | 6 +- .../dashboards/superset_factory_util.py | 26 +- .../integration_tests/databases/api_tests.py | 8 +- .../databases/commands_tests.py | 2 +- .../ssh_tunnel/commands/commands_tests.py | 2 +- tests/integration_tests/datasets/api_tests.py | 20 +- .../datasets/commands_tests.py | 6 +- .../db_engine_specs/base_tests.py | 2 - .../db_engine_specs/bigquery_tests.py | 3 +- .../db_engine_specs/hive_tests.py | 4 +- .../dict_import_export_tests.py | 18 +- tests/integration_tests/email_tests.py | 1 - tests/integration_tests/event_logger_tests.py | 4 +- .../explore/permalink/api_tests.py | 9 +- .../explore/permalink/commands_tests.py | 1 - .../fixtures/birth_names_dashboard.py | 4 +- .../integration_tests/fixtures/datasource.py | 5 +- .../fixtures/energy_dashboard.py | 7 +- .../fixtures/importexport.py | 36 +-- .../fixtures/query_context.py | 6 +- .../fixtures/world_bank_dashboard.py | 10 +- .../integration_tests/import_export_tests.py | 26 +- tests/integration_tests/insert_chart_mixin.py | 4 +- .../key_value/commands/fixtures.py | 3 +- tests/integration_tests/model_tests.py | 10 +- .../integration_tests/query_context_tests.py | 4 +- .../integration_tests/reports/alert_tests.py | 8 +- .../reports/commands_tests.py | 12 +- .../reports/scheduler_tests.py | 3 +- tests/integration_tests/reports/utils.py | 14 +- .../security/migrate_roles_tests.py | 1 - .../security/row_level_security_tests.py | 18 +- tests/integration_tests/sql_lab/api_tests.py | 4 +- .../sql_lab/commands_tests.py | 2 +- tests/integration_tests/sqla_models_tests.py | 19 +- tests/integration_tests/sqllab_tests.py | 16 +- tests/integration_tests/strategy_tests.py | 2 - .../integration_tests/superset_test_config.py | 2 +- ..._test_config_sqllab_backend_persist_off.py | 2 - .../superset_test_config_thumbnails.py | 2 +- tests/integration_tests/tagging_tests.py | 1 - tests/integration_tests/tags/api_tests.py | 3 - .../integration_tests/tags/commands_tests.py | 1 - tests/integration_tests/tags/dao_tests.py | 3 - tests/integration_tests/thumbnails_tests.py | 3 +- tests/integration_tests/users/__init__.py | 1 - tests/integration_tests/utils/csv_tests.py | 6 +- .../integration_tests/utils/encrypt_tests.py | 8 +- .../integration_tests/utils/get_dashboards.py | 3 +- .../utils/public_interfaces_test.py | 4 +- tests/integration_tests/utils_tests.py | 8 +- tests/integration_tests/viz_tests.py | 5 +- tests/unit_tests/charts/dao/dao_tests.py | 2 +- .../unit_tests/charts/test_post_processing.py | 1 - .../common/test_query_object_factory.py | 16 +- tests/unit_tests/config_test.py | 4 +- tests/unit_tests/conftest.py | 3 +- tests/unit_tests/dao/queries_test.py | 3 +- .../commands/importers/v1/utils_test.py | 6 +- tests/unit_tests/dashboards/dao_tests.py | 2 +- tests/unit_tests/databases/dao/dao_tests.py | 2 +- .../ssh_tunnel/commands/create_test.py | 1 - .../ssh_tunnel/commands/delete_test.py | 2 +- .../ssh_tunnel/commands/update_test.py | 2 +- .../databases/ssh_tunnel/dao_tests.py | 1 - .../commands/importers/v1/import_test.py | 6 +- tests/unit_tests/datasets/conftest.py | 8 +- tests/unit_tests/datasets/dao/dao_tests.py | 2 +- tests/unit_tests/datasource/dao_tests.py | 2 +- .../unit_tests/db_engine_specs/test_athena.py | 2 +- tests/unit_tests/db_engine_specs/test_base.py | 6 +- .../db_engine_specs/test_clickhouse.py | 6 +- .../db_engine_specs/test_elasticsearch.py | 4 +- .../unit_tests/db_engine_specs/test_mssql.py | 6 +- .../unit_tests/db_engine_specs/test_mysql.py | 8 +- .../unit_tests/db_engine_specs/test_ocient.py | 6 +- .../db_engine_specs/test_postgres.py | 6 +- .../unit_tests/db_engine_specs/test_presto.py | 6 +- .../db_engine_specs/test_starrocks.py | 10 +- .../unit_tests/db_engine_specs/test_trino.py | 22 +- tests/unit_tests/db_engine_specs/utils.py | 14 +- tests/unit_tests/extensions/ssh_test.py | 1 - tests/unit_tests/fixtures/assets_configs.py | 14 +- tests/unit_tests/fixtures/datasets.py | 6 +- tests/unit_tests/models/core_test.py | 4 +- .../unit_tests/pandas_postprocessing/utils.py | 8 +- tests/unit_tests/sql_parse_tests.py | 5 +- tests/unit_tests/tasks/test_cron_util.py | 12 +- tests/unit_tests/tasks/test_utils.py | 16 +- tests/unit_tests/thumbnails/test_digest.py | 20 +- tests/unit_tests/utils/cache_test.py | 1 - tests/unit_tests/utils/date_parser_tests.py | 6 +- tests/unit_tests/utils/test_core.py | 5 +- tests/unit_tests/utils/test_file.py | 1 - tests/unit_tests/utils/urls_tests.py | 1 - 448 files changed, 3084 insertions(+), 3305 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f524b3658b57..07544d66d26db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,14 +15,28 @@ # limitations under the License. # repos: + - repo: https://github.com/MarcoGorelli/auto-walrus + rev: v0.2.2 + hooks: + - id: auto-walrus + - repo: https://github.com/asottile/pyupgrade + rev: v3.4.0 + hooks: + - id: pyupgrade + args: + - --py39-plus + - repo: https://github.com/hadialqattan/pycln + rev: v2.1.2 + hooks: + - id: pycln + args: + - --disable-all-dunder-policy + - --exclude=superset/config.py + - --extend-exclude=tests/integration_tests/superset_test_config.*.py - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: - id: isort - - repo: https://github.com/MarcoGorelli/auto-walrus - rev: v0.2.2 - hooks: - - id: auto-walrus - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.3.0 hooks: diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py index 68a54e10be360..d1ba06a620b7c 100644 --- a/RELEASING/changelog.py +++ b/RELEASING/changelog.py @@ -17,8 +17,9 @@ import os import re import sys +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Optional, Union import click from click.core import Context @@ -67,15 +68,15 @@ class GitChangeLog: def __init__( self, version: str, - logs: List[GitLog], + logs: list[GitLog], access_token: Optional[str] = None, risk: Optional[bool] = False, ) -> None: self._version = version self._logs = logs - self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {} - self._github_login_cache: Dict[str, Optional[str]] = {} - self._github_prs: Dict[int, Any] = {} + self._pr_logs_with_details: dict[int, dict[str, Any]] = {} + self._github_login_cache: dict[str, Optional[str]] = {} + self._github_prs: dict[int, Any] = {} self._wait = 10 github_token = access_token or os.environ.get("GITHUB_TOKEN") self._github = Github(github_token) @@ -126,7 +127,7 @@ def _has_commit_migrations(self, git_sha: str) -> bool: "superset/migrations/versions/" in file.filename for file in commit.files ) - def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]: + def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]: pr_number = git_log.pr_number if pr_number: detail = self._pr_logs_with_details.get(pr_number) @@ -156,7 +157,7 @@ def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]: return detail - def _is_risk_pull_request(self, labels: List[Any]) -> bool: + def _is_risk_pull_request(self, labels: list[Any]) -> bool: for label in labels: risk_label = re.match(SUPERSET_RISKY_LABELS, label.name) if risk_label is not None: @@ -174,8 +175,8 @@ def _get_changelog_version_head(self) -> str: def _parse_change_log( self, - changelog: Dict[str, str], - pr_info: Dict[str, str], + changelog: dict[str, str], + pr_info: dict[str, str], github_login: str, ) -> None: formatted_pr = ( @@ -227,7 +228,7 @@ def __repr__(self) -> str: result += f"**{key}** {changelog[key]}\n" return result - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: for log in self._logs: yield { "pr_number": log.pr_number, @@ -250,20 +251,20 @@ class GitLogs: def __init__(self, git_ref: str) -> None: self._git_ref = git_ref - self._logs: List[GitLog] = [] + self._logs: list[GitLog] = [] @property def git_ref(self) -> str: return self._git_ref @property - def logs(self) -> List[GitLog]: + def logs(self) -> list[GitLog]: return self._logs def fetch(self) -> None: self._logs = list(map(self._parse_log, self._git_logs()))[::-1] - def diff(self, git_logs: "GitLogs") -> List[GitLog]: + def diff(self, git_logs: "GitLogs") -> list[GitLog]: return [log for log in git_logs.logs if log not in self._logs] def __repr__(self) -> str: @@ -284,7 +285,7 @@ def _git_checkout(self, git_ref: str) -> None: print(f"Could not checkout {git_ref}") sys.exit(1) - def _git_logs(self) -> List[str]: + def _git_logs(self) -> list[str]: # let's get current git ref so we can revert it back current_git_ref = self._git_get_current_head() self._git_checkout(self._git_ref) diff --git a/RELEASING/generate_email.py b/RELEASING/generate_email.py index 92536670cda6d..29142557c00a2 100755 --- a/RELEASING/generate_email.py +++ b/RELEASING/generate_email.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, List +from typing import Any from click.core import Context @@ -34,7 +34,7 @@ PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application" -def string_comma_to_list(message: str) -> List[str]: +def string_comma_to_list(message: str) -> list[str]: if not message: return [] return [element.strip() for element in message.split(",")] @@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str: return template.render(kwargs) -class BaseParameters(object): +class BaseParameters: def __init__( self, version: str, @@ -60,7 +60,7 @@ def __init__( ) -> None: self.version = version self.version_rc = version_rc - self.template_arguments: Dict[str, Any] = {} + self.template_arguments: dict[str, Any] = {} def __repr__(self) -> str: return f"Apache Credentials: {self.version}/{self.version_rc}" diff --git a/docker/pythonpath_dev/superset_config.py b/docker/pythonpath_dev/superset_config.py index 6ea9abf63c3b1..199e79f66e6d7 100644 --- a/docker/pythonpath_dev/superset_config.py +++ b/docker/pythonpath_dev/superset_config.py @@ -22,7 +22,6 @@ # import logging import os -from datetime import timedelta from typing import Optional from cachelib.file import FileSystemCache @@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str: error_msg = "The environment variable {} was missing, abort...".format( var_name ) - raise EnvironmentError(error_msg) + raise OSError(error_msg) DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT") @@ -53,7 +52,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str: DATABASE_DB = get_env_variable("DATABASE_DB") # The SQLAlchemy connection string. -SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % ( +SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format( DATABASE_DIALECT, DATABASE_USER, DATABASE_PASSWORD, @@ -80,7 +79,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str: DATA_CACHE_CONFIG = CACHE_CONFIG -class CeleryConfig(object): +class CeleryConfig: broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" imports = ("superset.sql_lab",) result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}" diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index 83c06456a102b..466fab6f130e6 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -23,7 +23,7 @@ from inspect import getsource from pathlib import Path from types import ModuleType -from typing import Any, Dict, List, Set, Type +from typing import Any import click from flask import current_app @@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore return module - raise Exception( - "No module spec found in location: `{path}`".format(path=str(filepath)) - ) + raise Exception(f"No module spec found in location: `{str(filepath)}`") -def extract_modified_tables(module: ModuleType) -> Set[str]: +def extract_modified_tables(module: ModuleType) -> set[str]: """ Extract the tables being modified by a migration script. @@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]: actually traversing the AST. """ - tables: Set[str] = set() + tables: set[str] = set() for function in {"upgrade", "downgrade"}: source = getsource(getattr(module, function)) tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL)) @@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]: return tables -def find_models(module: ModuleType) -> List[Type[Model]]: +def find_models(module: ModuleType) -> list[type[Model]]: """ Find all models in a migration script. """ - models: List[Type[Model]] = [] + models: list[type[Model]] = [] tables = extract_modified_tables(module) # add models defined explicitly in the migration script @@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]: sorter: TopologicalSorter[Any] = TopologicalSorter() for model in models: inspector = inspect(model) - dependent_tables: List[str] = [] + dependent_tables: list[str] = [] for column in inspector.columns.values(): for foreign_key in column.foreign_keys: if foreign_key.column.table.name != model.__tablename__: @@ -174,7 +172,7 @@ def main( print("\nIdentifying models used in the migration:") models = find_models(module) - model_rows: Dict[Type[Model], int] = {} + model_rows: dict[type[Model], int] = {} for model in models: rows = session.query(model).count() print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})") @@ -182,7 +180,7 @@ def main( session.close() print("Benchmarking migration") - results: Dict[str, float] = {} + results: dict[str, float] = {} start = time.time() upgrade(revision=revision) duration = time.time() - start @@ -190,14 +188,14 @@ def main( print(f"Migration on current DB took: {duration:.2f} seconds") min_entities = 10 - new_models: Dict[Type[Model], List[Model]] = defaultdict(list) + new_models: dict[type[Model], list[Model]] = defaultdict(list) while min_entities <= limit: downgrade(revision=down_revision) print(f"Running with at least {min_entities} entities of each model") for model in models: missing = min_entities - model_rows[model] if missing > 0: - entities: List[Model] = [] + entities: list[Model] = [] print(f"- Adding {missing} entities to the {model.__name__} model") bar = ChargingBar("Processing", max=missing) try: diff --git a/scripts/cancel_github_workflows.py b/scripts/cancel_github_workflows.py index 4d30d34adf405..70744c295467b 100755 --- a/scripts/cancel_github_workflows.py +++ b/scripts/cancel_github_workflows.py @@ -33,13 +33,13 @@ ./cancel_github_workflows.py 1024 --include-last """ import os -from typing import Any, Dict, Iterable, Iterator, List, Optional, Union +from collections.abc import Iterable, Iterator +from typing import Any, Literal, Optional, Union import click import requests from click.exceptions import ClickException from dateutil import parser -from typing_extensions import Literal github_token = os.environ.get("GITHUB_TOKEN") github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset") @@ -47,7 +47,7 @@ def request( method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any -) -> Dict[str, Any]: +) -> dict[str, Any]: resp = requests.request( method, f"https://api.github.com/{endpoint.lstrip('/')}", @@ -61,8 +61,8 @@ def request( def list_runs( repo: str, - params: Optional[Dict[str, str]] = None, -) -> Iterator[Dict[str, Any]]: + params: Optional[dict[str, str]] = None, +) -> Iterator[dict[str, Any]]: """List all github workflow runs. Returns: An iterator that will iterate through all pages of matching runs.""" @@ -77,16 +77,15 @@ def list_runs( params={**params, "per_page": 100, "page": page}, ) total_count = result["total_count"] - for item in result["workflow_runs"]: - yield item + yield from result["workflow_runs"] page += 1 -def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]: +def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]: return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel") -def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]: +def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]: return request("GET", f"/repos/{repo}/pulls/{pull_number}") @@ -96,7 +95,7 @@ def get_runs( user: Optional[str] = None, statuses: Iterable[str] = ("queued", "in_progress"), events: Iterable[str] = ("pull_request", "push"), -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Get workflow runs associated with the given branch""" return [ item @@ -108,7 +107,7 @@ def get_runs( ] -def print_commit(commit: Dict[str, Any], branch: str) -> None: +def print_commit(commit: dict[str, Any], branch: str) -> None: """Print out commit message for verification""" indented_message = " \n".join(commit["message"].split("\n")) date_str = ( @@ -155,7 +154,7 @@ def print_commit(commit: Dict[str, Any], branch: str) -> None: def cancel_github_workflows( branch_or_pull: Optional[str], repo: str, - event: List[str], + event: list[str], include_last: bool, include_running: bool, ) -> None: diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py index 5ca75e394cccf..0416f55806821 100644 --- a/scripts/permissions_cleanup.py +++ b/scripts/permissions_cleanup.py @@ -24,7 +24,7 @@ def cleanup_permissions() -> None: pvms = security_manager.get_session.query( security_manager.permissionview_model ).all() - print("# of permission view menus is: {}".format(len(pvms))) + print(f"# of permission view menus is: {len(pvms)}") pvms_dict = defaultdict(list) for pvm in pvms: pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm) @@ -43,7 +43,7 @@ def cleanup_permissions() -> None: pvms = security_manager.get_session.query( security_manager.permissionview_model ).all() - print("Stage 1: # of permission view menus is: {}".format(len(pvms))) + print(f"Stage 1: # of permission view menus is: {len(pvms)}") # 2. Clean up None permissions or view menus pvms = security_manager.get_session.query( @@ -57,7 +57,7 @@ def cleanup_permissions() -> None: pvms = security_manager.get_session.query( security_manager.permissionview_model ).all() - print("Stage 2: # of permission view menus is: {}".format(len(pvms))) + print(f"Stage 2: # of permission view menus is: {len(pvms)}") # 3. Delete empty permission view menus from roles roles = security_manager.get_session.query(security_manager.role_model).all() diff --git a/setup.py b/setup.py index 41f7e11e3893d..d8adea3285a20 100644 --- a/setup.py +++ b/setup.py @@ -14,21 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import io import json import os import subprocess -import sys from setuptools import find_packages, setup BASE_DIR = os.path.abspath(os.path.dirname(__file__)) PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json") -with open(PACKAGE_JSON, "r") as package_file: +with open(PACKAGE_JSON) as package_file: version_string = json.load(package_file)["version"] -with io.open("README.md", "r", encoding="utf-8") as f: +with open("README.md", encoding="utf-8") as f: long_description = f.read() diff --git a/superset/advanced_data_type/plugins/internet_address.py b/superset/advanced_data_type/plugins/internet_address.py index 08a0925846539..8ab20fe2d032d 100644 --- a/superset/advanced_data_type/plugins/internet_address.py +++ b/superset/advanced_data_type/plugins/internet_address.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import ipaddress -from typing import Any, List +from typing import Any from sqlalchemy import Column @@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse: # Make this return a single clause def cidr_translate_filter_func( - col: Column, operator: FilterOperator, values: List[Any] + col: Column, operator: FilterOperator, values: list[Any] ) -> Any: """ Convert a passed in column, FilterOperator and diff --git a/superset/advanced_data_type/plugins/internet_port.py b/superset/advanced_data_type/plugins/internet_port.py index 60a594bfd9b45..8983e41422c8a 100644 --- a/superset/advanced_data_type/plugins/internet_port.py +++ b/superset/advanced_data_type/plugins/internet_port.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import itertools -from typing import Any, Dict, List +from typing import Any from sqlalchemy import Column @@ -26,7 +26,7 @@ ) from superset.utils.core import FilterOperator, FilterStringOperators -port_conversion_dict: Dict[str, List[int]] = { +port_conversion_dict: dict[str, list[int]] = { "http": [80], "ssh": [22], "https": [443], @@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo def port_translate_filter_func( - col: Column, operator: FilterOperator, values: List[Any] + col: Column, operator: FilterOperator, values: list[Any] ) -> Any: """ Convert a passed in column, FilterOperator diff --git a/superset/advanced_data_type/types.py b/superset/advanced_data_type/types.py index 316922f3399da..e8d5de9143570 100644 --- a/superset/advanced_data_type/types.py +++ b/superset/advanced_data_type/types.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, TypedDict, Union +from typing import Any, Callable, Optional, TypedDict, Union from sqlalchemy import Column from sqlalchemy.sql.expression import BinaryExpression @@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict): """ advanced_data_type: str - values: List[ + values: list[ Union[FilterValues, None] ] # unparsed value (usually text when passed from text box) @@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False): """ error_message: Optional[str] - values: List[Any] # parsed value (can be any value) + values: list[Any] # parsed value (can be any value) display_value: str # The string representation of the parsed values - valid_filter_operators: List[FilterStringOperators] + valid_filter_operators: list[FilterStringOperators] @dataclass @@ -54,6 +54,6 @@ class AdvancedDataType: verbose_name: str description: str - valid_data_types: List[str] + valid_data_types: list[str] translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse] translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression] diff --git a/superset/annotation_layers/annotations/api.py b/superset/annotation_layers/annotations/api.py index 0a6a2767f0efd..70e0a1ad02d3a 100644 --- a/superset/annotation_layers/annotations/api.py +++ b/superset/annotation_layers/annotations/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask import request, Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe @@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi): @staticmethod def _apply_layered_relation_to_rison( # pylint: disable=invalid-name - layer_id: int, rison_parameters: Dict[str, Any] + layer_id: int, rison_parameters: dict[str, Any] ) -> None: if "filters" not in rison_parameters: rison_parameters["filters"] = [] diff --git a/superset/annotation_layers/annotations/commands/bulk_delete.py b/superset/annotation_layers/annotations/commands/bulk_delete.py index 113725050fd89..dd47047788a59 100644 --- a/superset/annotation_layers/annotations/commands/bulk_delete.py +++ b/superset/annotation_layers/annotations/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.annotation_layers.annotations.commands.exceptions import ( AnnotationBulkDeleteFailedError, @@ -30,9 +30,9 @@ class BulkDeleteAnnotationCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Annotation]] = None + self._models: Optional[list[Annotation]] = None def run(self) -> None: self.validate() diff --git a/superset/annotation_layers/annotations/commands/create.py b/superset/annotation_layers/annotations/commands/create.py index 0974624561142..986b5642917e6 100644 --- a/superset/annotation_layers/annotations/commands/create.py +++ b/superset/annotation_layers/annotations/commands/create.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -37,7 +37,7 @@ class CreateAnnotationCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -50,7 +50,7 @@ def run(self) -> Model: return annotation def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] layer_id: Optional[int] = self._properties.get("layer") start_dttm: Optional[datetime] = self._properties.get("start_dttm") end_dttm: Optional[datetime] = self._properties.get("end_dttm") diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index b644ddc3622d9..99ab20916501b 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -39,7 +39,7 @@ class UpdateAnnotationCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Annotation] = None @@ -54,7 +54,7 @@ def run(self) -> Model: return annotation def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] layer_id: Optional[int] = self._properties.get("layer") short_descr: str = self._properties.get("short_descr", "") diff --git a/superset/annotation_layers/annotations/dao.py b/superset/annotation_layers/annotations/dao.py index 0c8a9e47c5c06..da69e576e5087 100644 --- a/superset/annotation_layers/annotations/dao.py +++ b/superset/annotation_layers/annotations/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from sqlalchemy.exc import SQLAlchemyError @@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO): model_cls = Annotation @staticmethod - def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] try: db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete( diff --git a/superset/annotation_layers/commands/bulk_delete.py b/superset/annotation_layers/commands/bulk_delete.py index b9bc17e82f3b5..4910dc4275f11 100644 --- a/superset/annotation_layers/commands/bulk_delete.py +++ b/superset/annotation_layers/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.annotation_layers.commands.exceptions import ( AnnotationLayerBulkDeleteFailedError, @@ -31,9 +31,9 @@ class BulkDeleteAnnotationLayerCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[AnnotationLayer]] = None + self._models: Optional[list[AnnotationLayer]] = None def run(self) -> None: self.validate() diff --git a/superset/annotation_layers/commands/create.py b/superset/annotation_layers/commands/create.py index 97431568a9a68..86b0cb3b85893 100644 --- a/superset/annotation_layers/commands/create.py +++ b/superset/annotation_layers/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -33,7 +33,7 @@ class CreateAnnotationLayerCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -46,7 +46,7 @@ def run(self) -> Model: return annotation_layer def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] name = self._properties.get("name", "") diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py index 4a9cc31be5f8d..67d869c0054d3 100644 --- a/superset/annotation_layers/commands/update.py +++ b/superset/annotation_layers/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -35,7 +35,7 @@ class UpdateAnnotationLayerCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[AnnotationLayer] = None @@ -50,7 +50,7 @@ def run(self) -> Model: return annotation_layer def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] name = self._properties.get("name", "") self._model = AnnotationLayerDAO.find_by_id(self._model_id) diff --git a/superset/annotation_layers/dao.py b/superset/annotation_layers/dao.py index d9db4b582d97f..67efc19f88009 100644 --- a/superset/annotation_layers/dao.py +++ b/superset/annotation_layers/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional, Union +from typing import Optional, Union from sqlalchemy.exc import SQLAlchemyError @@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO): @staticmethod def bulk_delete( - models: Optional[List[AnnotationLayer]], commit: bool = True + models: Optional[list[AnnotationLayer]], commit: bool = True ) -> None: item_ids = [model.id for model in models] if models else [] try: @@ -46,7 +46,7 @@ def bulk_delete( raise DAODeleteFailedError() from ex @staticmethod - def has_annotations(model_id: Union[int, List[int]]) -> bool: + def has_annotations(model_id: Union[int, list[int]]) -> bool: if isinstance(model_id, list): return ( db.session.query(AnnotationLayer) diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py index c252f0be4cc28..ac801b7421e7b 100644 --- a/superset/charts/commands/bulk_delete.py +++ b/superset/charts/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from flask_babel import lazy_gettext as _ @@ -37,9 +37,9 @@ class BulkDeleteChartCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Slice]] = None + self._models: Optional[list[Slice]] = None def run(self) -> None: self.validate() diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py index 38076fb9cde8f..78706b3a665c1 100644 --- a/superset/charts/commands/create.py +++ b/superset/charts/commands/create.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import g from flask_appbuilder.models.sqla import Model @@ -37,7 +37,7 @@ class CreateChartCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -56,7 +56,7 @@ def validate(self) -> None: datasource_type = self._properties["datasource_type"] datasource_id = self._properties["datasource_id"] dashboard_ids = self._properties.get("dashboards", []) - owner_ids: Optional[List[int]] = self._properties.get("owners") + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate/Populate datasource try: diff --git a/superset/charts/commands/export.py b/superset/charts/commands/export.py index 9d445cb54e235..22310ade99ce4 100644 --- a/superset/charts/commands/export.py +++ b/superset/charts/commands/export.py @@ -18,7 +18,7 @@ import json import logging -from typing import Iterator, Tuple +from collections.abc import Iterator import yaml @@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand): not_found = ChartNotFoundError @staticmethod - def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]: + def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]: file_name = get_filename(model.slice_name, model.id) file_path = f"charts/{file_name}.yaml" diff --git a/superset/charts/commands/importers/dispatcher.py b/superset/charts/commands/importers/dispatcher.py index afeb9c2820c88..fb5007a50ca29 100644 --- a/superset/charts/commands/importers/dispatcher.py +++ b/superset/charts/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py index ab88038aaabe5..132df21b0815b 100644 --- a/superset/charts/commands/importers/v1/__init__.py +++ b/superset/charts/commands/importers/v1/__init__.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. -import json -from typing import Any, Dict, Set +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand): dao = ChartDAO model_name = "chart" prefix = "charts/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "datasets/": ImportV1DatasetSchema(), "databases/": ImportV1DatabaseSchema(), @@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover datasets associated with charts - dataset_uuids: Set[str] = set() + dataset_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("charts/"): dataset_uuids.add(config["dataset_uuid"]) # discover databases associated with datasets - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids: database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref - datasets: Dict[str, SqlaTable] = {} + datasets: dict[str, SqlaTable] = {} for file_name, config in configs.items(): if ( file_name.startswith("datasets/") diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py index d4aeb17a1e7a8..399e6c2243fa2 100644 --- a/superset/charts/commands/importers/v1/utils.py +++ b/superset/charts/commands/importers/v1/utils.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Any, Dict +from typing import Any from flask import g from sqlalchemy.orm import Session @@ -28,7 +28,7 @@ def import_chart( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Slice: diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index f5fc2616a5a3c..a4265d083539e 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import g from flask_appbuilder.models.sqla import Model @@ -42,14 +42,14 @@ logger = logging.getLogger(__name__) -def is_query_context_update(properties: Dict[str, Any]) -> bool: +def is_query_context_update(properties: dict[str, Any]) -> bool: return set(properties) == {"query_context", "query_context_generation"} and bool( properties.get("query_context_generation") ) class UpdateChartCommand(UpdateMixin, BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Slice] = None @@ -67,9 +67,9 @@ def run(self) -> Model: return chart def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] dashboard_ids = self._properties.get("dashboards") - owner_ids: Optional[List[int]] = self._properties.get("owners") + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate if datasource_id is provided datasource_type is required datasource_id = self._properties.get("datasource_id") diff --git a/superset/charts/dao.py b/superset/charts/dao.py index 7102e6ad234bf..9c6b2c26ef55d 100644 --- a/superset/charts/dao.py +++ b/superset/charts/dao.py @@ -17,7 +17,7 @@ # pylint: disable=arguments-renamed import logging from datetime import datetime -from typing import List, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from sqlalchemy.exc import SQLAlchemyError @@ -39,7 +39,7 @@ class ChartDAO(BaseDAO): base_filter = ChartFilter @staticmethod - def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data if models: @@ -71,7 +71,7 @@ def overwrite(slc: Slice, commit: bool = True) -> None: db.session.commit() @staticmethod - def favorited_ids(charts: List[Slice]) -> List[FavStar]: + def favorited_ids(charts: list[Slice]) -> list[FavStar]: ids = [chart.id for chart in charts] return [ star.obj_id diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 9c620dcf5ddd9..552044ebfa909 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -18,7 +18,7 @@ import json import logging -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import simplejson from flask import current_app, g, make_response, request, Response @@ -315,7 +315,7 @@ def data_from_cache(self, cache_key: str) -> Response: return self._get_data_response(command, True) def _run_async( - self, form_data: Dict[str, Any], command: ChartDataCommand + self, form_data: dict[str, Any], command: ChartDataCommand ) -> Response: """ Execute command as an async query. @@ -344,9 +344,9 @@ def _run_async( def _send_chart_response( self, - result: Dict[Any, Any], - form_data: Optional[Dict[str, Any]] = None, - datasource: Optional[Union[BaseDatasource, Query]] = None, + result: dict[Any, Any], + form_data: dict[str, Any] | None = None, + datasource: BaseDatasource | Query | None = None, ) -> Response: result_type = result["query_context"].result_type result_format = result["query_context"].result_format @@ -408,8 +408,8 @@ def _get_data_response( self, command: ChartDataCommand, force_cached: bool = False, - form_data: Optional[Dict[str, Any]] = None, - datasource: Optional[Union[BaseDatasource, Query]] = None, + form_data: dict[str, Any] | None = None, + datasource: BaseDatasource | Query | None = None, ) -> Response: try: result = command.run(force_cached=force_cached) @@ -421,12 +421,12 @@ def _get_data_response( return self._send_chart_response(result, form_data, datasource) # pylint: disable=invalid-name, no-self-use - def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]: + def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]: return QueryContextCacheLoader.load(cache_key) # pylint: disable=no-self-use def _create_query_context_from_form( - self, form_data: Dict[str, Any] + self, form_data: dict[str, Any] ) -> QueryContext: try: return ChartDataQueryContextSchema().load(form_data) diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py index c4e25f742baa3..fb6e3f3dbff34 100644 --- a/superset/charts/data/commands/create_async_job_command.py +++ b/superset/charts/data/commands/create_async_job_command.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Request @@ -32,7 +32,7 @@ def validate(self, request: Request) -> None: jwt_data = async_query_manager.parse_jwt_from_request(request) self._async_channel_id = jwt_data["channel"] - def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]: + def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]: job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) load_chart_data_into_cache.delay(job_metadata, form_data) return job_metadata diff --git a/superset/charts/data/commands/get_data_command.py b/superset/charts/data/commands/get_data_command.py index 819693607bfb7..a84870a1dd306 100644 --- a/superset/charts/data/commands/get_data_command.py +++ b/superset/charts/data/commands/get_data_command.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_babel import lazy_gettext as _ @@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand): def __init__(self, query_context: QueryContext): self._query_context = query_context - def run(self, **kwargs: Any) -> Dict[str, Any]: + def run(self, **kwargs: Any) -> dict[str, Any]: # caching is handled in query_context.get_df_payload # (also evals `force` property) cache_query_context = kwargs.get("cache", False) diff --git a/superset/charts/data/query_context_cache_loader.py b/superset/charts/data/query_context_cache_loader.py index b5ff3bdae87e3..97fa733a3e4ad 100644 --- a/superset/charts/data/query_context_cache_loader.py +++ b/superset/charts/data/query_context_cache_loader.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from superset import cache from superset.charts.commands.exceptions import ChartDataCacheLoadError @@ -22,7 +22,7 @@ class QueryContextCacheLoader: # pylint: disable=too-few-public-methods @staticmethod - def load(cache_key: str) -> Dict[str, Any]: + def load(cache_key: str) -> dict[str, Any]: cache_value = cache.get(cache_key) if not cache_value: raise ChartDataCacheLoadError("Cached data not found") diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index 1165769fc8df4..a6b64c08d68b6 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -27,7 +27,7 @@ """ from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import pandas as pd from flask_babel import gettext as __ @@ -45,14 +45,14 @@ from superset.models.sql_lab import Query -def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: +def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]: """ Sort columns when combining metrics. MultiIndex labels have the metric name as the last element in the tuple. We want to sort these according to the list of passed metrics. """ - parts: List[Any] = list(label) + parts: list[Any] = list(label) metric = parts[-1] parts[-1] = metrics.index(metric) return tuple(parts) @@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ... def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches df: pd.DataFrame, - rows: List[str], - columns: List[str], - metrics: List[str], + rows: list[str], + columns: list[str], + metrics: list[str], aggfunc: str = "Sum", transpose_pivot: bool = False, combine_metrics: bool = False, @@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str: """ List unique values in a series. """ - return ", ".join(set(str(v) for v in pd.Series.unique(series))) + return ", ".join({str(v) for v in pd.Series.unique(series)}) pivot_v2_aggfunc_map = { @@ -223,7 +223,7 @@ def list_unique_values(series: pd.Series) -> str: def pivot_table_v2( df: pd.DataFrame, - form_data: Dict[str, Any], + form_data: dict[str, Any], datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> pd.DataFrame: """ @@ -249,7 +249,7 @@ def pivot_table_v2( def pivot_table( df: pd.DataFrame, - form_data: Dict[str, Any], + form_data: dict[str, Any], datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> pd.DataFrame: """ @@ -285,7 +285,7 @@ def pivot_table( def table( df: pd.DataFrame, - form_data: Dict[str, Any], + form_data: dict[str, Any], datasource: Optional[ # pylint: disable=unused-argument Union["BaseDatasource", "Query"] ] = None, @@ -315,10 +315,10 @@ def table( def apply_post_process( - result: Dict[Any, Any], - form_data: Optional[Dict[str, Any]] = None, + result: dict[Any, Any], + form_data: Optional[dict[str, Any]] = None, datasource: Optional[Union["BaseDatasource", "Query"]] = None, -) -> Dict[Any, Any]: +) -> dict[Any, Any]: form_data = form_data or {} viz_type = form_data.get("viz_type") diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 44252ef06f9aa..373600cd08a4c 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -18,7 +18,7 @@ from __future__ import annotations import inspect -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate @@ -1383,7 +1383,7 @@ class Meta: # pylint: disable=too-few-public-methods class ChartDataQueryContextSchema(Schema): - query_context_factory: Optional[QueryContextFactory] = None + query_context_factory: QueryContextFactory | None = None datasource = fields.Nested(ChartDataDatasourceSchema) queries = fields.List(fields.Nested(ChartDataQueryObjectSchema)) custom_cache_timeout = fields.Integer( @@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema): # pylint: disable=unused-argument @post_load - def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext: + def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext: query_context = self.get_query_context_factory().create(**data) return query_context diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py index c7689569c2436..86f6fe9b67e71 100755 --- a/superset/cli/importexport.py +++ b/superset/cli/importexport.py @@ -18,7 +18,7 @@ import sys from datetime import datetime from pathlib import Path -from typing import List, Optional +from typing import Optional from zipfile import is_zipfile, ZipFile import click @@ -309,7 +309,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None: from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand path_object = Path(path) - files: List[Path] = [] + files: list[Path] = [] if path_object.is_file(): files.append(path_object) elif path_object.exists() and not recursive: @@ -363,7 +363,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None: sync_metrics = "metrics" in sync_array path_object = Path(path) - files: List[Path] = [] + files: list[Path] = [] if path_object.is_file(): files.append(path_object) elif path_object.exists() and not recursive: diff --git a/superset/cli/main.py b/superset/cli/main.py index 006f8eb5c9e80..536617cadd882 100755 --- a/superset/cli/main.py +++ b/superset/cli/main.py @@ -18,7 +18,7 @@ import importlib import logging import pkgutil -from typing import Any, Dict +from typing import Any import click from colorama import Fore, Style @@ -40,7 +40,7 @@ def superset() -> None: """This is a management script for the Superset application.""" @app.shell_context_processor - def make_shell_context() -> Dict[str, Any]: + def make_shell_context() -> dict[str, Any]: return dict(app=app, db=db) @@ -79,5 +79,5 @@ def version(verbose: bool) -> None: ) print(Fore.BLUE + "-=" * 15) if verbose: - print("[DB] : " + "{}".format(db.engine)) + print("[DB] : " + f"{db.engine}") print(Style.RESET_ALL) diff --git a/superset/cli/native_filters.py b/superset/cli/native_filters.py index 63cc185e8ed7d..a25724d38d54a 100644 --- a/superset/cli/native_filters.py +++ b/superset/cli/native_filters.py @@ -17,7 +17,6 @@ import json from copy import deepcopy from textwrap import dedent -from typing import Set, Tuple import click from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup @@ -102,7 +101,7 @@ def native_filters() -> None: ) def upgrade( all_: bool, # pylint: disable=unused-argument - dashboard_ids: Tuple[int, ...], + dashboard_ids: tuple[int, ...], ) -> None: """ Upgrade legacy filter-box charts to native dashboard filters. @@ -251,7 +250,7 @@ def upgrade( ) def downgrade( all_: bool, # pylint: disable=unused-argument - dashboard_ids: Tuple[int, ...], + dashboard_ids: tuple[int, ...], ) -> None: """ Downgrade native dashboard filters to legacy filter-box charts (where applicable). @@ -347,7 +346,7 @@ def downgrade( ) def cleanup( all_: bool, # pylint: disable=unused-argument - dashboard_ids: Tuple[int, ...], + dashboard_ids: tuple[int, ...], ) -> None: """ Cleanup obsolete legacy filter-box charts and interim metadata. @@ -355,7 +354,7 @@ def cleanup( Note this operation is irreversible. """ - slice_ids: Set[int] = set() + slice_ids: set[int] = set() # Cleanup the dashboard which contains legacy fields used for downgrading. for dashboard in ( diff --git a/superset/cli/thumbnails.py b/superset/cli/thumbnails.py index 276d9981c1ec2..325fab6853d60 100755 --- a/superset/cli/thumbnails.py +++ b/superset/cli/thumbnails.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Type, Union +from typing import Union import click from celery.utils.abstract import CallableTask @@ -75,7 +75,7 @@ def compute_thumbnails( def compute_generic_thumbnail( friendly_type: str, - model_cls: Union[Type[Dashboard], Type[Slice]], + model_cls: Union[type[Dashboard], type[Slice]], model_id: int, compute_func: CallableTask, ) -> None: diff --git a/superset/commands/base.py b/superset/commands/base.py index 42d5956312cd3..caca50755de9b 100644 --- a/superset/commands/base.py +++ b/superset/commands/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Optional from flask_appbuilder.security.sqla.models import User @@ -45,7 +45,7 @@ def validate(self) -> None: class CreateMixin: # pylint: disable=too-few-public-methods @staticmethod - def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: + def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]: """ Populate list of owners, defaulting to the current user if `owner_ids` is undefined or empty. If current user is missing in `owner_ids`, current user @@ -60,7 +60,7 @@ def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: class UpdateMixin: # pylint: disable=too-few-public-methods @staticmethod - def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: + def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]: """ Populate list of owners. If current user is missing in `owner_ids`, current user is added unless belonging to the Admin role. diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index db9d1b6c63916..4398d740c5201 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import lazy_gettext as _ from marshmallow import ValidationError @@ -59,7 +59,7 @@ class CommandInvalidError(CommandException): def __init__( self, message: str = "", - exceptions: Optional[List[ValidationError]] = None, + exceptions: Optional[list[ValidationError]] = None, ) -> None: self._exceptions = exceptions or [] super().__init__(message) @@ -67,14 +67,14 @@ def __init__( def append(self, exception: ValidationError) -> None: self._exceptions.append(exception) - def extend(self, exceptions: List[ValidationError]) -> None: + def extend(self, exceptions: list[ValidationError]) -> None: self._exceptions.extend(exceptions) - def get_list_classnames(self) -> List[str]: + def get_list_classnames(self) -> list[str]: return list(sorted({ex.__class__.__name__ for ex in self._exceptions})) - def normalized_messages(self) -> Dict[Any, Any]: - errors: Dict[Any, Any] = {} + def normalized_messages(self) -> dict[Any, Any]: + errors: dict[Any, Any] = {} for exception in self._exceptions: errors.update(exception.normalized_messages()) return errors diff --git a/superset/commands/export/assets.py b/superset/commands/export/assets.py index 9f088af428d04..1bd2cf6d61ffa 100644 --- a/superset/commands/export/assets.py +++ b/superset/commands/export/assets.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Iterator from datetime import datetime, timezone -from typing import Iterator, Tuple import yaml @@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand): Command that exports all databases, datasets, charts, dashboards and saved queries. """ - def run(self) -> Iterator[Tuple[str, str]]: + def run(self) -> Iterator[tuple[str, str]]: metadata = { "version": EXPORT_VERSION, "type": "assets", diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py index 4edafaa7464d0..3f21f29281c44 100644 --- a/superset/commands/export/models.py +++ b/superset/commands/export/models.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Iterator from datetime import datetime, timezone -from typing import Iterator, List, Tuple, Type import yaml from flask_appbuilder import Model @@ -30,21 +30,21 @@ class ExportModelsCommand(BaseCommand): - dao: Type[BaseDAO] = BaseDAO - not_found: Type[CommandException] = CommandException + dao: type[BaseDAO] = BaseDAO + not_found: type[CommandException] = CommandException - def __init__(self, model_ids: List[int], export_related: bool = True): + def __init__(self, model_ids: list[int], export_related: bool = True): self.model_ids = model_ids self.export_related = export_related # this will be set when calling validate() - self._models: List[Model] = [] + self._models: list[Model] = [] @staticmethod - def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]: + def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]: raise NotImplementedError("Subclasses MUST implement _export") - def run(self) -> Iterator[Tuple[str, str]]: + def run(self) -> Iterator[tuple[str, str]]: self.validate() metadata = { diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index a67828bdb283d..09830bf3cf727 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Set +from typing import Any, Optional from marshmallow import Schema, validate from marshmallow.exceptions import ValidationError @@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand): dao = BaseDAO model_name = "model" prefix = "" - schemas: Dict[str, Schema] = {} + schemas: dict[str, Schema] = {} import_error = CommandException # pylint: disable=unused-argument - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents - self.passwords: Dict[str, str] = kwargs.get("passwords") or {} - self.ssh_tunnel_passwords: Dict[str, str] = ( + self.passwords: dict[str, str] = kwargs.get("passwords") or {} + self.ssh_tunnel_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_passwords") or {} ) - self.ssh_tunnel_private_keys: Dict[str, str] = ( + self.ssh_tunnel_private_keys: dict[str, str] = ( kwargs.get("ssh_tunnel_private_keys") or {} ) - self.ssh_tunnel_priv_key_passwords: Dict[str, str] = ( + self.ssh_tunnel_priv_key_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_priv_key_passwords") or {} ) self.overwrite: bool = kwargs.get("overwrite", False) - self._configs: Dict[str, Any] = {} + self._configs: dict[str, Any] = {} @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: raise NotImplementedError("Subclasses MUST implement _import") @classmethod - def _get_uuids(cls) -> Set[str]: + def _get_uuids(cls) -> set[str]: return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()} def run(self) -> None: @@ -84,11 +84,11 @@ def run(self) -> None: raise self.import_error() from ex def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] # verify that the metadata file is present and valid try: - metadata: Optional[Dict[str, str]] = load_metadata(self.contents) + metadata: Optional[dict[str, str]] = load_metadata(self.contents) except ValidationError as exc: exceptions.append(exc) metadata = None @@ -114,7 +114,7 @@ def validate(self) -> None: ) def _prevent_overwrite_existing_model( # pylint: disable=invalid-name - self, exceptions: List[ValidationError] + self, exceptions: list[ValidationError] ) -> None: """check if the object exists and shouldn't be overwritten""" if not self.overwrite: diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index ce8b46c2a0c46..1ab2e486cf3dd 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from marshmallow import Schema from marshmallow.exceptions import ValidationError @@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand): and will overwrite everything. """ - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), "datasets/": ImportV1DatasetSchema(), @@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand): } # pylint: disable=unused-argument - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents - self.passwords: Dict[str, str] = kwargs.get("passwords") or {} - self.ssh_tunnel_passwords: Dict[str, str] = ( + self.passwords: dict[str, str] = kwargs.get("passwords") or {} + self.ssh_tunnel_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_passwords") or {} ) - self.ssh_tunnel_private_keys: Dict[str, str] = ( + self.ssh_tunnel_private_keys: dict[str, str] = ( kwargs.get("ssh_tunnel_private_keys") or {} ) - self.ssh_tunnel_priv_key_passwords: Dict[str, str] = ( + self.ssh_tunnel_priv_key_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_priv_key_passwords") or {} ) - self._configs: Dict[str, Any] = {} + self._configs: dict[str, Any] = {} @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: + def _import(session: Session, configs: dict[str, Any]) -> None: # import databases first - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database(session, config, overwrite=True) @@ -95,7 +95,7 @@ def _import(session: Session, configs: Dict[str, Any]) -> None: import_saved_query(session, config, overwrite=True) # import datasets - dataset_info: Dict[str, Dict[str, Any]] = {} + dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if file_name.startswith("datasets/"): config["database_id"] = database_ids[config["database_uuid"]] @@ -107,7 +107,7 @@ def _import(session: Session, configs: Dict[str, Any]) -> None: } # import charts - chart_ids: Dict[str, int] = {} + chart_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("charts/"): config.update(dataset_info[config["dataset_uuid"]]) @@ -121,7 +121,7 @@ def _import(session: Session, configs: Dict[str, Any]) -> None: dashboard = import_dashboard(session, config, overwrite=True) # set ref in the dashboard_slices table - dashboard_chart_ids: List[Dict[str, int]] = [] + dashboard_chart_ids: list[dict[str, int]] = [] for uuid in find_chart_uuids(config["position"]): if uuid not in chart_ids: break @@ -151,11 +151,11 @@ def run(self) -> None: raise ImportFailedError() from ex def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] # verify that the metadata file is present and valid try: - metadata: Optional[Dict[str, str]] = load_metadata(self.contents) + metadata: Optional[dict[str, str]] = load_metadata(self.contents) except ValidationError as exc: exceptions.append(exc) metadata = None diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index 35efdb13934ff..4c20e93ff7434 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand): dao = BaseDAO model_name = "model" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), "datasets/": ImportV1DatasetSchema(), @@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand): } import_error = CommandException - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): super().__init__(contents, *args, **kwargs) self.force_data = kwargs.get("force_data", False) @@ -81,7 +81,7 @@ def run(self) -> None: raise self.import_error() from ex @classmethod - def _get_uuids(cls) -> Set[str]: + def _get_uuids(cls) -> set[str]: # pylint: disable=protected-access return ( ImportDatabasesCommand._get_uuids() @@ -93,12 +93,12 @@ def _get_uuids(cls) -> Set[str]: @staticmethod def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches session: Session, - configs: Dict[str, Any], + configs: dict[str, Any], overwrite: bool = False, force_data: bool = False, ) -> None: # import databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database( @@ -114,7 +114,7 @@ def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-bran # database was created before its UUID was frozen, so it has a random UUID. # We need to determine its ID so we can point the dataset to it. examples_db = get_example_database() - dataset_info: Dict[str, Dict[str, Any]] = {} + dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if file_name.startswith("datasets/"): # find the ID of the corresponding database @@ -153,7 +153,7 @@ def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-bran } # import charts - chart_ids: Dict[str, int] = {} + chart_ids: dict[str, int] = {} for file_name, config in configs.items(): if ( file_name.startswith("charts/") @@ -175,7 +175,7 @@ def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-bran ).fetchall() # import dashboards - dashboard_chart_ids: List[Tuple[int, int]] = [] + dashboard_chart_ids: list[tuple[int, int]] = [] for file_name, config in configs.items(): if file_name.startswith("dashboards/"): try: diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py index c8fb97c53ddbf..8ca008b3e23bf 100644 --- a/superset/commands/importers/v1/utils.py +++ b/superset/commands/importers/v1/utils.py @@ -15,7 +15,7 @@ import logging from pathlib import Path, PurePosixPath -from typing import Any, Dict, List, Optional +from typing import Any, Optional from zipfile import ZipFile import yaml @@ -46,7 +46,7 @@ class MetadataSchema(Schema): timestamp = fields.DateTime() -def load_yaml(file_name: str, content: str) -> Dict[str, Any]: +def load_yaml(file_name: str, content: str) -> dict[str, Any]: """Try to load a YAML file""" try: return yaml.safe_load(content) @@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]: raise ValidationError({file_name: "Not a valid YAML file"}) from ex -def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: +def load_metadata(contents: dict[str, str]) -> dict[str, str]: """Apply validation and load a metadata file""" if METADATA_FILE_NAME not in contents: # if the contents have no METADATA_FILE_NAME this is probably @@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: def validate_metadata_type( - metadata: Optional[Dict[str, str]], + metadata: Optional[dict[str, str]], type_: str, - exceptions: List[ValidationError], + exceptions: list[ValidationError], ) -> None: """Validate that the type declared in METADATA_FILE_NAME is correct""" if metadata and "type" in metadata: @@ -96,35 +96,35 @@ def validate_metadata_type( # pylint: disable=too-many-locals,too-many-arguments def load_configs( - contents: Dict[str, str], - schemas: Dict[str, Schema], - passwords: Dict[str, str], - exceptions: List[ValidationError], - ssh_tunnel_passwords: Dict[str, str], - ssh_tunnel_private_keys: Dict[str, str], - ssh_tunnel_priv_key_passwords: Dict[str, str], -) -> Dict[str, Any]: - configs: Dict[str, Any] = {} + contents: dict[str, str], + schemas: dict[str, Schema], + passwords: dict[str, str], + exceptions: list[ValidationError], + ssh_tunnel_passwords: dict[str, str], + ssh_tunnel_private_keys: dict[str, str], + ssh_tunnel_priv_key_passwords: dict[str, str], +) -> dict[str, Any]: + configs: dict[str, Any] = {} # load existing databases so we can apply the password validation - db_passwords: Dict[str, str] = { + db_passwords: dict[str, str] = { str(uuid): password for uuid, password in db.session.query(Database.uuid, Database.password).all() } # load existing ssh_tunnels so we can apply the password validation - db_ssh_tunnel_passwords: Dict[str, str] = { + db_ssh_tunnel_passwords: dict[str, str] = { str(uuid): password for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all() } # load existing ssh_tunnels so we can apply the private_key validation - db_ssh_tunnel_private_keys: Dict[str, str] = { + db_ssh_tunnel_private_keys: dict[str, str] = { str(uuid): private_key for uuid, private_key in db.session.query( SSHTunnel.uuid, SSHTunnel.private_key ).all() } # load existing ssh_tunnels so we can apply the private_key_password validation - db_ssh_tunnel_priv_key_passws: Dict[str, str] = { + db_ssh_tunnel_priv_key_passws: dict[str, str] = { str(uuid): private_key_password for uuid, private_key_password in db.session.query( SSHTunnel.uuid, SSHTunnel.private_key_password @@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool: return True -def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]: +def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]: return { remove_root(file_name): bundle.read(file_name).decode() for file_name in bundle.namelist() diff --git a/superset/commands/utils.py b/superset/commands/utils.py index ad58bb40506f7..7bb13984f8c69 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from flask import g from flask_appbuilder.security.sqla.models import Role, User @@ -37,9 +37,9 @@ def populate_owners( - owner_ids: Optional[List[int]], + owner_ids: list[int] | None, default_to_user: bool, -) -> List[User]: +) -> list[User]: """ Helper function for commands, will fetch all users from owners id's @@ -63,13 +63,13 @@ def populate_owners( return owners -def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]: +def populate_roles(role_ids: list[int] | None = None) -> list[Role]: """ Helper function for commands, will fetch all roles from roles id's :raises RolesNotFoundValidationError: If a role in the input list is not found :param role_ids: A List of roles by id's """ - roles: List[Role] = [] + roles: list[Role] = [] if role_ids: roles = security_manager.find_roles_by_id(role_ids) if len(roles) != len(role_ids): diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py index 659a640159378..65c0c43c11ae0 100644 --- a/superset/common/chart_data.py +++ b/superset/common/chart_data.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from typing import Set class ChartDataResultFormat(str, Enum): @@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum): XLSX = "xlsx" @classmethod - def table_like(cls) -> Set["ChartDataResultFormat"]: + def table_like(cls) -> set["ChartDataResultFormat"]: return {cls.CSV} | {cls.XLSX} diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index f6f5a5cd62cfb..22c778b77be67 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -17,7 +17,7 @@ from __future__ import annotations import copy -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask_babel import _ @@ -49,7 +49,7 @@ def _get_datasource( def _get_columns( query_context: QueryContext, query_obj: QueryObject, _: bool -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) return { "data": [ @@ -65,7 +65,7 @@ def _get_columns( def _get_timegrains( query_context: QueryContext, query_obj: QueryObject, _: bool -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) return { "data": [ @@ -83,7 +83,7 @@ def _get_query( query_context: QueryContext, query_obj: QueryObject, _: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result = {"language": datasource.query_language} try: @@ -96,8 +96,8 @@ def _get_query( def _get_full( query_context: QueryContext, query_obj: QueryObject, - force_cached: Optional[bool] = False, -) -> Dict[str, Any]: + force_cached: bool | None = False, +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result_type = query_obj.result_type or query_context.result_type payload = query_context.get_df_payload(query_obj, force_cached=force_cached) @@ -141,7 +141,7 @@ def _get_full( def _get_samples( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) query_obj = copy.copy(query_obj) query_obj.is_timeseries = False @@ -162,7 +162,7 @@ def _get_samples( def _get_drill_detail( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False -) -> Dict[str, Any]: +) -> dict[str, Any]: # todo(yongjie): Remove this function, # when determining whether samples should be applied to the time filter. datasource = _get_datasource(query_context, query_obj) @@ -183,13 +183,13 @@ def _get_drill_detail( def _get_results( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False -) -> Dict[str, Any]: +) -> dict[str, Any]: payload = _get_full(query_context, query_obj, force_cached) return payload -_result_type_functions: Dict[ - ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]] +_result_type_functions: dict[ + ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]] ] = { ChartDataResultType.COLUMNS: _get_columns, ChartDataResultType.TIMEGRAINS: _get_timegrains, @@ -210,7 +210,7 @@ def get_query_results( query_context: QueryContext, query_obj: QueryObject, force_cached: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Return result payload for a chart data request. diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 78eb8800c4d6d..1a8d3c518b07a 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, ClassVar, TYPE_CHECKING import pandas as pd @@ -47,15 +47,15 @@ class QueryContext: enforce_numerical_metrics: ClassVar[bool] = True datasource: BaseDatasource - slice_: Optional[Slice] = None - queries: List[QueryObject] - form_data: Optional[Dict[str, Any]] + slice_: Slice | None = None + queries: list[QueryObject] + form_data: dict[str, Any] | None result_type: ChartDataResultType result_format: ChartDataResultFormat force: bool - custom_cache_timeout: Optional[int] + custom_cache_timeout: int | None - cache_values: Dict[str, Any] + cache_values: dict[str, Any] _processor: QueryContextProcessor @@ -65,14 +65,14 @@ def __init__( self, *, datasource: BaseDatasource, - queries: List[QueryObject], - slice_: Optional[Slice], - form_data: Optional[Dict[str, Any]], + queries: list[QueryObject], + slice_: Slice | None, + form_data: dict[str, Any] | None, result_type: ChartDataResultType, result_format: ChartDataResultFormat, force: bool = False, - custom_cache_timeout: Optional[int] = None, - cache_values: Dict[str, Any], + custom_cache_timeout: int | None = None, + cache_values: dict[str, Any], ) -> None: self.datasource = datasource self.slice_ = slice_ @@ -88,18 +88,18 @@ def __init__( def get_data( self, df: pd.DataFrame, - ) -> Union[str, List[Dict[str, Any]]]: + ) -> str | list[dict[str, Any]]: return self._processor.get_data(df) def get_payload( self, - cache_query_context: Optional[bool] = False, + cache_query_context: bool | None = False, force_cached: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the query results with both metadata and data""" return self._processor.get_payload(cache_query_context, force_cached) - def get_cache_timeout(self) -> Optional[int]: + def get_cache_timeout(self) -> int | None: if self.custom_cache_timeout is not None: return self.custom_cache_timeout if self.slice_ and self.slice_.cache_timeout is not None: @@ -110,14 +110,14 @@ def get_cache_timeout(self) -> Optional[int]: return self.datasource.database.cache_timeout return None - def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: + def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None: return self._processor.query_cache_key(query_obj, **kwargs) def get_df_payload( self, query_obj: QueryObject, - force_cached: Optional[bool] = False, - ) -> Dict[str, Any]: + force_cached: bool | None = False, + ) -> dict[str, Any]: return self._processor.get_df_payload( query_obj=query_obj, force_cached=force_cached, diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 84c0415722c99..62018def8db24 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from superset import app, db from superset.charts.dao import ChartDAO @@ -48,12 +48,12 @@ def create( self, *, datasource: DatasourceDict, - queries: List[Dict[str, Any]], - form_data: Optional[Dict[str, Any]] = None, - result_type: Optional[ChartDataResultType] = None, - result_format: Optional[ChartDataResultFormat] = None, + queries: list[dict[str, Any]], + form_data: dict[str, Any] | None = None, + result_type: ChartDataResultType | None = None, + result_format: ChartDataResultFormat | None = None, force: bool = False, - custom_cache_timeout: Optional[int] = None, + custom_cache_timeout: int | None = None, ) -> QueryContext: datasource_model_instance = None if datasource: @@ -101,13 +101,13 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: datasource_id=int(datasource["id"]), ) - def _get_slice(self, slice_id: Any) -> Optional[Slice]: + def _get_slice(self, slice_id: Any) -> Slice | None: return ChartDAO.find_by_id(slice_id) def _process_query_object( self, datasource: BaseDatasource, - form_data: Optional[Dict[str, Any]], + form_data: dict[str, Any] | None, query_object: QueryObject, ) -> QueryObject: self._apply_granularity(query_object, form_data, datasource) @@ -117,7 +117,7 @@ def _process_query_object( def _apply_granularity( self, query_object: QueryObject, - form_data: Optional[Dict[str, Any]], + form_data: dict[str, Any] | None, datasource: BaseDatasource, ) -> None: temporal_columns = { diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 85a2b5d97ae72..ecb8db4246f6b 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -19,7 +19,7 @@ import copy import logging import re -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, ClassVar, TYPE_CHECKING import numpy as np import pandas as pd @@ -77,8 +77,8 @@ class CachedTimeOffset(TypedDict): df: pd.DataFrame - queries: List[str] - cache_keys: List[Optional[str]] + queries: list[str] + cache_keys: list[str | None] class QueryContextProcessor: @@ -102,8 +102,8 @@ def __init__(self, query_context: QueryContext): enforce_numerical_metrics: ClassVar[bool] = True def get_df_payload( - self, query_obj: QueryObject, force_cached: Optional[bool] = False - ) -> Dict[str, Any]: + self, query_obj: QueryObject, force_cached: bool | None = False + ) -> dict[str, Any]: """Handles caching around the df payload retrieval""" cache_key = self.query_cache_key(query_obj) timeout = self.get_cache_timeout() @@ -181,7 +181,7 @@ def get_df_payload( "label_map": label_map, } - def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: + def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None: """ Returns a QueryObject cache key for objects in self.queries """ @@ -248,8 +248,8 @@ def get_query_result(self, query_object: QueryObject) -> QueryResult: def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: # todo: should support "python_date_format" and "get_column" in each datasource def _get_timestamp_format( - source: BaseDatasource, column: Optional[str] - ) -> Optional[str]: + source: BaseDatasource, column: str | None + ) -> str | None: column_obj = source.get_column(column) if ( column_obj @@ -315,9 +315,9 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme query_context = self._query_context # ensure query_object is immutable query_object_clone = copy.copy(query_object) - queries: List[str] = [] - cache_keys: List[Optional[str]] = [] - rv_dfs: List[pd.DataFrame] = [df] + queries: list[str] = [] + cache_keys: list[str | None] = [] + rv_dfs: list[pd.DataFrame] = [df] time_offsets = query_object.time_offsets outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object) @@ -449,7 +449,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys) - def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]: + def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]: if self._query_context.result_format in ChartDataResultFormat.table_like(): include_index = not isinstance(df.index, pd.RangeIndex) columns = list(df.columns) @@ -470,9 +470,9 @@ def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]: def get_payload( self, - cache_query_context: Optional[bool] = False, + cache_query_context: bool | None = False, force_cached: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the query results with both metadata and data""" # Get all the payloads from the QueryObjects @@ -522,13 +522,13 @@ def cache_key(self, **extra: Any) -> str: return generate_cache_key(cache_dict, key_prefix) - def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]: + def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]: """ :param query_context: :param query_obj: :return: """ - annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj) + annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj) for annotation_layer in [ layer for layer in query_obj.annotation_layers @@ -541,7 +541,7 @@ def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]: return annotation_data @staticmethod - def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]: + def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]: annotation_data = {} annotation_layers = [ layer @@ -576,8 +576,8 @@ def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]: @staticmethod def get_viz_annotation_data( - annotation_layer: Dict[str, Any], force: bool - ) -> Dict[str, Any]: + annotation_layer: dict[str, Any], force: bool + ) -> dict[str, Any]: chart = ChartDAO.find_by_id(annotation_layer["value"]) if not chart: raise QueryObjectValidationError(_("The chart does not exist")) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 802a1eed5b423..dc02b774e56be 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -21,7 +21,7 @@ import logging from datetime import datetime from pprint import pformat -from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING +from typing import Any, NamedTuple, TYPE_CHECKING from flask import g from flask_babel import gettext as _ @@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes and druid. The query objects are constructed on the client. """ - annotation_layers: List[Dict[str, Any]] - applied_time_extras: Dict[str, str] + annotation_layers: list[dict[str, Any]] + applied_time_extras: dict[str, str] apply_fetch_values_predicate: bool - columns: List[Column] - datasource: Optional[BaseDatasource] - extras: Dict[str, Any] - filter: List[QueryObjectFilterClause] - from_dttm: Optional[datetime] - granularity: Optional[str] - inner_from_dttm: Optional[datetime] - inner_to_dttm: Optional[datetime] + columns: list[Column] + datasource: BaseDatasource | None + extras: dict[str, Any] + filter: list[QueryObjectFilterClause] + from_dttm: datetime | None + granularity: str | None + inner_from_dttm: datetime | None + inner_to_dttm: datetime | None is_rowcount: bool is_timeseries: bool - metrics: Optional[List[Metric]] + metrics: list[Metric] | None order_desc: bool - orderby: List[OrderBy] - post_processing: List[Dict[str, Any]] - result_type: Optional[ChartDataResultType] - row_limit: Optional[int] + orderby: list[OrderBy] + post_processing: list[dict[str, Any]] + result_type: ChartDataResultType | None + row_limit: int | None row_offset: int - series_columns: List[Column] + series_columns: list[Column] series_limit: int - series_limit_metric: Optional[Metric] - time_offsets: List[str] - time_shift: Optional[str] - time_range: Optional[str] - to_dttm: Optional[datetime] + series_limit_metric: Metric | None + time_offsets: list[str] + time_shift: str | None + time_range: str | None + to_dttm: datetime | None def __init__( # pylint: disable=too-many-locals self, *, - annotation_layers: Optional[List[Dict[str, Any]]] = None, - applied_time_extras: Optional[Dict[str, str]] = None, + annotation_layers: list[dict[str, Any]] | None = None, + applied_time_extras: dict[str, str] | None = None, apply_fetch_values_predicate: bool = False, - columns: Optional[List[Column]] = None, - datasource: Optional[BaseDatasource] = None, - extras: Optional[Dict[str, Any]] = None, - filters: Optional[List[QueryObjectFilterClause]] = None, - granularity: Optional[str] = None, + columns: list[Column] | None = None, + datasource: BaseDatasource | None = None, + extras: dict[str, Any] | None = None, + filters: list[QueryObjectFilterClause] | None = None, + granularity: str | None = None, is_rowcount: bool = False, - is_timeseries: Optional[bool] = None, - metrics: Optional[List[Metric]] = None, + is_timeseries: bool | None = None, + metrics: list[Metric] | None = None, order_desc: bool = True, - orderby: Optional[List[OrderBy]] = None, - post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, - row_limit: Optional[int], - row_offset: Optional[int] = None, - series_columns: Optional[List[Column]] = None, + orderby: list[OrderBy] | None = None, + post_processing: list[dict[str, Any] | None] | None = None, + row_limit: int | None, + row_offset: int | None = None, + series_columns: list[Column] | None = None, series_limit: int = 0, - series_limit_metric: Optional[Metric] = None, - time_range: Optional[str] = None, - time_shift: Optional[str] = None, + series_limit_metric: Metric | None = None, + time_range: str | None = None, + time_shift: str | None = None, **kwargs: Any, ): self._set_annotation_layers(annotation_layers) @@ -166,7 +166,7 @@ def __init__( # pylint: disable=too-many-locals self._move_deprecated_extra_fields(kwargs) def _set_annotation_layers( - self, annotation_layers: Optional[List[Dict[str, Any]]] + self, annotation_layers: list[dict[str, Any]] | None ) -> None: self.annotation_layers = [ layer @@ -175,14 +175,14 @@ def _set_annotation_layers( if layer["annotationType"] != "FORMULA" ] - def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None: + def _set_is_timeseries(self, is_timeseries: bool | None) -> None: # is_timeseries is True if time column is in either columns or groupby # (both are dimensions) self.is_timeseries = ( is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns ) - def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None: + def _set_metrics(self, metrics: list[Metric] | None = None) -> None: # Support metric reference/definition in the format of # 1. 'metric_name' - name of predefined metric # 2. { label: 'label_name' } - legacy format for a predefined metric @@ -195,16 +195,16 @@ def is_str_or_adhoc(metric: Metric) -> bool: ] def _set_post_processing( - self, post_processing: Optional[List[Optional[Dict[str, Any]]]] + self, post_processing: list[dict[str, Any] | None] | None ) -> None: post_processing = post_processing or [] self.post_processing = [post_proc for post_proc in post_processing if post_proc] def _init_series_columns( self, - series_columns: Optional[List[Column]], - metrics: Optional[List[Metric]], - is_timeseries: Optional[bool], + series_columns: list[Column] | None, + metrics: list[Metric] | None, + is_timeseries: bool | None, ) -> None: if series_columns: self.series_columns = series_columns @@ -213,7 +213,7 @@ def _init_series_columns( else: self.series_columns = [] - def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None: + def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None: # rename deprecated fields for field in DEPRECATED_FIELDS: if field.old_name in kwargs: @@ -233,7 +233,7 @@ def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None: ) setattr(self, field.new_name, value) - def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None: + def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None: # move deprecated extras fields to extras for field in DEPRECATED_EXTRAS_FIELDS: if field.old_name in kwargs: @@ -256,19 +256,19 @@ def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None: self.extras[field.new_name] = value @property - def metric_names(self) -> List[str]: + def metric_names(self) -> list[str]: """Return metrics names (labels), coerce adhoc metrics to strings.""" return get_metric_names(self.metrics or []) @property - def column_names(self) -> List[str]: + def column_names(self) -> list[str]: """Return column names (labels). Gives priority to groupbys if both groupbys and metrics are non-empty, otherwise returns column labels.""" return get_column_names(self.columns) def validate( - self, raise_exceptions: Optional[bool] = True - ) -> Optional[QueryObjectValidationError]: + self, raise_exceptions: bool | None = True + ) -> QueryObjectValidationError | None: """Validate query object""" try: self._validate_there_are_no_missing_series() @@ -314,7 +314,7 @@ def _validate_there_are_no_missing_series(self) -> None: ) ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: query_object_dict = { "apply_fetch_values_predicate": self.apply_fetch_values_predicate, "columns": self.columns, diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index 88cc7ca1b461b..5676dc9eda5ff 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from superset.common.chart_data import ChartDataResultType from superset.common.query_object import QueryObject @@ -31,13 +31,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods - _config: Dict[str, Any] + _config: dict[str, Any] _datasource_dao: DatasourceDAO _session_maker: sessionmaker def __init__( self, - app_configurations: Dict[str, Any], + app_configurations: dict[str, Any], _datasource_dao: DatasourceDAO, session_maker: sessionmaker, ): @@ -48,11 +48,11 @@ def __init__( def create( # pylint: disable=too-many-arguments self, parent_result_type: ChartDataResultType, - datasource: Optional[DatasourceDict] = None, - extras: Optional[Dict[str, Any]] = None, - row_limit: Optional[int] = None, - time_range: Optional[str] = None, - time_shift: Optional[str] = None, + datasource: DatasourceDict | None = None, + extras: dict[str, Any] | None = None, + row_limit: int | None = None, + time_range: str | None = None, + time_shift: str | None = None, **kwargs: Any, ) -> QueryObject: datasource_model_instance = None @@ -84,13 +84,13 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: def _process_extras( # pylint: disable=no-self-use self, - extras: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: + extras: dict[str, Any] | None, + ) -> dict[str, Any]: extras = extras or {} return extras def _process_row_limit( - self, row_limit: Optional[int], result_type: ChartDataResultType + self, row_limit: int | None, result_type: ChartDataResultType ) -> int: default_row_limit = ( self._config["SAMPLES_ROW_LIMIT"] diff --git a/superset/common/tags.py b/superset/common/tags.py index 706192913a1c3..6066d0eec7978 100644 --- a/superset/common/tags.py +++ b/superset/common/tags.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any from sqlalchemy import MetaData from sqlalchemy.exc import IntegrityError @@ -25,7 +25,7 @@ def add_types_to_charts( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: slices = metadata.tables["slices"] @@ -57,7 +57,7 @@ def add_types_to_charts( def add_types_to_dashboards( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: dashboard_table = metadata.tables["dashboards"] @@ -89,7 +89,7 @@ def add_types_to_dashboards( def add_types_to_saved_queries( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: saved_query = metadata.tables["saved_query"] @@ -121,7 +121,7 @@ def add_types_to_saved_queries( def add_types_to_datasets( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: tables = metadata.tables["tables"] @@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None: def add_owners_to_charts( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: slices = metadata.tables["slices"] @@ -273,7 +273,7 @@ def add_owners_to_charts( def add_owners_to_dashboards( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: dashboard_table = metadata.tables["dashboards"] @@ -309,7 +309,7 @@ def add_owners_to_dashboards( def add_owners_to_saved_queries( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: saved_query = metadata.tables["saved_query"] @@ -345,7 +345,7 @@ def add_owners_to_saved_queries( def add_owners_to_datasets( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: tables = metadata.tables["tables"] diff --git a/superset/common/utils/dataframe_utils.py b/superset/common/utils/dataframe_utils.py index 4dd62e3b5d886..a3421f6431061 100644 --- a/superset/common/utils/dataframe_utils.py +++ b/superset/common/utils/dataframe_utils.py @@ -17,7 +17,7 @@ from __future__ import annotations import datetime -from typing import Any, List, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import numpy as np import pandas as pd @@ -29,7 +29,7 @@ def left_join_df( left_df: pd.DataFrame, right_df: pd.DataFrame, - join_keys: List[str], + join_keys: list[str], ) -> pd.DataFrame: df = left_df.set_index(join_keys).join(right_df.set_index(join_keys)) df.reset_index(inplace=True) diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 6c1b268f46534..a0fb65b20d4cb 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +from typing import Any from flask_caching import Cache from pandas import DataFrame @@ -37,7 +37,7 @@ stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) -_cache: Dict[CacheRegion, Cache] = { +_cache: dict[CacheRegion, Cache] = { CacheRegion.DEFAULT: cache_manager.cache, CacheRegion.DATA: cache_manager.data_cache, } @@ -53,17 +53,17 @@ def __init__( self, df: DataFrame = DataFrame(), query: str = "", - annotation_data: Optional[Dict[str, Any]] = None, - applied_template_filters: Optional[List[str]] = None, - applied_filter_columns: Optional[List[Column]] = None, - rejected_filter_columns: Optional[List[Column]] = None, - status: Optional[str] = None, - error_message: Optional[str] = None, + annotation_data: dict[str, Any] | None = None, + applied_template_filters: list[str] | None = None, + applied_filter_columns: list[Column] | None = None, + rejected_filter_columns: list[Column] | None = None, + status: str | None = None, + error_message: str | None = None, is_loaded: bool = False, - stacktrace: Optional[str] = None, - is_cached: Optional[bool] = None, - cache_dttm: Optional[str] = None, - cache_value: Optional[Dict[str, Any]] = None, + stacktrace: str | None = None, + is_cached: bool | None = None, + cache_dttm: str | None = None, + cache_value: dict[str, Any] | None = None, ) -> None: self.df = df self.query = query @@ -85,10 +85,10 @@ def set_query_result( self, key: str, query_result: QueryResult, - annotation_data: Optional[Dict[str, Any]] = None, - force_query: Optional[bool] = False, - timeout: Optional[int] = None, - datasource_uid: Optional[str] = None, + annotation_data: dict[str, Any] | None = None, + force_query: bool | None = False, + timeout: int | None = None, + datasource_uid: str | None = None, region: CacheRegion = CacheRegion.DEFAULT, ) -> None: """ @@ -136,11 +136,11 @@ def set_query_result( @classmethod def get( cls, - key: Optional[str], + key: str | None, region: CacheRegion = CacheRegion.DEFAULT, - force_query: Optional[bool] = False, - force_cached: Optional[bool] = False, - ) -> "QueryCacheManager": + force_query: bool | None = False, + force_cached: bool | None = False, + ) -> QueryCacheManager: """ Initialize QueryCacheManager by query-cache key """ @@ -190,10 +190,10 @@ def get( @staticmethod def set( - key: Optional[str], - value: Dict[str, Any], - timeout: Optional[int] = None, - datasource_uid: Optional[str] = None, + key: str | None, + value: dict[str, Any], + timeout: int | None = None, + datasource_uid: str | None = None, region: CacheRegion = CacheRegion.DEFAULT, ) -> None: """ @@ -204,7 +204,7 @@ def set( @staticmethod def delete( - key: Optional[str], + key: str | None, region: CacheRegion = CacheRegion.DEFAULT, ) -> None: if key: @@ -212,7 +212,7 @@ def delete( @staticmethod def has( - key: Optional[str], + key: str | None, region: CacheRegion = CacheRegion.DEFAULT, ) -> bool: return bool(_cache[region].get(key)) if key else False diff --git a/superset/common/utils/time_range_utils.py b/superset/common/utils/time_range_utils.py index fa6a5244b244b..5f9139c0474c2 100644 --- a/superset/common/utils/time_range_utils.py +++ b/superset/common/utils/time_range_utils.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, cast, Dict, Optional, Tuple +from typing import Any, cast from superset import app from superset.common.query_object import QueryObject @@ -26,10 +26,10 @@ def get_since_until_from_time_range( - time_range: Optional[str] = None, - time_shift: Optional[str] = None, - extras: Optional[Dict[str, Any]] = None, -) -> Tuple[Optional[datetime], Optional[datetime]]: + time_range: str | None = None, + time_shift: str | None = None, + extras: dict[str, Any] | None = None, +) -> tuple[datetime | None, datetime | None]: return get_since_until( relative_start=(extras or {}).get( "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"] @@ -45,7 +45,7 @@ def get_since_until_from_time_range( # pylint: disable=invalid-name def get_since_until_from_query_object( query_object: QueryObject, -) -> Tuple[Optional[datetime], Optional[datetime]]: +) -> tuple[datetime | None, datetime | None]: """ this function will return since and until by tuple if 1) the time_range is in the query object. diff --git a/superset/config.py b/superset/config.py index 7d9359d14fb0f..434456386d932 100644 --- a/superset/config.py +++ b/superset/config.py @@ -33,20 +33,7 @@ from collections import OrderedDict from datetime import timedelta from email.mime.multipart import MIMEMultipart -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - TypedDict, - Union, -) +from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict import pkg_resources from cachelib.base import BaseCache @@ -114,17 +101,17 @@ FAVICONS = [{"href": "/static/assets/images/favicon.png"}] -def _try_json_readversion(filepath: str) -> Optional[str]: +def _try_json_readversion(filepath: str) -> str | None: try: - with open(filepath, "r") as f: + with open(filepath) as f: return json.load(f).get("version") except Exception: # pylint: disable=broad-except return None -def _try_json_readsha(filepath: str, length: int) -> Optional[str]: +def _try_json_readsha(filepath: str, length: int) -> str | None: try: - with open(filepath, "r") as f: + with open(filepath) as f: return json.load(f).get("GIT_SHA")[:length] except Exception: # pylint: disable=broad-except return None @@ -275,7 +262,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1} # Configuration for scheduling queries from SQL Lab. -SCHEDULED_QUERIES: Dict[str, Any] = {} +SCHEDULED_QUERIES: dict[str, Any] = {} # ------------------------------ # GLOBALS FOR APP Builder @@ -294,7 +281,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: LOGO_TOOLTIP = "" # Specify any text that should appear to the right of the logo -LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = "" +LOGO_RIGHT_TEXT: Callable[[], str] | str = "" # Enables SWAGGER UI for superset openapi spec # ex: http://localhost:8080/swagger/v1 @@ -347,7 +334,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # Grant public role the same set of permissions as for a selected builtin role. # This is useful if one wants to enable anonymous users to view # dashboards. Explicit grant on specific datasets is still required. -PUBLIC_ROLE_LIKE: Optional[str] = None +PUBLIC_ROLE_LIKE: str | None = None # --------------------------------------------------- # Babel config for translations @@ -390,8 +377,8 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: class D3Format(TypedDict, total=False): decimal: str thousands: str - grouping: List[int] - currency: List[str] + grouping: list[int] + currency: list[str] D3_FORMAT: D3Format = {} @@ -404,7 +391,7 @@ class D3Format(TypedDict, total=False): # For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here # and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py # will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True } -DEFAULT_FEATURE_FLAGS: Dict[str, bool] = { +DEFAULT_FEATURE_FLAGS: dict[str, bool] = { # Experimental feature introducing a client (browser) cache "CLIENT_CACHE": False, # deprecated "DISABLE_DATASET_SOURCE_EDIT": False, # deprecated @@ -527,7 +514,7 @@ class D3Format(TypedDict, total=False): ) # This is merely a default. -FEATURE_FLAGS: Dict[str, bool] = {} +FEATURE_FLAGS: dict[str, bool] = {} # A function that receives a dict of all feature flags # (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS) @@ -543,7 +530,7 @@ class D3Format(TypedDict, total=False): # if hasattr(g, "user") and g.user.is_active: # feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5 # return feature_flags_dict -GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None +GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None # A function that receives a feature flag name and an optional default value. # Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the # evaluation of all feature flags when just evaluating a single one. @@ -551,7 +538,7 @@ class D3Format(TypedDict, total=False): # Note that the default `get_feature_flags` will evaluate each feature with this # callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC # and IS_FEATURE_ENABLED_FUNC in conjunction. -IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None +IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None # A function that expands/overrides the frontend `bootstrap_data.common` object. # Can be used to implement custom frontend functionality, # or dynamically change certain configs. @@ -563,7 +550,7 @@ class D3Format(TypedDict, total=False): # Takes as a parameter the common bootstrap payload before transformations. # Returns a dict containing data that should be added or overridden to the payload. COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ - [Dict[str, Any]], Dict[str, Any] + [dict[str, Any]], dict[str, Any] ] = lambda data: {} # default: empty dict # EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes @@ -580,7 +567,7 @@ class D3Format(TypedDict, total=False): # }] # This is merely a default -EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = [] +EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = [] # THEME_OVERRIDES is used for adding custom theme to superset # example code for "My theme" custom scheme @@ -599,7 +586,7 @@ class D3Format(TypedDict, total=False): # } # } -THEME_OVERRIDES: Dict[str, Any] = {} +THEME_OVERRIDES: dict[str, Any] = {} # EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes # EXTRA_SEQUENTIAL_COLOR_SCHEMES = [ @@ -615,7 +602,7 @@ class D3Format(TypedDict, total=False): # }] # This is merely a default -EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = [] +EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = [] # --------------------------------------------------- # Thumbnail config (behind feature flag) @@ -626,7 +613,7 @@ class D3Format(TypedDict, total=False): # `superset.tasks.types.ExecutorType` for a full list of executor options. # To always use a fixed user account, use the following configuration: # THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM] -THUMBNAIL_SELENIUM_USER: Optional[str] = "admin" +THUMBNAIL_SELENIUM_USER: str | None = "admin" THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM] # By default, thumbnail digests are calculated based on various parameters in the @@ -639,10 +626,10 @@ class D3Format(TypedDict, total=False): # `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in # user if the executor type is equal to `ExecutorType.CURRENT_USER`) # and return the final digest string: -THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[ +THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | ( Callable[[Dashboard, ExecutorType, str], str] -] = None -THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None +) = None +THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None THUMBNAIL_CACHE_CONFIG: CacheConfig = { "CACHE_TYPE": "NullCache", @@ -714,7 +701,7 @@ class D3Format(TypedDict, total=False): # CORS Options ENABLE_CORS = False -CORS_OPTIONS: Dict[Any, Any] = {} +CORS_OPTIONS: dict[Any, Any] = {} # Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner. # Disabling this option is not recommended for security reasons. If you wish to allow @@ -736,7 +723,7 @@ class D3Format(TypedDict, total=False): # } # } # Be careful when extending the default schema to avoid XSS attacks. -HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {} +HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {} # Chrome allows up to 6 open connections per domain at a time. When there are more # than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for @@ -768,13 +755,13 @@ class D3Format(TypedDict, total=False): # time grains in superset/db_engine_specs/base.py). # For example: to disable 1 second time grain: # TIME_GRAIN_DENYLIST = ['PT1S'] -TIME_GRAIN_DENYLIST: List[str] = [] +TIME_GRAIN_DENYLIST: list[str] = [] # Additional time grains to be supported using similar definitions as in # superset/db_engine_specs/base.py. # For example: To add a new 2 second time grain: # TIME_GRAIN_ADDONS = {'PT2S': '2 second'} -TIME_GRAIN_ADDONS: Dict[str, str] = {} +TIME_GRAIN_ADDONS: dict[str, str] = {} # Implementation of additional time grains per engine. # The column to be truncated is denoted `{col}` in the expression. @@ -784,7 +771,7 @@ class D3Format(TypedDict, total=False): # 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)' # } # } -TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {} +TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {} # --------------------------------------------------- # List of viz_types not allowed in your environment @@ -792,7 +779,7 @@ class D3Format(TypedDict, total=False): # VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap'] # --------------------------------------------------- -VIZ_TYPE_DENYLIST: List[str] = [] +VIZ_TYPE_DENYLIST: list[str] = [] # -------------------------------------------------- # Modules, datasources and middleware to be registered @@ -802,8 +789,8 @@ class D3Format(TypedDict, total=False): ("superset.connectors.sqla.models", ["SqlaTable"]), ] ) -ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {} -ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = [] +ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {} +ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = [] # 1) https://docs.python-guide.org/writing/logging/ # 2) https://docs.python.org/2/library/logging.config.html @@ -925,9 +912,9 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # within the app # OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will # override anything set within the app -DEFAULT_HTTP_HEADERS: Dict[str, Any] = {} -OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {} -HTTP_HEADERS: Dict[str, Any] = {} +DEFAULT_HTTP_HEADERS: dict[str, Any] = {} +OVERRIDE_HTTP_HEADERS: dict[str, Any] = {} +HTTP_HEADERS: dict[str, Any] = {} # The db id here results in selecting this one as a default in SQL Lab DEFAULT_DB_ID = None @@ -974,8 +961,8 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # return out # # QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter} -QUERY_COST_FORMATTERS_BY_ENGINE: Dict[ - str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]] +QUERY_COST_FORMATTERS_BY_ENGINE: dict[ + str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]] ] = {} # Flag that controls if limit should be enforced on the CTA (create table as queries). @@ -1000,13 +987,13 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # else: # return f'tmp_{schema}' # Function accepts database object, user object, schema name and sql that will be run. -SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[ +SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | ( Callable[[Database, models.User, str, str], str] -] = None +) = None # If enabled, it can be used to store the results of long-running queries # in SQL Lab by using the "Run Async" button/feature -RESULTS_BACKEND: Optional[BaseCache] = None +RESULTS_BACKEND: BaseCache | None = None # Use PyArrow and MessagePack for async query results serialization, # rather than JSON. This feature requires additional testing from the @@ -1028,7 +1015,7 @@ class CeleryConfig: # pylint: disable=too-few-public-methods def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name database: Database, user: models.User, # pylint: disable=unused-argument - schema: Optional[str], + schema: str | None, ) -> str: # Note the final empty path enforces a trailing slash. return os.path.join( @@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # The namespace within hive where the tables created from # uploading CSVs will be stored. -UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None +UPLOADED_CSV_HIVE_NAMESPACE: str | None = None # Function that computes the allowed schemas for the CSV uploads. # Allowed schemas will be a union of schemas_allowed_for_file_upload # db configuration and a result of this function. # mypy doesn't catch that if case ensures list content being always str -ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = ( +ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else [] @@ -1062,7 +1049,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # It's important to make sure that the objects exposed (as well as objects attached # to those objets) are harmless. We recommend only exposing simple/pure functions that # return native types. -JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {} +JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {} # A dictionary of macro template processors (by engine) that gets merged into global # template processors. The existing template processors get updated with this @@ -1070,7 +1057,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # dictionary. The customized addons don't necessarily need to use Jinja templating # language. This allows you to define custom logic to process templates on a per-engine # basis. Example value = `{"presto": CustomPrestoTemplateProcessor}` -CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {} +CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {} # Roles that are controlled by the API / Superset and should not be changes # by humans. @@ -1125,7 +1112,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # Integrate external Blueprints to the app by passing them to your # configuration. These blueprints will get integrated in the app -BLUEPRINTS: List[Blueprint] = [] +BLUEPRINTS: list[Blueprint] = [] # Provide a callable that receives a tracking_url and returns another # URL. This is used to translate internal Hadoop job tracker URL @@ -1142,7 +1129,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # customize the polling time of each engine -DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {} +DB_POLL_INTERVAL_SECONDS: dict[str, int] = {} # Interval between consecutive polls when using Presto Engine # See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression @@ -1159,7 +1146,7 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # "another_auth_method": auth_method, # }, # } -ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {} +ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {} # The id of a template dashboard that should be copied to every new user DASHBOARD_TEMPLATE_ID = None @@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # Owners, filters for created_by, etc. # The users can also be excluded by overriding the get_exclude_users_from_lists method # in security manager -EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None +EXCLUDE_USERS_FROM_LISTS: list[str] | None = None # For database connections, this dictionary will remove engines from the available # list/dropdown if you do not want these dbs to show as available. # The available list is generated by driver installed, and some engines have multiple # drivers. # e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}} -DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {} +DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {} # This auth provider is used by background (offline) tasks that need to access # protected resources. Can be overridden by end users in order to support @@ -1261,7 +1248,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # ExecutorType.OWNER, # ExecutorType.SELENIUM, # ] -ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER] +ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER] # if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout # Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds()) @@ -1286,7 +1273,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument EMAIL_REPORTS_CTA = "Explore in Superset" # Slack API token for the superset reports, either string or callable -SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None +SLACK_API_TOKEN: Callable[[], str] | str | None = None SLACK_PROXY = None # The webdriver to use for generating reports. Use one of the following @@ -1310,7 +1297,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument WEBDRIVER_AUTH_FUNC = None # Any config options to be passed as-is to the webdriver -WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"} +WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"} # Additional args to be passed as arguments to the config object # Note: If using Chrome, you'll want to add the "--marionette" arg. @@ -1353,7 +1340,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # displayed prominently in the "Add Database" dialog. You should # use the "engine_name" attribute of the corresponding DB engine spec # in `superset/db_engine_specs/`. -PREFERRED_DATABASES: List[str] = [ +PREFERRED_DATABASES: list[str] = [ "PostgreSQL", "Presto", "MySQL", @@ -1386,7 +1373,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS? SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls? -SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax" +SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax" # Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection SESSION_PROTECTION = "strong" @@ -1418,7 +1405,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # Path used to store SSL certificates that are generated when using custom certs. # Defaults to temporary directory. # Example: SSL_CERT_PATH = "/certs" -SSL_CERT_PATH: Optional[str] = None +SSL_CERT_PATH: str | None = None # SQLA table mutator, every time we fetch the metadata for a certain table # (superset.connectors.sqla.models.SqlaTable), we call this hook @@ -1443,9 +1430,9 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000 GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token" GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False -GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[ +GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | ( Literal["None", "Lax", "Strict"] -] = None +) = None GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling" @@ -1461,7 +1448,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument GUEST_TOKEN_HEADER_NAME = "X-GuestToken" GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes # Guest token audience for the embedded superset, either string or callable -GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None +GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None # A SQL dataset health check. Note if enabled it is strongly advised that the callable # be memoized to aid with performance, i.e., @@ -1492,7 +1479,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # cache_manager.cache.delete_memoized(func) # cache_manager.cache.set(name, code, timeout=0) # -DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None +DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None # Do not show user info or profile in the menu MENU_HIDE_USER_INFO = False @@ -1502,7 +1489,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument ENABLE_BROAD_ACTIVITY_ACCESS = True # the advanced data type key should correspond to that set in the column metadata -ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = { +ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = { "internet_address": internet_address, "port": internet_port, } @@ -1514,9 +1501,9 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # "Xyz", # [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}], # ) -WELCOME_PAGE_LAST_TAB: Union[ - Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]] -] = "all" +WELCOME_PAGE_LAST_TAB: ( + Literal["examples", "all"] | tuple[str, list[dict[str, Any]]] +) = "all" # Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag. # 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 2cb0d54c51537..d43d07863902b 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -16,21 +16,12 @@ # under the License. from __future__ import annotations +import builtins import json +from collections.abc import Hashable from datetime import datetime from enum import Enum -from typing import ( - Any, - Dict, - Hashable, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, TYPE_CHECKING from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __ @@ -89,23 +80,23 @@ class BaseDatasource( # --------------------------------------------------------------- # class attributes to define when deriving BaseDatasource # --------------------------------------------------------------- - __tablename__: Optional[str] = None # {connector_name}_datasource - baselink: Optional[str] = None # url portion pointing to ModelView endpoint + __tablename__: str | None = None # {connector_name}_datasource + baselink: str | None = None # url portion pointing to ModelView endpoint @property - def column_class(self) -> Type["BaseColumn"]: + def column_class(self) -> type[BaseColumn]: # link to derivative of BaseColumn raise NotImplementedError() @property - def metric_class(self) -> Type["BaseMetric"]: + def metric_class(self) -> type[BaseMetric]: # link to derivative of BaseMetric raise NotImplementedError() - owner_class: Optional[User] = None + owner_class: User | None = None # Used to do code highlighting when displaying the query in the UI - query_language: Optional[str] = None + query_language: str | None = None # Only some datasources support Row Level Security is_rls_supported: bool = False @@ -131,9 +122,9 @@ def name(self) -> str: is_managed_externally = Column(Boolean, nullable=False, default=False) external_url = Column(Text, nullable=True) - sql: Optional[str] = None - owners: List[User] - update_from_object_fields: List[str] + sql: str | None = None + owners: list[User] + update_from_object_fields: list[str] extra_import_fields = ["is_managed_externally", "external_url"] @@ -142,7 +133,7 @@ def kind(self) -> DatasourceKind: return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL @property - def owners_data(self) -> List[Dict[str, Any]]: + def owners_data(self) -> list[dict[str, Any]]: return [ { "first_name": o.first_name, @@ -167,8 +158,8 @@ def slices(self) -> RelationshipProperty: ), ) - columns: List["BaseColumn"] = [] - metrics: List["BaseMetric"] = [] + columns: list[BaseColumn] = [] + metrics: list[BaseMetric] = [] @property def type(self) -> str: @@ -180,11 +171,11 @@ def uid(self) -> str: return f"{self.id}__{self.type}" @property - def column_names(self) -> List[str]: + def column_names(self) -> list[str]: return sorted([c.column_name for c in self.columns], key=lambda x: x or "") @property - def columns_types(self) -> Dict[str, str]: + def columns_types(self) -> dict[str, str]: return {c.column_name: c.type for c in self.columns} @property @@ -196,26 +187,26 @@ def datasource_name(self) -> str: raise NotImplementedError() @property - def connection(self) -> Optional[str]: + def connection(self) -> str | None: """String representing the context of the Datasource""" return None @property - def schema(self) -> Optional[str]: + def schema(self) -> str | None: """String representing the schema of the Datasource (if it applies)""" return None @property - def filterable_column_names(self) -> List[str]: + def filterable_column_names(self) -> list[str]: return sorted([c.column_name for c in self.columns if c.filterable]) @property - def dttm_cols(self) -> List[str]: + def dttm_cols(self) -> list[str]: return [] @property def url(self) -> str: - return "/{}/edit/{}".format(self.baselink, self.id) + return f"/{self.baselink}/edit/{self.id}" @property def explore_url(self) -> str: @@ -224,10 +215,10 @@ def explore_url(self) -> str: return f"/explore/?datasource_type={self.type}&datasource_id={self.id}" @property - def column_formats(self) -> Dict[str, Optional[str]]: + def column_formats(self) -> dict[str, str | None]: return {m.metric_name: m.d3format for m in self.metrics if m.d3format} - def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None: + def add_missing_metrics(self, metrics: list[BaseMetric]) -> None: existing_metrics = {m.metric_name for m in self.metrics} for metric in metrics: if metric.metric_name not in existing_metrics: @@ -235,7 +226,7 @@ def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None: self.metrics.append(metric) @property - def short_data(self) -> Dict[str, Any]: + def short_data(self) -> dict[str, Any]: """Data representation of the datasource sent to the frontend""" return { "edit_url": self.url, @@ -249,11 +240,11 @@ def short_data(self) -> Dict[str, Any]: } @property - def select_star(self) -> Optional[str]: + def select_star(self) -> str | None: pass @property - def order_by_choices(self) -> List[Tuple[str, str]]: + def order_by_choices(self) -> list[tuple[str, str]]: choices = [] # self.column_names return sorted column_names for column_name in self.column_names: @@ -267,7 +258,7 @@ def order_by_choices(self) -> List[Tuple[str, str]]: return choices @property - def verbose_map(self) -> Dict[str, str]: + def verbose_map(self) -> dict[str, str]: verb_map = {"__timestamp": "Time"} verb_map.update( {o.metric_name: o.verbose_name or o.metric_name for o in self.metrics} @@ -278,7 +269,7 @@ def verbose_map(self) -> Dict[str, str]: return verb_map @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: """Data representation of the datasource sent to the frontend""" return { # simple fields @@ -313,8 +304,8 @@ def data(self) -> Dict[str, Any]: } def data_for_slices( # pylint: disable=too-many-locals - self, slices: List[Slice] - ) -> Dict[str, Any]: + self, slices: list[Slice] + ) -> dict[str, Any]: """ The representation of the datasource containing only the required data to render the provided slices. @@ -381,8 +372,8 @@ def data_for_slices( # pylint: disable=too-many-locals if metric["metric_name"] in metric_names ] - filtered_columns: List[Column] = [] - column_types: Set[GenericDataType] = set() + filtered_columns: list[Column] = [] + column_types: set[GenericDataType] = set() for column in data["columns"]: generic_type = column.get("type_generic") if generic_type is not None: @@ -413,18 +404,18 @@ def data_for_slices( # pylint: disable=too-many-locals @staticmethod def filter_values_handler( # pylint: disable=too-many-arguments - values: Optional[FilterValues], + values: FilterValues | None, operator: str, target_generic_type: GenericDataType, - target_native_type: Optional[str] = None, + target_native_type: str | None = None, is_list_target: bool = False, - db_engine_spec: Optional[Type[BaseEngineSpec]] = None, - db_extra: Optional[Dict[str, Any]] = None, - ) -> Optional[FilterValues]: + db_engine_spec: builtins.type[BaseEngineSpec] | None = None, + db_extra: dict[str, Any] | None = None, + ) -> FilterValues | None: if values is None: return None - def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]: + def handle_single_value(value: FilterValue | None) -> FilterValue | None: if operator == utils.FilterOperator.TEMPORAL_RANGE: return value if ( @@ -464,7 +455,7 @@ def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]: values = values[0] if values else None return values - def external_metadata(self) -> List[Dict[str, str]]: + def external_metadata(self) -> list[dict[str, str]]: """Returns column information from the external system""" raise NotImplementedError() @@ -483,7 +474,7 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult: """ raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: """Given a column, returns an iterable of distinct values This is used to populate the dropdown showing a list of @@ -494,7 +485,7 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: def default_query(qry: Query) -> Query: return qry - def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]: + def get_column(self, column_name: str | None) -> BaseColumn | None: if not column_name: return None for col in self.columns: @@ -504,11 +495,11 @@ def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]: @staticmethod def get_fk_many_from_list( - object_list: List[Any], - fkmany: List[Column], - fkmany_class: Type[Union["BaseColumn", "BaseMetric"]], + object_list: list[Any], + fkmany: list[Column], + fkmany_class: builtins.type[BaseColumn | BaseMetric], key_attr: str, - ) -> List[Column]: + ) -> list[Column]: """Update ORM one-to-many list from object list Used for syncing metrics and columns using the same code""" @@ -541,7 +532,7 @@ def get_fk_many_from_list( fkmany += new_fks return fkmany - def update_from_object(self, obj: Dict[str, Any]) -> None: + def update_from_object(self, obj: dict[str, Any]) -> None: """Update datasource from a data structure The UI's table editor crafts a complex data structure that @@ -578,7 +569,7 @@ def update_from_object(self, obj: Dict[str, Any]) -> None: def get_extra_cache_keys( # pylint: disable=no-self-use self, query_obj: QueryObjectDict # pylint: disable=unused-argument - ) -> List[Hashable]: + ) -> list[Hashable]: """If a datasource needs to provide additional keys for calculation of cache keys, those can be provided via this method @@ -607,14 +598,14 @@ def raise_for_access(self) -> None: @classmethod def get_datasource_by_name( cls, session: Session, datasource_name: str, schema: str, database_name: str - ) -> Optional["BaseDatasource"]: + ) -> BaseDatasource | None: raise NotImplementedError() class BaseColumn(AuditMixinNullable, ImportExportMixin): """Interface for column""" - __tablename__: Optional[str] = None # {connector_name}_column + __tablename__: str | None = None # {connector_name}_column id = Column(Integer, primary_key=True) column_name = Column(String(255), nullable=False) @@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin): is_dttm = None # [optional] Set this to support import/export functionality - export_fields: List[Any] = [] + export_fields: list[Any] = [] def __repr__(self) -> str: return str(self.column_name) @@ -666,7 +657,7 @@ def is_boolean(self) -> bool: return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types)) @property - def type_generic(self) -> Optional[utils.GenericDataType]: + def type_generic(self) -> utils.GenericDataType | None: if self.is_string: return utils.GenericDataType.STRING if self.is_boolean: @@ -686,7 +677,7 @@ def python_date_format(self) -> Column: raise NotImplementedError() @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "id", "column_name", @@ -705,7 +696,7 @@ def data(self) -> Dict[str, Any]: class BaseMetric(AuditMixinNullable, ImportExportMixin): """Interface for Metrics""" - __tablename__: Optional[str] = None # {connector_name}_metric + __tablename__: str | None = None # {connector_name}_metric id = Column(Integer, primary_key=True) metric_name = Column(String(255), nullable=False) @@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin): """ @property - def perm(self) -> Optional[str]: + def perm(self) -> str | None: raise NotImplementedError() @property @@ -738,7 +729,7 @@ def expression(self) -> Column: raise NotImplementedError() @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "id", "metric_name", diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8833d6f6cb561..41a9c89757891 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -22,21 +22,10 @@ import logging import re from collections import defaultdict +from collections.abc import Hashable from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import ( - Any, - Callable, - cast, - Dict, - Hashable, - List, - Optional, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Callable, cast import dateutil.parser import numpy as np @@ -136,9 +125,9 @@ @dataclass class MetadataResult: - added: List[str] = field(default_factory=list) - removed: List[str] = field(default_factory=list) - modified: List[str] = field(default_factory=list) + added: list[str] = field(default_factory=list) + removed: list[str] = field(default_factory=list) + modified: list[str] = field(default_factory=list) class AnnotationDatasource(BaseDatasource): @@ -190,7 +179,7 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult: def get_query_str(self, query_obj: QueryObjectDict) -> str: raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: raise NotImplementedError() @@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin): __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table: Mapped["SqlaTable"] = relationship( + table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="columns", ) @@ -263,15 +252,15 @@ def is_temporal(self) -> bool: return self.type_generic == GenericDataType.TEMPORAL @property - def db_engine_spec(self) -> Type[BaseEngineSpec]: + def db_engine_spec(self) -> type[BaseEngineSpec]: return self.table.db_engine_spec @property - def db_extra(self) -> Dict[str, Any]: + def db_extra(self) -> dict[str, Any]: return self.table.database.get_extra() @property - def type_generic(self) -> Optional[utils.GenericDataType]: + def type_generic(self) -> utils.GenericDataType | None: if self.is_dttm: return GenericDataType.TEMPORAL @@ -310,8 +299,8 @@ def type_generic(self) -> Optional[utils.GenericDataType]: def get_sqla_col( self, - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, + label: str | None = None, + template_processor: BaseTemplateProcessor | None = None, ) -> Column: label = label or self.column_name db_engine_spec = self.db_engine_spec @@ -332,10 +321,10 @@ def datasource(self) -> RelationshipProperty: def get_timestamp_expression( self, - time_grain: Optional[str], - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, - ) -> Union[TimestampExpression, Label]: + time_grain: str | None, + label: str | None = None, + template_processor: BaseTemplateProcessor | None = None, + ) -> TimestampExpression | Label: """ Return a SQLAlchemy Core element representation of self to be used in a query. @@ -365,7 +354,7 @@ def get_timestamp_expression( return self.table.make_sqla_column_compatible(time_expr, label) @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "id", "column_name", @@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): __tablename__ = "sql_metrics" __table_args__ = (UniqueConstraint("table_id", "metric_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table: Mapped["SqlaTable"] = relationship( + table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="metrics", ) @@ -425,8 +414,8 @@ def __repr__(self) -> str: def get_sqla_col( self, - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, + label: str | None = None, + template_processor: BaseTemplateProcessor | None = None, ) -> Column: label = label or self.metric_name expression = self.expression @@ -437,7 +426,7 @@ def get_sqla_col( return self.table.make_sqla_column_compatible(sqla_col, label) @property - def perm(self) -> Optional[str]: + def perm(self) -> str | None: return ( ("{parent_name}.[{obj.metric_name}](id:{obj.id})").format( obj=self, parent_name=self.table.full_name @@ -446,11 +435,11 @@ def perm(self) -> Optional[str]: else None ) - def get_perm(self) -> Optional[str]: + def get_perm(self) -> str | None: return self.perm @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "is_certified", "certified_by", @@ -473,11 +462,11 @@ def data(self) -> Dict[str, Any]: def _process_sql_expression( - expression: Optional[str], + expression: str | None, database_id: int, schema: str, - template_processor: Optional[BaseTemplateProcessor] = None, -) -> Optional[str]: + template_processor: BaseTemplateProcessor | None = None, +) -> str | None: if template_processor and expression: expression = template_processor.process_template(expression) if expression: @@ -501,12 +490,12 @@ class SqlaTable( type = "table" query_language = "sql" is_rls_supported = True - columns: Mapped[List[TableColumn]] = relationship( + columns: Mapped[list[TableColumn]] = relationship( TableColumn, back_populates="table", cascade="all, delete-orphan", ) - metrics: Mapped[List[SqlMetric]] = relationship( + metrics: Mapped[list[SqlMetric]] = relationship( SqlMetric, back_populates="table", cascade="all, delete-orphan", @@ -577,11 +566,11 @@ def __repr__(self) -> str: # pylint: disable=invalid-repr-returned return self.name @property - def db_extra(self) -> Dict[str, Any]: + def db_extra(self) -> dict[str, Any]: return self.database.get_extra() @staticmethod - def _apply_cte(sql: str, cte: Optional[str]) -> str: + def _apply_cte(sql: str, cte: str | None) -> str: """ Append a CTE before the SELECT statement if defined @@ -594,7 +583,7 @@ def _apply_cte(sql: str, cte: Optional[str]) -> str: return sql @property - def db_engine_spec(self) -> Type[BaseEngineSpec]: + def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]: return self.database.db_engine_spec @property @@ -637,9 +626,9 @@ def get_datasource_by_name( cls, session: Session, datasource_name: str, - schema: Optional[str], + schema: str | None, database_name: str, - ) -> Optional[SqlaTable]: + ) -> SqlaTable | None: schema = schema or None query = ( session.query(cls) @@ -660,7 +649,7 @@ def link(self) -> Markup: anchor = f'{name}' return Markup(anchor) - def get_schema_perm(self) -> Optional[str]: + def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) @@ -685,18 +674,18 @@ def full_name(self) -> str: ) @property - def dttm_cols(self) -> List[str]: + def dttm_cols(self) -> list[str]: l = [c.column_name for c in self.columns if c.is_dttm] if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property - def num_cols(self) -> List[str]: + def num_cols(self) -> list[str]: return [c.column_name for c in self.columns if c.is_numeric] @property - def any_dttm_col(self) -> Optional[str]: + def any_dttm_col(self) -> str | None: cols = self.dttm_cols return cols[0] if cols else None @@ -713,7 +702,7 @@ def html(self) -> str: def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) - def external_metadata(self) -> List[Dict[str, str]]: + def external_metadata(self) -> list[dict[str, str]]: # todo(yongjie): create a physical table column type in a separate PR if self.sql: return get_virtual_table_metadata(dataset=self) # type: ignore @@ -724,14 +713,14 @@ def external_metadata(self) -> List[Dict[str, str]]: ) @property - def time_column_grains(self) -> Dict[str, Any]: + def time_column_grains(self) -> dict[str, Any]: return { "time_columns": self.dttm_cols, "time_grains": [grain.name for grain in self.database.grains()], } @property - def select_star(self) -> Optional[str]: + def select_star(self) -> str | None: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( @@ -739,20 +728,20 @@ def select_star(self) -> Optional[str]: ) @property - def health_check_message(self) -> Optional[str]: + def health_check_message(self) -> str | None: check = config["DATASET_HEALTH_CHECK"] return check(self) if check else None @property - def granularity_sqla(self) -> List[Tuple[Any, Any]]: + def granularity_sqla(self) -> list[tuple[Any, Any]]: return utils.choicify(self.dttm_cols) @property - def time_grain_sqla(self) -> List[Tuple[Any, Any]]: + def time_grain_sqla(self) -> list[tuple[Any, Any]]: return [(g.duration, g.name) for g in self.database.grains() or []] @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: data_ = super().data if self.type == "table": data_["granularity_sqla"] = self.granularity_sqla @@ -767,7 +756,7 @@ def data(self) -> Dict[str, Any]: return data_ @property - def extra_dict(self) -> Dict[str, Any]: + def extra_dict(self) -> dict[str, Any]: try: return json.loads(self.extra) except (TypeError, json.JSONDecodeError): @@ -775,7 +764,7 @@ def extra_dict(self) -> Dict[str, Any]: def get_fetch_values_predicate( self, - template_processor: Optional[BaseTemplateProcessor] = None, + template_processor: BaseTemplateProcessor | None = None, ) -> TextClause: fetch_values_predicate = self.fetch_values_predicate if template_processor: @@ -792,7 +781,7 @@ def get_fetch_values_predicate( ) ) from ex - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -869,8 +858,8 @@ def get_sqla_table(self) -> TableClause: return tbl def get_from_clause( - self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Tuple[Union[TableClause, Alias], Optional[str]]: + self, template_processor: BaseTemplateProcessor | None = None + ) -> tuple[TableClause | Alias, str | None]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a @@ -899,7 +888,7 @@ def get_from_clause( return from_clause, cte def get_rendered_sql( - self, template_processor: Optional[BaseTemplateProcessor] = None + self, template_processor: BaseTemplateProcessor | None = None ) -> str: """ Render sql with template engine (Jinja). @@ -928,8 +917,8 @@ def get_rendered_sql( def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, TableColumn], - template_processor: Optional[BaseTemplateProcessor] = None, + columns_by_name: dict[str, TableColumn], + template_processor: BaseTemplateProcessor | None = None, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. @@ -946,7 +935,7 @@ def adhoc_metric_to_sqla( if expression_type == utils.AdhocMetricExpressionType.SIMPLE: metric_column = metric.get("column") or {} column_name = cast(str, metric_column.get("column_name")) - table_column: Optional[TableColumn] = columns_by_name.get(column_name) + table_column: TableColumn | None = columns_by_name.get(column_name) if table_column: sqla_column = table_column.get_sqla_col( template_processor=template_processor @@ -971,7 +960,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals self, col: AdhocColumn, force_type_check: bool = False, - template_processor: Optional[BaseTemplateProcessor] = None, + template_processor: BaseTemplateProcessor | None = None, ) -> ColumnElement: """ Turn an adhoc column into a sqlalchemy column. @@ -1021,7 +1010,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals return self.make_sqla_column_compatible(sqla_column, label) def make_sqla_column_compatible( - self, sqla_col: ColumnElement, label: Optional[str] = None + self, sqla_col: ColumnElement, label: str | None = None ) -> ColumnElement: """Takes a sqlalchemy column object and adds label info if supported by engine. :param sqla_col: sqlalchemy column instance @@ -1038,7 +1027,7 @@ def make_sqla_column_compatible( return sqla_col def make_orderby_compatible( - self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] + self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement] ) -> None: """ If needed, make sure aliases for selected columns are not used in @@ -1069,7 +1058,7 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool: def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor, - ) -> List[TextClause]: + ) -> list[TextClause]: """ Return the appropriate row level security filters for this table and the current user. A custom username can be passed when the user is not present in the @@ -1078,8 +1067,8 @@ def get_sqla_row_level_filters( :param template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. """ - all_filters: List[TextClause] = [] - filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) + all_filters: list[TextClause] = [] + filter_groups: dict[int | str, list[TextClause]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): clause = self.text( @@ -1114,9 +1103,9 @@ def text(self, clause: str) -> TextClause: def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Dict[str, SqlMetric], - columns_by_name: Dict[str, TableColumn], - template_processor: Optional[BaseTemplateProcessor] = None, + metrics_by_name: dict[str, SqlMetric], + columns_by_name: dict[str, TableColumn], + template_processor: BaseTemplateProcessor | None = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) @@ -1138,8 +1127,8 @@ def _normalize_prequery_result_type( self, row: pd.Series, dimension: str, - columns_by_name: Dict[str, TableColumn], - ) -> Union[str, int, float, bool, Text]: + columns_by_name: dict[str, TableColumn], + ) -> str | int | float | bool | Text: """ Convert a prequery result type to its equivalent Python type. @@ -1159,7 +1148,7 @@ def _normalize_prequery_result_type( value = value.item() column_ = columns_by_name[dimension] - db_extra: Dict[str, Any] = self.database.get_extra() + db_extra: dict[str, Any] = self.database.get_extra() if column_.type and column_.is_temporal and isinstance(value, str): sql = self.db_engine_spec.convert_dttm( @@ -1174,9 +1163,9 @@ def _normalize_prequery_result_type( def _get_top_groups( self, df: pd.DataFrame, - dimensions: List[str], - groupby_exprs: Dict[str, Any], - columns_by_name: Dict[str, TableColumn], + dimensions: list[str], + groupby_exprs: dict[str, Any], + columns_by_name: dict[str, TableColumn], ) -> ColumnElement: groups = [] for _unused, row in df.iterrows(): @@ -1201,7 +1190,7 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult: errors = None error_message = None - def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: + def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: """ Some engines change the case or generate bespoke column names, either by default or due to lack of support for aliasing. This function ensures that @@ -1283,7 +1272,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: else self.columns ) - old_columns_by_name: Dict[str, TableColumn] = { + old_columns_by_name: dict[str, TableColumn] = { col.column_name: col for col in old_columns } results = MetadataResult( @@ -1341,8 +1330,8 @@ def query_datasources_by_name( session: Session, database: Database, datasource_name: str, - schema: Optional[str] = None, - ) -> List[SqlaTable]: + schema: str | None = None, + ) -> list[SqlaTable]: query = ( session.query(cls) .filter_by(database_id=database.id) @@ -1357,9 +1346,9 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name cls, session: Session, database: Database, - permissions: Set[str], - schema_perms: Set[str], - ) -> List[SqlaTable]: + permissions: set[str], + schema_perms: set[str], + ) -> list[SqlaTable]: # TODO(hughhhh): add unit test return ( session.query(cls) @@ -1389,7 +1378,7 @@ def get_eager_sqlatable_datasource( ) @classmethod - def get_all_datasources(cls, session: Session) -> List[SqlaTable]: + def get_all_datasources(cls, session: Session) -> list[SqlaTable]: qry = session.query(cls) qry = cls.default_query(qry) return qry.all() @@ -1409,7 +1398,7 @@ def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: :param query_obj: query object to analyze :return: True if there are call(s) to an `ExtraCache` method, False otherwise """ - templatable_statements: List[str] = [] + templatable_statements: list[str] = [] if self.sql: templatable_statements.append(self.sql) if self.fetch_values_predicate: @@ -1428,7 +1417,7 @@ def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: return True return False - def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]: + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: """ The cache key of a SqlaTable needs to consider any keys added by the parent class and any keys added via `ExtraCache`. @@ -1489,7 +1478,7 @@ def before_update( @staticmethod def update_column( # pylint: disable=unused-argument - mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn] + mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn ) -> None: """ :param mapper: Unused. diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 698311dab65ef..d41c0555d382e 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -17,19 +17,9 @@ from __future__ import annotations import logging +from collections.abc import Iterable, Iterator from functools import lru_cache -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Type, - TYPE_CHECKING, - TypeVar, -) +from typing import Any, Callable, TYPE_CHECKING, TypeVar from uuid import UUID from flask_babel import lazy_gettext as _ @@ -58,8 +48,8 @@ def get_physical_table_metadata( database: Database, table_name: str, - schema_name: Optional[str] = None, -) -> List[Dict[str, Any]]: + schema_name: str | None = None, +) -> list[dict[str, Any]]: """Use SQLAlchemy inspector to get table metadata""" db_engine_spec = database.db_engine_spec db_dialect = database.get_dialect() @@ -103,7 +93,7 @@ def get_physical_table_metadata( return cols -def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: +def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: """Use SQLparser to get virtual dataset metadata""" if not dataset.sql: raise SupersetGenericDBErrorException( @@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: def get_columns_description( database: Database, query: str, -) -> List[ResultSetColumnType]: +) -> list[ResultSetColumnType]: db_engine_spec = database.db_engine_spec try: with database.get_raw_connection() as conn: @@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str: @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) -def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: +def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]: return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote @@ -181,9 +171,9 @@ def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: def find_cached_objects_in_session( session: Session, - cls: Type[DeclarativeModel], - ids: Optional[Iterable[int]] = None, - uuids: Optional[Iterable[UUID]] = None, + cls: type[DeclarativeModel], + ids: Iterable[int] | None = None, + uuids: Iterable[UUID] | None = None, ) -> Iterator[DeclarativeModel]: """Find known ORM instances in cached SQLA session states. diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 0989a545fdcb1..9116b9636e220 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -447,7 +447,7 @@ def edit(self, pk: str) -> FlaskResponse: resp = super().edit(pk) if isinstance(resp, str): return resp - return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk)) + return redirect(f"/explore/?datasource_type=table&datasource_id={pk}") @expose("/list/") @has_access diff --git a/superset/css_templates/commands/bulk_delete.py b/superset/css_templates/commands/bulk_delete.py index 93564208c4f15..57612d90485d7 100644 --- a/superset/css_templates/commands/bulk_delete.py +++ b/superset/css_templates/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.commands.base import BaseCommand from superset.css_templates.commands.exceptions import ( @@ -30,9 +30,9 @@ class BulkDeleteCssTemplateCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[CssTemplate]] = None + self._models: Optional[list[CssTemplate]] = None def run(self) -> None: self.validate() diff --git a/superset/css_templates/dao.py b/superset/css_templates/dao.py index 1862fb7aafd97..bc1a796269384 100644 --- a/superset/css_templates/dao.py +++ b/superset/css_templates/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from sqlalchemy.exc import SQLAlchemyError @@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO): model_cls = CssTemplate @staticmethod - def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] try: db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete( diff --git a/superset/dao/base.py b/superset/dao/base.py index d3675a0e17c93..539dbab2d5141 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=isinstance-second-argument-not-valid-type -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -37,7 +37,7 @@ class BaseDAO: Base DAO, implement base CRUD sqlalchemy operations """ - model_cls: Optional[Type[Model]] = None + model_cls: Optional[type[Model]] = None """ Child classes need to state the Model class so they don't need to implement basic create, update and delete methods @@ -75,10 +75,10 @@ def find_by_id( @classmethod def find_by_ids( cls, - model_ids: Union[List[str], List[int]], + model_ids: Union[list[str], list[int]], session: Session = None, skip_base_filter: bool = False, - ) -> List[Model]: + ) -> list[Model]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ @@ -95,7 +95,7 @@ def find_by_ids( return query.all() @classmethod - def find_all(cls) -> List[Model]: + def find_all(cls) -> list[Model]: """ Get all that fit the `base_filter` """ @@ -121,7 +121,7 @@ def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]: return query.filter_by(**filter_by).one_or_none() @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model: + def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: """ Generic for creating models :raises: DAOCreateFailedError @@ -163,7 +163,7 @@ def save(cls, instance_model: Model, commit: bool = True) -> Model: @classmethod def update( - cls, model: Model, properties: Dict[str, Any], commit: bool = True + cls, model: Model, properties: dict[str, Any], commit: bool = True ) -> Model: """ Generic update a model @@ -196,7 +196,7 @@ def delete(cls, model: Model, commit: bool = True) -> Model: return model @classmethod - def bulk_delete(cls, models: List[Model], commit: bool = True) -> None: + def bulk_delete(cls, models: list[Model], commit: bool = True) -> None: try: for model in models: cls.delete(model, False) diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py index 13541cd946ba0..385f1fbc6d285 100644 --- a/superset/dashboards/commands/bulk_delete.py +++ b/superset/dashboards/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from flask_babel import lazy_gettext as _ @@ -37,9 +37,9 @@ class BulkDeleteDashboardCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Dashboard]] = None + self._models: Optional[list[Dashboard]] = None def run(self) -> None: self.validate() diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py index 0ad8ddee7c4d7..58acc379baf5b 100644 --- a/superset/dashboards/commands/create.py +++ b/superset/dashboards/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -34,7 +34,7 @@ class CreateDashboardCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -48,9 +48,9 @@ def run(self) -> Model: return dashboard def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") - role_ids: Optional[List[int]] = self._properties.get("roles") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") + role_ids: Optional[list[int]] = self._properties.get("roles") slug: str = self._properties.get("slug", "") # Validate slug uniqueness diff --git a/superset/dashboards/commands/export.py b/superset/dashboards/commands/export.py index 886b84ffa6db0..2e70e29bb0caf 100644 --- a/superset/dashboards/commands/export.py +++ b/superset/dashboards/commands/export.py @@ -20,7 +20,8 @@ import logging import random import string -from typing import Any, Dict, Iterator, Optional, Set, Tuple +from typing import Any, Optional +from collections.abc import Iterator import yaml @@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str: ) -def get_default_position(title: str) -> Dict[str, Any]: +def get_default_position(title: str) -> dict[str, Any]: return { "DASHBOARD_VERSION_KEY": "v2", "ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"}, @@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]: } -def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]: +def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]: chart_hashes = [f"CHART-{suffix()}" for _ in charts] # if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid @@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand): @staticmethod def _export( model: Dashboard, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: file_name = get_filename(model.dashboard_title, model.id) file_path = f"dashboards/{file_name}.yaml" diff --git a/superset/dashboards/commands/importers/dispatcher.py b/superset/dashboards/commands/importers/dispatcher.py index dd0121f3e3500..d5323b4fe4dd1 100644 --- a/superset/dashboards/commands/importers/dispatcher.py +++ b/superset/dashboards/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index e49c931896838..012dbbc5c9663 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -19,7 +19,7 @@ import time from copy import copy from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_babel import lazy_gettext as _ from sqlalchemy.orm import make_transient, Session @@ -83,7 +83,7 @@ def import_chart( def import_dashboard( # pylint: disable=too-many-locals,too-many-statements dashboard_to_import: Dashboard, - dataset_id_mapping: Optional[Dict[int, int]] = None, + dataset_id_mapping: Optional[dict[int, int]] = None, import_time: Optional[int] = None, ) -> int: """Imports the dashboard from the object to the database. @@ -97,7 +97,7 @@ def import_dashboard( """ def alter_positions( - dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int] + dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int] ) -> None: """Updates slice_ids in the position json. @@ -166,7 +166,7 @@ def alter_native_filters(dashboard: Dashboard) -> None: dashboard_to_import.slug = None old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}") - old_to_new_slc_id_dict: Dict[int, int] = {} + old_to_new_slc_id_dict: dict[int, int] = {} new_timed_refresh_immune_slices = [] new_expanded_slices = {} new_filter_scopes = {} @@ -268,7 +268,7 @@ def alter_native_filters(dashboard: Dashboard) -> None: return dashboard_to_import.id # type: ignore -def decode_dashboards(o: Dict[str, Any]) -> Any: +def decode_dashboards(o: dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. @@ -302,7 +302,7 @@ def import_dashboards( data = json.loads(content, object_hook=decode_dashboards) if not data: raise DashboardImportException(_("No data in file")) - dataset_id_mapping: Dict[int, int] = {} + dataset_id_mapping: dict[int, int] = {} for table in data["datasources"]: new_dataset_id = import_dataset(table, database_id, import_time=import_time) params = json.loads(table.params) @@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand): # pylint: disable=unused-argument def __init__( - self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any + self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any ): self.contents = contents self.database_id = database_id diff --git a/superset/dashboards/commands/importers/v1/__init__.py b/superset/dashboards/commands/importers/v1/__init__.py index 5d83a580bd606..597adba6d9cab 100644 --- a/superset/dashboards/commands/importers/v1/__init__.py +++ b/superset/dashboards/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand): dao = DashboardDAO model_name = "dashboard" prefix = "dashboards/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), "datasets/": ImportV1DatasetSchema(), @@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand): # pylint: disable=too-many-branches, too-many-locals @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover charts and datasets associated with dashboards - chart_uuids: Set[str] = set() - dataset_uuids: Set[str] = set() + chart_uuids: set[str] = set() + dataset_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("dashboards/"): chart_uuids.update(find_chart_uuids(config["position"])) @@ -77,20 +77,20 @@ def _import( dataset_uuids.add(config["dataset_uuid"]) # discover databases associated with datasets - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids: database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref - dataset_info: Dict[str, Dict[str, Any]] = {} + dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if ( file_name.startswith("datasets/") @@ -105,7 +105,7 @@ def _import( } # import charts with the correct parent ref - chart_ids: Dict[str, int] = {} + chart_ids: dict[str, int] = {} for file_name, config in configs.items(): if ( file_name.startswith("charts/") @@ -129,7 +129,7 @@ def _import( ).fetchall() # import dashboards - dashboard_chart_ids: List[Tuple[int, int]] = [] + dashboard_chart_ids: list[tuple[int, int]] = [] for file_name, config in configs.items(): if file_name.startswith("dashboards/"): config = update_id_refs(config, chart_ids, dataset_info) diff --git a/superset/dashboards/commands/importers/v1/utils.py b/superset/dashboards/commands/importers/v1/utils.py index 9f0ffc36a1e24..1deb44949a685 100644 --- a/superset/dashboards/commands/importers/v1/utils.py +++ b/superset/dashboards/commands/importers/v1/utils.py @@ -17,7 +17,7 @@ import json import logging -from typing import Any, Dict, Set +from typing import Any from flask import g from sqlalchemy.orm import Session @@ -32,12 +32,12 @@ JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"} -def find_chart_uuids(position: Dict[str, Any]) -> Set[str]: +def find_chart_uuids(position: dict[str, Any]) -> set[str]: return set(build_uuid_to_id_map(position)) -def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]: - uuids: Set[str] = set() +def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]: + uuids: set[str] = set() for native_filter in metadata.get("native_filter_configuration", []): targets = native_filter.get("targets", []) for target in targets: @@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]: return uuids -def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]: +def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]: return { child["meta"]["uuid"]: child["meta"]["chartId"] for child in position.values() @@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]: def update_id_refs( # pylint: disable=too-many-locals - config: Dict[str, Any], - chart_ids: Dict[str, int], - dataset_info: Dict[str, Dict[str, Any]], -) -> Dict[str, Any]: + config: dict[str, Any], + chart_ids: dict[str, int], + dataset_info: dict[str, dict[str, Any]], +) -> dict[str, Any]: """Update dashboard metadata to use new IDs""" fixed = config.copy() @@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals def import_dashboard( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Dashboard: diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index 11833a64be17d..fefa65e3f6f0b 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -41,7 +41,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Dashboard] = None @@ -64,9 +64,9 @@ def run(self) -> Model: return dashboard def validate(self) -> None: - exceptions: List[ValidationError] = [] - owners_ids: Optional[List[int]] = self._properties.get("owners") - roles_ids: Optional[List[int]] = self._properties.get("roles") + exceptions: list[ValidationError] = [] + owners_ids: Optional[list[int]] = self._properties.get("owners") + roles_ids: Optional[list[int]] = self._properties.get("roles") slug: Optional[str] = self._properties.get("slug") # Validate/populate model exists diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index 5355d602bec04..d88fb431b757d 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -17,7 +17,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from flask import g from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -68,12 +68,12 @@ def get_by_id_or_slug(cls, id_or_slug: Union[int, str]) -> Dashboard: return dashboard @staticmethod - def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]: + def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]: dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug) return dashboard.datasets_trimmed_for_slices() @staticmethod - def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]: + def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]: return DashboardDAO.get_by_id_or_slug(id_or_slug).slices @staticmethod @@ -173,7 +173,7 @@ def update_charts_owners(model: Dashboard, commit: bool = True) -> Dashboard: return model @staticmethod - def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data if models: @@ -196,8 +196,8 @@ def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None: @staticmethod def set_dash_metadata( # pylint: disable=too-many-locals dashboard: Dashboard, - data: Dict[Any, Any], - old_to_new_slice_ids: Optional[Dict[int, int]] = None, + data: dict[Any, Any], + old_to_new_slice_ids: Optional[dict[int, int]] = None, commit: bool = False, ) -> Dashboard: new_filter_scopes = {} @@ -235,7 +235,7 @@ def set_dash_metadata( # pylint: disable=too-many-locals if "filter_scopes" in data: # replace filter_id and immune ids from old slice id to new slice id: # and remove slice ids that are not in dash anymore - slc_id_dict: Dict[int, int] = {} + slc_id_dict: dict[int, int] = {} if old_to_new_slice_ids: slc_id_dict = { old: new @@ -288,7 +288,7 @@ def set_dash_metadata( # pylint: disable=too-many-locals return dashboard @staticmethod - def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]: + def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]: ids = [dash.id for dash in dashboards] return [ star.obj_id @@ -303,7 +303,7 @@ def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]: @classmethod def copy_dashboard( - cls, original_dash: Dashboard, data: Dict[str, Any] + cls, original_dash: Dashboard, data: dict[str, Any] ) -> Dashboard: dash = Dashboard() dash.owners = [g.user] if g.user else [] @@ -311,7 +311,7 @@ def copy_dashboard( dash.css = data.get("css") metadata = json.loads(data["json_metadata"]) - old_to_new_slice_ids: Dict[int, int] = {} + old_to_new_slice_ids: dict[int, int] = {} if data.get("duplicate_slices"): # Duplicating slices as well, mapping old ids to new ones for slc in original_dash.slices: diff --git a/superset/dashboards/filter_sets/commands/create.py b/superset/dashboards/filter_sets/commands/create.py index de1d70daf7879..63c4534786249 100644 --- a/superset/dashboards/filter_sets/commands/create.py +++ b/superset/dashboards/filter_sets/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_appbuilder.models.sqla import Model @@ -40,7 +40,7 @@ class CreateFilterSetCommand(BaseFilterSetCommand): # pylint: disable=C0103 - def __init__(self, dashboard_id: int, data: Dict[str, Any]): + def __init__(self, dashboard_id: int, data: dict[str, Any]): super().__init__(dashboard_id) self._properties = data.copy() diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py index 07d59f93aee23..722672d6684d8 100644 --- a/superset/dashboards/filter_sets/commands/update.py +++ b/superset/dashboards/filter_sets/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_appbuilder.models.sqla import Model @@ -31,7 +31,7 @@ class UpdateFilterSetCommand(BaseFilterSetCommand): - def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]): + def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]): super().__init__(dashboard_id) self._filter_set_id = filter_set_id self._properties = data.copy() diff --git a/superset/dashboards/filter_sets/dao.py b/superset/dashboards/filter_sets/dao.py index 949aa6d3fdf25..5f2b0ba418edd 100644 --- a/superset/dashboards/filter_sets/dao.py +++ b/superset/dashboards/filter_sets/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_appbuilder.models.sqla import Model from sqlalchemy.exc import SQLAlchemyError @@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO): model_cls = FilterSet @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model: + def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: if cls.model_cls is None: raise DAOConfigError() model = FilterSet() diff --git a/superset/dashboards/filter_sets/schemas.py b/superset/dashboards/filter_sets/schemas.py index c1a13b424e815..2309eea99fabf 100644 --- a/superset/dashboards/filter_sets/schemas.py +++ b/superset/dashboards/filter_sets/schemas.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, cast, Dict, Mapping +from collections.abc import Mapping +from typing import Any, cast from marshmallow import fields, post_load, Schema, ValidationError from marshmallow.validate import Length, OneOf @@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema): @post_load def validate( self, data: Mapping[Any, Any], *, many: Any, partial: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: self._validate_json_meta_data(data[JSON_METADATA_FIELD]) if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data: raise ValidationError("owner_id is mandatory when owner_type is User") - return cast(Dict[str, Any], data) + return cast(dict[str, Any], data) class FilterSetPutSchema(FilterSetSchema): @@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema): @post_load def validate( # pylint: disable=unused-argument self, data: Mapping[Any, Any], *, many: Any, partial: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if JSON_METADATA_FIELD in data: self._validate_json_meta_data(data[JSON_METADATA_FIELD]) - return cast(Dict[str, Any], data) + return cast(dict[str, Any], data) -def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None: +def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None: if first_field in data and second_field not in data: raise ValidationError( - "{} must be included alongside {}".format(first_field, second_field) + f"{first_field} must be included alongside {second_field}" ) diff --git a/superset/dashboards/filter_state/api.py b/superset/dashboards/filter_state/api.py index 7a771d6b54098..a1b855ca9ee6e 100644 --- a/superset/dashboards/filter_state/api.py +++ b/superset/dashboards/filter_state/api.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Type from flask import Response from flask_appbuilder.api import expose, protect, safe @@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi): resource_name = "dashboard" openapi_spec_tag = "Dashboard Filter State" - def get_create_command(self) -> Type[CreateFilterStateCommand]: + def get_create_command(self) -> type[CreateFilterStateCommand]: return CreateFilterStateCommand - def get_update_command(self) -> Type[UpdateFilterStateCommand]: + def get_update_command(self) -> type[UpdateFilterStateCommand]: return UpdateFilterStateCommand - def get_get_command(self) -> Type[GetFilterStateCommand]: + def get_get_command(self) -> type[GetFilterStateCommand]: return GetFilterStateCommand - def get_delete_command(self) -> Type[DeleteFilterStateCommand]: + def get_delete_command(self) -> type[DeleteFilterStateCommand]: return DeleteFilterStateCommand @expose("//filter_state", methods=("POST",)) diff --git a/superset/dashboards/permalink/types.py b/superset/dashboards/permalink/types.py index 91c5a9620cf71..4961d2a17bf67 100644 --- a/superset/dashboards/permalink/types.py +++ b/superset/dashboards/permalink/types.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Optional, TypedDict class DashboardPermalinkState(TypedDict): - dataMask: Optional[Dict[str, Any]] - activeTabs: Optional[List[str]] + dataMask: Optional[dict[str, Any]] + activeTabs: Optional[list[str]] anchor: Optional[str] - urlParams: Optional[List[Tuple[str, str]]] + urlParams: Optional[list[tuple[str, str]]] class DashboardPermalinkValue(TypedDict): diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index ab93e4130f87a..846ed39e825cf 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -16,7 +16,7 @@ # under the License. import json import re -from typing import Any, Dict, Union +from typing import Any, Union from marshmallow import fields, post_load, pre_load, Schema from marshmallow.validate import Length, ValidationError @@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema): @pre_load def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use self, - data: Dict[str, Any], + data: dict[str, Any], **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Remove ``show_native_filters`` from the JSON metadata. @@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema): class BaseDashboardSchema(Schema): # pylint: disable=no-self-use,unused-argument @post_load - def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: if data.get("slug"): data["slug"] = data["slug"].strip() data["slug"] = data["slug"].replace(" ", "-") diff --git a/superset/databases/api.py b/superset/databases/api.py index 77f959618292a..c214065a27d11 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -19,7 +19,7 @@ import logging from datetime import datetime from io import BytesIO -from typing import Any, cast, Dict, List, Optional +from typing import Any, cast, Optional from zipfile import is_zipfile, ZipFile from flask import request, Response, send_file @@ -1328,13 +1328,13 @@ def available(self) -> Response: 500: $ref: '#/components/responses/500' """ - preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", []) + preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", []) available_databases = [] for engine_spec, drivers in get_available_engine_specs().items(): if not drivers: continue - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "name": engine_spec.engine_name, "engine": engine_spec.engine, "available_drivers": sorted(drivers), diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 16d27835b37c3..e3fd667130c2e 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import current_app from flask_appbuilder.models.sqla import Model @@ -47,7 +47,7 @@ class CreateDatabaseCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -128,7 +128,7 @@ def run(self) -> Model: return database def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri") database_name: Optional[str] = self._properties.get("database_name") if not sqlalchemy_uri: diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py index e1f8fc2a25165..889cb86c8f095 100644 --- a/superset/databases/commands/export.py +++ b/superset/databases/commands/export.py @@ -18,7 +18,8 @@ import json import logging -from typing import Any, Dict, Iterator, Tuple +from typing import Any +from collections.abc import Iterator import yaml @@ -33,7 +34,7 @@ logger = logging.getLogger(__name__) -def parse_extra(extra_payload: str) -> Dict[str, Any]: +def parse_extra(extra_payload: str) -> dict[str, Any]: try: extra = json.loads(extra_payload) except json.decoder.JSONDecodeError: @@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand): @staticmethod def _export( model: Database, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: db_file_name = get_filename(model.database_name, model.id, skip_id=True) file_path = f"databases/{db_file_name}.yaml" diff --git a/superset/databases/commands/importers/dispatcher.py b/superset/databases/commands/importers/dispatcher.py index 88d38bf13b857..70031b09e4fe6 100644 --- a/superset/databases/commands/importers/dispatcher.py +++ b/superset/databases/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/databases/commands/importers/v1/__init__.py b/superset/databases/commands/importers/v1/__init__.py index 239bd0977f784..ba119beaaa80f 100644 --- a/superset/databases/commands/importers/v1/__init__.py +++ b/superset/databases/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand): dao = DatabaseDAO model_name = "database" prefix = "databases/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "datasets/": ImportV1DatasetSchema(), } @@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # first import databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database(session, config, overwrite=overwrite) diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py index c0c0ee60d99ed..8881f78a9c39c 100644 --- a/superset/databases/commands/importers/v1/utils.py +++ b/superset/databases/commands/importers/v1/utils.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -28,7 +28,7 @@ def import_database( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Database: diff --git a/superset/databases/commands/tables.py b/superset/databases/commands/tables.py index 48e9227dea75e..b7dbb4d461315 100644 --- a/superset/databases/commands/tables.py +++ b/superset/databases/commands/tables.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, cast, Dict +from typing import Any, cast from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable @@ -40,7 +40,7 @@ def __init__(self, db_id: int, schema_name: str, force: bool): self._schema_name = schema_name self._force = force - def run(self) -> Dict[str, Any]: + def run(self) -> dict[str, Any]: self.validate() try: tables = security_manager.get_datasources_accessible_by_user( diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 9809641d5cd6a..2680c5e8c180b 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -17,7 +17,7 @@ import logging import sqlite3 from contextlib import closing -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import current_app as app from flask_babel import gettext as _ @@ -64,7 +64,7 @@ def get_log_connection_action( class TestConnectionDatabaseCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[Database] = None diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 746f7a8152a74..f12706fa1d159 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -47,7 +47,7 @@ class UpdateDatabaseCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[Database] = None @@ -78,7 +78,7 @@ def run(self) -> Model: raise DatabaseConnectionFailedError() from ex # Update database schema permissions - new_schemas: List[str] = [] + new_schemas: list[str] = [] for schema in schemas: old_view_menu_name = security_manager.get_schema_perm( @@ -164,7 +164,7 @@ def _propagate_schema_permissions( chart.schema_perm = new_view_menu_name def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] # Validate/populate model exists self._model = DatabaseDAO.find_by_id(self._model_id) if not self._model: diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index 2a624e32c7abc..d97ad33af9eaf 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -16,7 +16,7 @@ # under the License. import json from contextlib import closing -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_babel import gettext as __ @@ -38,7 +38,7 @@ class ValidateDatabaseParametersCommand(BaseCommand): - def __init__(self, properties: Dict[str, Any]): + def __init__(self, properties: dict[str, Any]): self._properties = properties.copy() self._model: Optional[Database] = None diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py index 346d684a0d2ca..40d88af7457f0 100644 --- a/superset/databases/commands/validate_sql.py +++ b/superset/databases/commands/validate_sql.py @@ -16,7 +16,7 @@ # under the License. import logging import re -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from flask import current_app from flask_babel import gettext as __ @@ -41,13 +41,13 @@ class ValidateSQLCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[Database] = None - self._validator: Optional[Type[BaseSQLValidator]] = None + self._validator: Optional[type[BaseSQLValidator]] = None - def run(self) -> List[Dict[str, Any]]: + def run(self) -> list[dict[str, Any]]: """ Validates a SQL statement @@ -97,9 +97,7 @@ def validate(self) -> None: if not validators_by_engine or spec.engine not in validators_by_engine: raise NoValidatorConfigFoundError( SupersetError( - message=__( - "no SQL validator is configured for {}".format(spec.engine) - ), + message=__(f"no SQL validator is configured for {spec.engine}"), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, ), diff --git a/superset/databases/dao.py b/superset/databases/dao.py index c82f0db5745ae..9ce3b5e73ec2b 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from superset.dao.base import BaseDAO from superset.databases.filters import DatabaseFilter @@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO): def update( cls, model: Database, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> Database: """ @@ -93,7 +93,7 @@ def build_db_for_connection_test( ) @classmethod - def get_related_objects(cls, database_id: int) -> Dict[str, Any]: + def get_related_objects(cls, database_id: int) -> dict[str, Any]: database: Any = cls.find_by_id(database_id) datasets = database.tables dataset_ids = [dataset.id for dataset in datasets] diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 86564e8f15a7e..2ca77b77d1c40 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Set +from typing import Any from flask import g from flask_babel import lazy_gettext as _ @@ -30,7 +30,7 @@ def can_access_databases( view_menu_name: str, -) -> Set[str]: +) -> set[str]: return { security_manager.unpack_database_and_schema(vm).database for vm in security_manager.user_view_menu_names(view_menu_name) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 00e8c3ca5381d..01a00e8b80ca0 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -19,7 +19,7 @@ import inspect import json -from typing import Any, Dict, List +from typing import Any from flask import current_app from flask_babel import lazy_gettext as _ @@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods @pre_load def build_sqlalchemy_uri( - self, data: Dict[str, Any], **kwargs: Any - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: """ Build SQLAlchemy URI from separate parameters. @@ -325,9 +325,9 @@ def build_sqlalchemy_uri( def rename_encrypted_extra( self: Schema, - data: Dict[str, Any], + data: dict[str, Any], **kwargs: Any, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Rename ``encrypted_extra`` to ``masked_encrypted_extra``. @@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema): class ImportV1DatabaseExtraSchema(Schema): @pre_load def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name - self, data: Dict[str, Any], **kwargs: Any - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: """ Fixes for ``schemas_allowed_for_csv_upload``. """ @@ -744,8 +744,8 @@ def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name class ImportV1DatabaseSchema(Schema): @pre_load def fix_allow_csv_upload( - self, data: Dict[str, Any], **kwargs: Any - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: """ Fix for ``allow_csv_upload`` . """ @@ -775,7 +775,7 @@ def fix_allow_csv_upload( ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) @validates_schema - def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None: + def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None: """If sqlalchemy_uri has a masked password, password is required""" uuid = data["uuid"] existing = db.session.query(Database).filter_by(uuid=uuid).first() @@ -789,7 +789,7 @@ def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None: @validates_schema def validate_ssh_tunnel_credentials( - self, data: Dict[str, Any], **kwargs: Any + self, data: dict[str, Any], **kwargs: Any ) -> None: """If ssh_tunnel has a masked credentials, credentials are required""" uuid = data["uuid"] @@ -829,7 +829,7 @@ def validate_ssh_tunnel_credentials( # or there're times where it's masked. # If both are masked, we need to return a list of errors # so the UI ask for both fields at the same time if needed - exception_messages: List[str] = [] + exception_messages: list[str] = [] if private_key is None or private_key == PASSWORD_MASK: # If we get here we need to ask for the private key exception_messages.append( @@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict): pass -def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore +def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore ret = {} if isinstance(field, EncryptedField): if self.openapi_version.major > 2: diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index 45e5af5f44ea9..9c41b83392dc7 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -34,7 +34,7 @@ class CreateSSHTunnelCommand(BaseCommand): - def __init__(self, database_id: int, data: Dict[str, Any]): + def __init__(self, database_id: int, data: dict[str, Any]): self._properties = data.copy() self._properties["database_id"] = database_id @@ -61,7 +61,7 @@ def run(self) -> Model: def validate(self) -> None: # TODO(hughhh): check to make sure the server port is not localhost # using the config.SSH_TUNNEL_MANAGER - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] database_id: Optional[int] = self._properties.get("database_id") server_address: Optional[str] = self._properties.get("server_address") server_port: Optional[int] = self._properties.get("server_port") diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py index 42925d1caa317..37fd4a94b9652 100644 --- a/superset/databases/ssh_tunnel/commands/update.py +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -34,7 +34,7 @@ class UpdateSSHTunnelCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[SSHTunnel] = None diff --git a/superset/databases/ssh_tunnel/dao.py b/superset/databases/ssh_tunnel/dao.py index 89562fc05dcc0..731f9183b348a 100644 --- a/superset/databases/ssh_tunnel/dao.py +++ b/superset/databases/ssh_tunnel/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from superset.dao.base import BaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel @@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO): def update( cls, model: SSHTunnel, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> SSHTunnel: """ diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 3384679cb7896..d9462a63db87f 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any import sqlalchemy as sa from flask import current_app @@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): ] @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: output = { "id": self.id, "server_address": self.server_address, diff --git a/superset/databases/utils.py b/superset/databases/utils.py index 9229bb8cbae84..74943f4747388 100644 --- a/superset/databases/utils.py +++ b/superset/databases/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from sqlalchemy.engine.url import make_url, URL @@ -25,7 +25,7 @@ def get_foreign_keys_metadata( database: Any, table_name: str, schema_name: Optional[str], -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: foreign_keys = database.get_foreign_keys(table_name, schema_name) for fk in foreign_keys: fk["column_names"] = fk.pop("constrained_columns") @@ -35,14 +35,14 @@ def get_foreign_keys_metadata( def get_indexes_metadata( database: Any, table_name: str, schema_name: Optional[str] -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: indexes = database.get_indexes(table_name, schema_name) for idx in indexes: idx["type"] = "index" return indexes -def get_col_type(col: Dict[Any, Any]) -> str: +def get_col_type(col: dict[Any, Any]) -> str: try: dtype = f"{col['type']}" except Exception: # pylint: disable=broad-except @@ -53,7 +53,7 @@ def get_col_type(col: Dict[Any, Any]) -> str: def get_table_metadata( database: Any, table_name: str, schema_name: Optional[str] -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. @@ -73,7 +73,7 @@ def get_table_metadata( foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name) keys += foreign_keys + indexes - payload_columns: List[Dict[str, Any]] = [] + payload_columns: list[dict[str, Any]] = [] table_comment = database.get_table_comment(table_name, schema_name) for col in columns: dtype = get_col_type(col) diff --git a/superset/dataframe.py b/superset/dataframe.py index 8abeedc095298..80839932944d5 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -17,7 +17,7 @@ """ Superset utilities for pandas.DataFrame. """ import logging -from typing import Any, Dict, List +from typing import Any import pandas as pd @@ -37,7 +37,7 @@ def _convert_big_integers(val: Any) -> Any: return str(val) if isinstance(val, int) and abs(val) > JS_MAX_INTEGER else val -def df_to_records(dframe: pd.DataFrame) -> List[Dict[str, Any]]: +def df_to_records(dframe: pd.DataFrame) -> list[dict[str, Any]]: """ Convert a DataFrame to a set of records. diff --git a/superset/datasets/commands/bulk_delete.py b/superset/datasets/commands/bulk_delete.py index 643ac784ec3b3..fd133518098cc 100644 --- a/superset/datasets/commands/bulk_delete.py +++ b/superset/datasets/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset import security_manager from superset.commands.base import BaseCommand @@ -34,9 +34,9 @@ class BulkDeleteDatasetCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[SqlaTable]] = None + self._models: Optional[list[SqlaTable]] = None def run(self) -> None: self.validate() diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 04f54339d0847..1c864ad196d1f 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -37,7 +37,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -55,12 +55,12 @@ def run(self) -> Model: return dataset def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] database_id = self._properties["database"] table_name = self._properties["table_name"] schema = self._properties.get("schema", None) sql = self._properties.get("sql", None) - owner_ids: Optional[List[int]] = self._properties.get("owners") + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate uniqueness if not DatasetDAO.validate_uniqueness(database_id, schema, table_name): diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py index 5fc642cbe3e66..5a4a84fdf9dfe 100644 --- a/superset/datasets/commands/duplicate.py +++ b/superset/datasets/commands/duplicate.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from flask_appbuilder.models.sqla import Model from flask_babel import gettext as __ @@ -43,7 +43,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any]) -> None: self._base_model: SqlaTable = SqlaTable() self._properties = data.copy() @@ -105,7 +105,7 @@ def run(self) -> Model: return table def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] base_model_id = self._properties["base_model_id"] duplicate_name = self._properties["table_name"] diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py index c6fe43c89df33..8c02a23f2967c 100644 --- a/superset/datasets/commands/export.py +++ b/superset/datasets/commands/export.py @@ -18,7 +18,7 @@ import json import logging -from typing import Iterator, Tuple +from collections.abc import Iterator import yaml @@ -43,7 +43,7 @@ class ExportDatasetsCommand(ExportModelsCommand): @staticmethod def _export( model: SqlaTable, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: db_file_name = get_filename( model.database.database_name, model.database.id, skip_id=True ) diff --git a/superset/datasets/commands/importers/dispatcher.py b/superset/datasets/commands/importers/dispatcher.py index 74f1129d23bbd..6be8635da20a7 100644 --- a/superset/datasets/commands/importers/dispatcher.py +++ b/superset/datasets/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -43,7 +43,7 @@ class ImportDatasetsCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index f706ecf38bcf7..c530be3c14aee 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import yaml from flask_appbuilder import Model @@ -213,7 +213,7 @@ def import_simple_obj( def import_from_dict( - session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None + session: Session, data: dict[str, Any], sync: Optional[list[str]] = None ) -> None: """Imports databases from dictionary""" if not sync: @@ -238,12 +238,12 @@ class ImportDatasetsCommand(BaseCommand): # pylint: disable=unused-argument def __init__( self, - contents: Dict[str, str], + contents: dict[str, str], *args: Any, **kwargs: Any, ): self.contents = contents - self._configs: Dict[str, Any] = {} + self._configs: dict[str, Any] = {} self.sync = [] if kwargs.get("sync_columns"): diff --git a/superset/datasets/commands/importers/v1/__init__.py b/superset/datasets/commands/importers/v1/__init__.py index e73213319db6f..e753138ab8fb4 100644 --- a/superset/datasets/commands/importers/v1/__init__.py +++ b/superset/datasets/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Set +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -36,7 +36,7 @@ class ImportDatasetsCommand(ImportModelsCommand): dao = DatasetDAO model_name = "dataset" prefix = "datasets/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "datasets/": ImportV1DatasetSchema(), } @@ -44,16 +44,16 @@ class ImportDatasetsCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover databases associated with datasets - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("datasets/"): database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index 52f46829b5b35..ae47fc411aa36 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -18,7 +18,7 @@ import json import logging import re -from typing import Any, Dict +from typing import Any from urllib import request import pandas as pd @@ -69,7 +69,7 @@ def get_sqla_type(native_type: str) -> VisitableType: raise Exception(f"Unknown type: {native_type}") -def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]: +def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]: return { column.column_name: get_sqla_type(column.type) for column in dataset.columns @@ -101,7 +101,7 @@ def validate_data_uri(data_uri: str) -> None: def import_dataset( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, force_data: bool = False, ignore_permissions: bool = False, diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index cc9f480a41b54..be9625709fdb3 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from collections import Counter -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import current_app from flask_appbuilder.models.sqla import Model @@ -52,7 +52,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): def __init__( self, model_id: int, - data: Dict[str, Any], + data: dict[str, Any], override_columns: Optional[bool] = False, ): self._model_id = model_id @@ -76,8 +76,8 @@ def run(self) -> Model: raise DatasetUpdateFailedError() def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate/populate model exists self._model = DatasetDAO.find_by_id(self._model_id) if not self._model: @@ -125,14 +125,14 @@ def validate(self) -> None: raise DatasetInvalidError(exceptions=exceptions) def _validate_columns( - self, columns: List[Dict[str, Any]], exceptions: List[ValidationError] + self, columns: list[dict[str, Any]], exceptions: list[ValidationError] ) -> None: # Validate duplicates on data if self._get_duplicates(columns, "column_name"): exceptions.append(DatasetColumnsDuplicateValidationError()) else: # validate invalid id's - columns_ids: List[int] = [ + columns_ids: list[int] = [ column["id"] for column in columns if "id" in column ] if not DatasetDAO.validate_columns_exist(self._model_id, columns_ids): @@ -140,7 +140,7 @@ def _validate_columns( # validate new column names uniqueness if not self.override_columns: - columns_names: List[str] = [ + columns_names: list[str] = [ column["column_name"] for column in columns if "id" not in column ] if not DatasetDAO.validate_columns_uniqueness( @@ -149,26 +149,26 @@ def _validate_columns( exceptions.append(DatasetColumnsExistsValidationError()) def _validate_metrics( - self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError] + self, metrics: list[dict[str, Any]], exceptions: list[ValidationError] ) -> None: if self._get_duplicates(metrics, "metric_name"): exceptions.append(DatasetMetricsDuplicateValidationError()) else: # validate invalid id's - metrics_ids: List[int] = [ + metrics_ids: list[int] = [ metric["id"] for metric in metrics if "id" in metric ] if not DatasetDAO.validate_metrics_exist(self._model_id, metrics_ids): exceptions.append(DatasetMetricsNotFoundValidationError()) # validate new metric names uniqueness - metric_names: List[str] = [ + metric_names: list[str] = [ metric["metric_name"] for metric in metrics if "id" not in metric ] if not DatasetDAO.validate_metrics_uniqueness(self._model_id, metric_names): exceptions.append(DatasetMetricsExistsValidationError()) @staticmethod - def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]: + def _get_duplicates(data: list[dict[str, Any]], key: str) -> list[str]: duplicates = [ name for name, count in Counter([item[key] for item in data]).items() diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index b158fce1fefe8..f4d46be109799 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError @@ -44,7 +44,7 @@ def get_database_by_id(database_id: int) -> Optional[Database]: return None @staticmethod - def get_related_objects(database_id: int) -> Dict[str, Any]: + def get_related_objects(database_id: int) -> dict[str, Any]: charts = ( db.session.query(Slice) .filter( @@ -108,7 +108,7 @@ def validate_update_uniqueness( return not db.session.query(dataset_query.exists()).scalar() @staticmethod - def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool: + def validate_columns_exist(dataset_id: int, columns_ids: list[int]) -> bool: dataset_query = ( db.session.query(TableColumn.id).filter( TableColumn.table_id == dataset_id, TableColumn.id.in_(columns_ids) @@ -117,7 +117,7 @@ def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool: return len(columns_ids) == len(dataset_query) @staticmethod - def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bool: + def validate_columns_uniqueness(dataset_id: int, columns_names: list[str]) -> bool: dataset_query = ( db.session.query(TableColumn.id).filter( TableColumn.table_id == dataset_id, @@ -127,7 +127,7 @@ def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bo return len(dataset_query) == 0 @staticmethod - def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool: + def validate_metrics_exist(dataset_id: int, metrics_ids: list[int]) -> bool: dataset_query = ( db.session.query(SqlMetric.id).filter( SqlMetric.table_id == dataset_id, SqlMetric.id.in_(metrics_ids) @@ -136,7 +136,7 @@ def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool: return len(metrics_ids) == len(dataset_query) @staticmethod - def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bool: + def validate_metrics_uniqueness(dataset_id: int, metrics_names: list[str]) -> bool: dataset_query = ( db.session.query(SqlMetric.id).filter( SqlMetric.table_id == dataset_id, @@ -149,7 +149,7 @@ def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bo def update( cls, model: SqlaTable, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> Optional[SqlaTable]: """ @@ -173,7 +173,7 @@ def update( def update_columns( cls, model: SqlaTable, - property_columns: List[Dict[str, Any]], + property_columns: list[dict[str, Any]], commit: bool = True, override_columns: bool = False, ) -> None: @@ -239,7 +239,7 @@ def update_columns( def update_metrics( cls, model: SqlaTable, - property_metrics: List[Dict[str, Any]], + property_metrics: list[dict[str, Any]], commit: bool = True, ) -> None: """ @@ -304,14 +304,14 @@ def find_dataset_column( def update_column( cls, model: TableColumn, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> TableColumn: return DatasetColumnDAO.update(model, properties, commit=commit) @classmethod def create_column( - cls, properties: Dict[str, Any], commit: bool = True + cls, properties: dict[str, Any], commit: bool = True ) -> TableColumn: """ Creates a Dataset model on the metadata DB @@ -346,7 +346,7 @@ def delete_metric(cls, model: SqlMetric, commit: bool = True) -> SqlMetric: def update_metric( cls, model: SqlMetric, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> SqlMetric: return DatasetMetricDAO.update(model, properties, commit=commit) @@ -354,7 +354,7 @@ def update_metric( @classmethod def create_metric( cls, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> SqlMetric: """ @@ -363,7 +363,7 @@ def create_metric( return DatasetMetricDAO.create(properties, commit=commit) @staticmethod - def bulk_delete(models: Optional[List[SqlaTable]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[SqlaTable]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data if models: diff --git a/superset/datasets/models.py b/superset/datasets/models.py index b433709f2c779..50aeea7b51672 100644 --- a/superset/datasets/models.py +++ b/superset/datasets/models.py @@ -24,7 +24,6 @@ These models are not fully implemented, and shouldn't be used yet. """ -from typing import List import sqlalchemy as sa from flask_appbuilder import Model @@ -87,7 +86,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # The relationship between datasets and columns is 1:n, but we use a # many-to-many association table to avoid adding two mutually exclusive # columns(dataset_id and table_id) to Column - columns: List[Column] = relationship( + columns: list[Column] = relationship( "Column", secondary=dataset_column_association_table, cascade="all, delete-orphan", @@ -97,7 +96,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): owners = relationship( security_manager.user_model, secondary=dataset_user_association_table ) - tables: List[Table] = relationship( + tables: list[Table] = relationship( "Table", secondary=dataset_table_association_table, backref="datasets" ) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index f248fc70ff3fa..9a2af980666e8 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -16,7 +16,7 @@ # under the License. import json import re -from typing import Any, Dict +from typing import Any from flask_babel import lazy_gettext as _ from marshmallow import fields, pre_load, Schema, ValidationError @@ -150,7 +150,7 @@ class DatasetRelatedObjectsResponse(Schema): class ImportV1ColumnSchema(Schema): # pylint: disable=no-self-use, unused-argument @pre_load - def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """ Fix for extra initially being exported as a string. """ @@ -176,7 +176,7 @@ def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: class ImportV1MetricSchema(Schema): # pylint: disable=no-self-use, unused-argument @pre_load - def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """ Fix for extra initially being exported as a string. """ @@ -198,7 +198,7 @@ def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: class ImportV1DatasetSchema(Schema): # pylint: disable=no-self-use, unused-argument @pre_load - def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """ Fix for extra initially being exported as a string. """ diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py index 158a32c7fdc6f..4682f070e2bed 100644 --- a/superset/datasource/dao.py +++ b/superset/datasource/dao.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Dict, Type, Union +from typing import Union from sqlalchemy.orm import Session @@ -34,7 +34,7 @@ class DatasourceDAO(BaseDAO): - sources: Dict[Union[DatasourceType, str], Type[Datasource]] = { + sources: dict[Union[DatasourceType, str], type[Datasource]] = { DatasourceType.TABLE: SqlaTable, DatasourceType.QUERY: Query, DatasourceType.SAVEDQUERY: SavedQuery, diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index f19dffd4a3bbe..20cdfcc51f8f1 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -33,7 +33,7 @@ from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Type +from typing import Any, Optional import sqlalchemy.databases import sqlalchemy.dialects @@ -58,11 +58,11 @@ def is_engine_spec(obj: Any) -> bool: ) -def load_engine_specs() -> List[Type[BaseEngineSpec]]: +def load_engine_specs() -> list[type[BaseEngineSpec]]: """ Load all engine specs, native and 3rd party. """ - engine_specs: List[Type[BaseEngineSpec]] = [] + engine_specs: list[type[BaseEngineSpec]] = [] # load standard engines db_engine_spec_dir = str(Path(__file__).parent) @@ -85,7 +85,7 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]: return engine_specs -def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]: +def get_engine_spec(backend: str, driver: Optional[str] = None) -> type[BaseEngineSpec]: """ Return the DB engine spec associated with a given SQLAlchemy URL. @@ -120,11 +120,11 @@ def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngi } -def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]: +def get_available_engine_specs() -> dict[type[BaseEngineSpec], set[str]]: """ Return available engine specs and installed drivers for them. """ - drivers: Dict[str, Set[str]] = defaultdict(set) + drivers: dict[str, set[str]] = defaultdict(set) # native SQLAlchemy dialects for attr in sqlalchemy.databases.__all__: diff --git a/superset/db_engine_specs/athena.py b/superset/db_engine_specs/athena.py index 047952402d2eb..ad6bed113da87 100644 --- a/superset/db_engine_specs/athena.py +++ b/superset/db_engine_specs/athena.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional from flask_babel import gettext as __ from sqlalchemy import types @@ -51,7 +52,7 @@ class AthenaEngineSpec(BaseEngineSpec): date_add('day', 1, CAST({col} AS TIMESTAMP))))", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { SYNTAX_ERROR_REGEX: ( __( "Please check your query for syntax errors at or " @@ -64,7 +65,7 @@ class AthenaEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a7ff8622722c1..ef922a5e63a44 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -22,22 +22,8 @@ import logging import re from datetime import datetime -from typing import ( - Any, - Callable, - ContextManager, - Dict, - List, - Match, - NamedTuple, - Optional, - Pattern, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from re import Match, Pattern +from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union import pandas as pd import sqlparse @@ -77,7 +63,7 @@ from superset.models.core import Database from superset.models.sql_lab import Query -ColumnTypeMapping = Tuple[ +ColumnTypeMapping = tuple[ Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]], GenericDataType, @@ -90,10 +76,10 @@ class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove label: str function: str - duration: Optional[str] + duration: str | None -builtin_time_grains: Dict[Optional[str], str] = { +builtin_time_grains: dict[str | None, str] = { "PT1S": __("Second"), "PT5S": __("5 second"), "PT30S": __("30 second"), @@ -160,12 +146,12 @@ class MetricType(TypedDict, total=False): metric_name: str expression: str - verbose_name: Optional[str] - metric_type: Optional[str] - description: Optional[str] - d3format: Optional[str] - warning_text: Optional[str] - extra: Optional[str] + verbose_name: str | None + metric_type: str | None + description: str | None + d3format: str | None + warning_text: str | None + extra: str | None class BaseEngineSpec: # pylint: disable=too-many-public-methods @@ -182,19 +168,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods having to add the same aggregation in SELECT. """ - engine_name: Optional[str] = None # for user messages, overridden in child classes + engine_name: str | None = None # for user messages, overridden in child classes # These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers; # see the ``supports_url`` and ``supports_backend`` methods below. engine = "base" # str as defined in sqlalchemy.engine.engine - engine_aliases: Set[str] = set() - drivers: Dict[str, str] = {} - default_driver: Optional[str] = None + engine_aliases: set[str] = set() + drivers: dict[str, str] = {} + default_driver: str | None = None disable_ssh_tunneling = False - _date_trunc_functions: Dict[str, str] = {} - _time_grain_expressions: Dict[Optional[str], str] = {} - _default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( + _date_trunc_functions: dict[str, str] = {} + _time_grain_expressions: dict[str | None, str] = {} + _default_column_type_mappings: tuple[ColumnTypeMapping, ...] = ( ( re.compile(r"^string", re.IGNORECASE), types.String(), @@ -312,7 +298,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ), ) # engine-specific type mappings to check prior to the defaults - column_type_mappings: Tuple[ColumnTypeMapping, ...] = () + column_type_mappings: tuple[ColumnTypeMapping, ...] = () # Does database support join-free timeslot grouping time_groupby_inline = False @@ -351,23 +337,23 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods allow_limit_clause = True # This set will give keywords for select statements # to consider for the engines with TOP SQL parsing - select_keywords: Set[str] = {"SELECT"} + select_keywords: set[str] = {"SELECT"} # This set will give the keywords for data limit statements # to consider for the engines with TOP SQL parsing - top_keywords: Set[str] = {"TOP"} + top_keywords: set[str] = {"TOP"} # A set of disallowed connection query parameters by driver name - disallow_uri_query_params: Dict[str, Set[str]] = {} + disallow_uri_query_params: dict[str, set[str]] = {} # A Dict of query parameters that will always be used on every connection # by driver name - enforce_uri_query_params: Dict[str, Dict[str, Any]] = {} + enforce_uri_query_params: dict[str, dict[str, Any]] = {} force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - custom_errors: Dict[ - Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]] + custom_errors: dict[ + Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] ] = {} # Whether the engine supports file uploads @@ -422,7 +408,7 @@ class PostgresDBEngineSpec: return cls.supports_backend(backend, driver) @classmethod - def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool: + def supports_backend(cls, backend: str, driver: str | None = None) -> bool: """ Returns true if the DB engine spec supports a given SQLAlchemy backend/driver. """ @@ -439,7 +425,7 @@ def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool: return driver in cls.drivers @classmethod - def get_default_schema(cls, database: Database) -> Optional[str]: + def get_default_schema(cls, database: Database) -> str | None: """ Return the default schema in a given database. """ @@ -450,8 +436,8 @@ def get_default_schema(cls, database: Database) -> Optional[str]: def get_schema_from_engine_params( # pylint: disable=unused-argument cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], - ) -> Optional[str]: + connect_args: dict[str, Any], + ) -> str | None: """ Return the schema configured in a SQLALchemy URI and connection argments, if any. """ @@ -462,7 +448,7 @@ def get_default_schema_for_query( cls, database: Database, query: Query, - ) -> Optional[str]: + ) -> str | None: """ Return the default schema for a given query. @@ -501,7 +487,7 @@ def get_default_schema_for_query( return cls.get_default_schema(database) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: """ Each engine can implement and converge its own specific exceptions into Superset DBAPI exceptions @@ -541,7 +527,7 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: @classmethod def get_allow_cost_estimate( # pylint: disable=unused-argument cls, - extra: Dict[str, Any], + extra: dict[str, Any], ) -> bool: return False @@ -561,8 +547,8 @@ def get_text_clause(cls, clause: str) -> TextClause: def get_engine( cls, database: Database, - schema: Optional[str] = None, - source: Optional[utils.QuerySource] = None, + schema: str | None = None, + source: utils.QuerySource | None = None, ) -> ContextManager[Engine]: """ Return an engine context manager. @@ -578,8 +564,8 @@ def get_engine( def get_timestamp_expr( cls, col: ColumnClause, - pdf: Optional[str], - time_grain: Optional[str], + pdf: str | None, + time_grain: str | None, ) -> TimestampExpression: """ Construct a TimestampExpression to be used in a SQLAlchemy query. @@ -616,7 +602,7 @@ def get_timestamp_expr( return TimestampExpression(time_expr, col, type_=col.type) @classmethod - def get_time_grains(cls) -> Tuple[TimeGrain, ...]: + def get_time_grains(cls) -> tuple[TimeGrain, ...]: """ Generate a tuple of supported time grains. @@ -634,8 +620,8 @@ def get_time_grains(cls) -> Tuple[TimeGrain, ...]: @classmethod def _sort_time_grains( - cls, val: Tuple[Optional[str], str], index: int - ) -> Union[float, int, str]: + cls, val: tuple[str | None, str], index: int + ) -> float | int | str: """ Return an ordered time-based value of a portion of a time grain for sorting @@ -695,7 +681,7 @@ def sort_interval() -> float: return plist.get(index, 0) @classmethod - def get_time_grain_expressions(cls) -> Dict[Optional[str], str]: + def get_time_grain_expressions(cls) -> dict[str | None, str]: """ Return a dict of all supported time grains including any potential added grains but excluding any potentially disabled grains in the config file. @@ -706,7 +692,7 @@ def get_time_grain_expressions(cls) -> Dict[Optional[str], str]: time_grain_expressions = cls._time_grain_expressions.copy() grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {})) - denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"] + denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"] for key in denylist: time_grain_expressions.pop(key, None) @@ -723,9 +709,7 @@ def get_time_grain_expressions(cls) -> Dict[Optional[str], str]: ) @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: """ :param cursor: Cursor instance @@ -743,9 +727,9 @@ def fetch_data( @classmethod def expand_data( - cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]] - ) -> Tuple[ - List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType] + cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]] + ) -> tuple[ + list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType] ]: """ Some engines support expanding nested fields. See implementation in Presto @@ -759,7 +743,7 @@ def expand_data( return columns, data, [] @classmethod - def alter_new_orm_column(cls, orm_col: "TableColumn") -> None: + def alter_new_orm_column(cls, orm_col: TableColumn) -> None: """Allow altering default column attributes when first detected/added For instance special column like `__time` for Druid can be @@ -789,7 +773,7 @@ def epoch_ms_to_dttm(cls) -> str: return cls.epoch_to_dttm().replace("{col}", "({col}/1000)") @classmethod - def get_datatype(cls, type_code: Any) -> Optional[str]: + def get_datatype(cls, type_code: Any) -> str | None: """ Change column type code from cursor description to string representation. @@ -802,7 +786,7 @@ def get_datatype(cls, type_code: Any) -> Optional[str]: @classmethod @deprecated(deprecated_in="3.0") - def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -818,8 +802,8 @@ def extra_table_metadata( # pylint: disable=unused-argument cls, database: Database, table_name: str, - schema_name: Optional[str], - ) -> Dict[str, Any]: + schema_name: str | None, + ) -> dict[str, Any]: """ Returns engine-specific table metadata @@ -872,7 +856,7 @@ def apply_top_to_sql(cls, sql: str, limit: int) -> str: sql_remainder = None sql = sql.strip(" \t\n;") sql_statement = sqlparse.format(sql, strip_comments=True) - query_limit: Optional[int] = sql_parse.extract_top_from_query( + query_limit: int | None = sql_parse.extract_top_from_query( sql_statement, cls.top_keywords ) if not limit: @@ -928,7 +912,7 @@ def top_not_in_sql(cls, sql: str) -> bool: return True @classmethod - def get_limit_from_sql(cls, sql: str) -> Optional[int]: + def get_limit_from_sql(cls, sql: str) -> int | None: """ Extract limit from SQL query @@ -951,7 +935,7 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str: return parsed_query.set_or_update_query_limit(limit) @classmethod - def get_cte_query(cls, sql: str) -> Optional[str]: + def get_cte_query(cls, sql: str) -> str | None: """ Convert the input CTE based SQL to the SQL for virtual table conversion @@ -981,7 +965,7 @@ def df_to_sql( database: Database, table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. @@ -1012,8 +996,8 @@ def df_to_sql( @classmethod def convert_dttm( # pylint: disable=unused-argument - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: """ Convert a Python `datetime` object to a SQL expression. @@ -1044,8 +1028,8 @@ def _extract_error_message(cls, ex: Exception) -> str: @classmethod def extract_errors( - cls, ex: Exception, context: Optional[Dict[str, Any]] = None - ) -> List[SupersetError]: + cls, ex: Exception, context: dict[str, Any] | None = None + ) -> list[SupersetError]: raw_message = cls._extract_error_message(ex) context = context or {} @@ -1076,10 +1060,10 @@ def extract_errors( def adjust_engine_params( # pylint: disable=unused-argument cls, uri: URL, - connect_args: Dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: """ Return a new URL and ``connect_args`` for a specific catalog/schema. @@ -1116,7 +1100,7 @@ def get_catalog_names( # pylint: disable=unused-argument cls, database: Database, inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Get all catalogs from database. @@ -1126,7 +1110,7 @@ def get_catalog_names( # pylint: disable=unused-argument return [] @classmethod - def get_schema_names(cls, inspector: Inspector) -> List[str]: + def get_schema_names(cls, inspector: Inspector) -> list[str]: """ Get all schemas from database @@ -1140,8 +1124,8 @@ def get_table_names( # pylint: disable=unused-argument cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the real table names within the specified schema. @@ -1168,8 +1152,8 @@ def get_view_names( # pylint: disable=unused-argument cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the view names within the specified schema. @@ -1197,8 +1181,8 @@ def get_indexes( database: Database, # pylint: disable=unused-argument inspector: Inspector, table_name: str, - schema: Optional[str], - ) -> List[Dict[str, Any]]: + schema: str | None, + ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. @@ -1213,8 +1197,8 @@ def get_indexes( @classmethod def get_table_comment( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> Optional[str]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> str | None: """ Get comment of table from a given schema and table @@ -1237,8 +1221,8 @@ def get_table_comment( @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: """ Get all columns from a given schema and table @@ -1255,8 +1239,8 @@ def get_metrics( # pylint: disable=unused-argument database: Database, inspector: Inspector, table_name: str, - schema: Optional[str], - ) -> List[MetricType]: + schema: str | None, + ) -> list[MetricType]: """ Get all metrics from a given schema and table. """ @@ -1273,11 +1257,11 @@ def get_metrics( # pylint: disable=unused-argument def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument cls, table_name: str, - schema: Optional[str], + schema: str | None, database: Database, query: Select, - columns: Optional[List[Dict[str, Any]]] = None, - ) -> Optional[Select]: + columns: list[dict[str, Any]] | None = None, + ) -> Select | None: """ Add a where clause to a query to reference only the most recent partition @@ -1293,7 +1277,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen return None @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]: return [column(c["name"]) for c in cols] @classmethod @@ -1302,12 +1286,12 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals database: Database, table_name: str, engine: Engine, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: list[dict[str, Any]] | None = None, ) -> str: """ Generate a "SELECT * from [schema.]table_name" query with appropriate limit. @@ -1326,7 +1310,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals :return: SQL query """ # pylint: disable=redefined-outer-name - fields: Union[str, List[Any]] = "*" + fields: str | list[Any] = "*" cols = cols or [] if (show_cols or latest_partition) and not cols: cols = database.get_columns(table_name, schema) @@ -1355,7 +1339,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. @@ -1367,8 +1351,8 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: """ Format cost estimate. @@ -1405,8 +1389,8 @@ def estimate_query_cost( database: Database, schema: str, sql: str, - source: Optional[utils.QuerySource] = None, - ) -> List[Dict[str, Any]]: + source: utils.QuerySource | None = None, + ) -> list[dict[str, Any]]: """ Estimate the cost of a multiple statement SQL query. @@ -1433,7 +1417,7 @@ def estimate_query_cost( @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -1450,9 +1434,9 @@ def get_url_for_impersonation( @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -1490,7 +1474,7 @@ def execute( # pylint: disable=unused-argument raise cls.get_dbapi_mapped_exception(ex) from ex @classmethod - def make_label_compatible(cls, label: str) -> Union[str, quoted_name]: + def make_label_compatible(cls, label: str) -> str | quoted_name: """ Conditionally mutate and/or quote a sqlalchemy expression label. If force_column_alias_quotes is set to True, return the label as a @@ -1515,8 +1499,8 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]: @classmethod def get_column_types( cls, - column_type: Optional[str], - ) -> Optional[Tuple[TypeEngine, GenericDataType]]: + column_type: str | None, + ) -> tuple[TypeEngine, GenericDataType] | None: """ Return a sqlalchemy native column type and generic data type that corresponds to the column type defined in the data source (return None to use default type @@ -1598,7 +1582,7 @@ def column_datatype_to_string( def get_function_names( # pylint: disable=unused-argument cls, database: Database, - ) -> List[str]: + ) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -1609,7 +1593,7 @@ def get_function_names( # pylint: disable=unused-argument return [] @staticmethod - def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]: + def pyodbc_rows_to_tuples(data: list[Any]) -> list[tuple[Any, ...]]: """ Convert pyodbc.Row objects from `fetch_data` to tuples. @@ -1634,7 +1618,7 @@ def mutate_db_for_connection_test( # pylint: disable=unused-argument return None @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -1642,7 +1626,7 @@ def get_extra_params(database: Database) -> Dict[str, Any]: :param database: database instance from which to extract extras :raises CertificateException: If certificate is not valid/unparseable """ - extra: Dict[str, Any] = {} + extra: dict[str, Any] = {} if database.extra: try: extra = json.loads(database.extra) @@ -1653,7 +1637,7 @@ def get_extra_params(database: Database) -> Dict[str, Any]: @staticmethod def update_params_from_encrypted_extra( # pylint: disable=invalid-name - database: Database, params: Dict[str, Any] + database: Database, params: dict[str, Any] ) -> None: """ Some databases require some sensitive information which do not conform to @@ -1691,10 +1675,10 @@ def is_select_query(cls, parsed_query: ParsedQuery) -> bool: @classmethod def get_column_spec( # pylint: disable=unused-argument cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, + native_type: str | None, + db_extra: dict[str, Any] | None = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - ) -> Optional[ColumnSpec]: + ) -> ColumnSpec | None: """ Get generic type related specs regarding a native column type. @@ -1714,10 +1698,10 @@ def get_column_spec( # pylint: disable=unused-argument @classmethod def get_sqla_column_type( cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, + native_type: str | None, + db_extra: dict[str, Any] | None = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - ) -> Optional[TypeEngine]: + ) -> TypeEngine | None: """ Converts native database type to sqlalchemy column type. @@ -1761,7 +1745,7 @@ def get_cancel_query_id( # pylint: disable=unused-argument cls, cursor: Any, query: Query, - ) -> Optional[str]: + ) -> str | None: """ Select identifiers from the database engine that uniquely identifies the queries to cancel. The identifier is typically a session id, process id @@ -1794,11 +1778,11 @@ def cancel_query( # pylint: disable=unused-argument return False @classmethod - def parse_sql(cls, sql: str) -> List[str]: + def parse_sql(cls, sql: str) -> list[str]: return [str(s).strip(" ;") for s in sqlparse.parse(sql)] @classmethod - def get_impersonation_key(cls, user: Optional[User]) -> Any: + def get_impersonation_key(cls, user: User | None) -> Any: """ Construct an impersonation key, by default it's the given username. @@ -1809,7 +1793,7 @@ def get_impersonation_key(cls, user: Optional[User]) -> Any: return user.username if user else None @classmethod - def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]: + def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None: """ Mask ``encrypted_extra``. @@ -1822,9 +1806,7 @@ def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]: # pylint: disable=unused-argument @classmethod - def unmask_encrypted_extra( - cls, old: Optional[str], new: Optional[str] - ) -> Optional[str]: + def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None: """ Remove masks from ``encrypted_extra``. @@ -1835,7 +1817,7 @@ def unmask_encrypted_extra( return new @classmethod - def get_public_information(cls) -> Dict[str, Any]: + def get_public_information(cls) -> dict[str, Any]: """ Construct a Dict with properties we want to expose. @@ -1891,12 +1873,12 @@ class BasicParametersSchema(Schema): class BasicParametersType(TypedDict, total=False): - username: Optional[str] - password: Optional[str] + username: str | None + password: str | None host: str port: int database: str - query: Dict[str, Any] + query: dict[str, Any] encryption: bool @@ -1929,13 +1911,13 @@ class BasicParametersMixin: # query parameter to enable encryption in the database connection # for Postgres this would be `{"sslmode": "verify-ca"}`, eg. - encryption_parameters: Dict[str, str] = {} + encryption_parameters: dict[str, str] = {} @classmethod def build_sqlalchemy_uri( # pylint: disable=unused-argument cls, parameters: BasicParametersType, - encrypted_extra: Optional[Dict[str, str]] = None, + encrypted_extra: dict[str, str] | None = None, ) -> str: # make a copy so that we don't update the original query = parameters.get("query", {}).copy() @@ -1958,7 +1940,7 @@ def build_sqlalchemy_uri( # pylint: disable=unused-argument @classmethod def get_parameters_from_uri( # pylint: disable=unused-argument - cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None + cls, uri: str, encrypted_extra: dict[str, Any] | None = None ) -> BasicParametersType: url = make_url_safe(uri) query = { @@ -1982,14 +1964,14 @@ def get_parameters_from_uri( # pylint: disable=unused-argument @classmethod def validate_parameters( cls, properties: BasicPropertiesType - ) -> List[SupersetError]: + ) -> list[SupersetError]: """ Validates any number of parameters, for progressive validation. If only the hostname is present it will check if the name is resolvable. As more parameters are present in the request, more validation is done. """ - errors: List[SupersetError] = [] + errors: list[SupersetError] = [] required = {"host", "port", "username", "database"} parameters = properties.get("parameters", {}) diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 1f5068ad04bbd..3b62f4bbb809c 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -18,7 +18,8 @@ import re import urllib from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, Type, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING import pandas as pd from apispec import APISpec @@ -99,8 +100,8 @@ class BigQueryParametersSchema(Schema): class BigQueryParametersType(TypedDict): - credentials_info: Dict[str, Any] - query: Dict[str, Any] + credentials_info: dict[str, Any] + query: dict[str, Any] class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods @@ -173,7 +174,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met "P1Y": "{func}({col}, YEAR)", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_DATABASE_PERMISSIONS_REGEX: ( __( "Unable to connect. Verify that the following roles are set " @@ -219,7 +220,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -235,7 +236,7 @@ def convert_dttm( @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Support type BigQuery Row, introduced here PR #4071 # google.cloud.bigquery.table.Row @@ -280,7 +281,7 @@ def _truncate_label(cls, label: str) -> str: @classmethod @deprecated(deprecated_in="3.0") - def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -305,7 +306,7 @@ def get_indexes( inspector: Inspector, table_name: str, schema: Optional[str], - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. @@ -321,7 +322,7 @@ def get_indexes( @classmethod def extra_table_metadata( cls, database: "Database", table_name: str, schema_name: Optional[str] - ) -> Dict[str, Any]: + ) -> dict[str, Any]: indexes = database.get_indexes(table_name, schema_name) if not indexes: return {} @@ -354,7 +355,7 @@ def df_to_sql( database: "Database", table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. @@ -421,7 +422,7 @@ def estimate_query_cost( schema: str, sql: str, source: Optional[utils.QuerySource] = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Estimate the cost of a multiple statement SQL query. @@ -448,7 +449,7 @@ def get_catalog_names( cls, database: "Database", inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Get all catalogs. @@ -462,11 +463,11 @@ def get_catalog_names( return sorted(project.project_id for project in projects) @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: with cls.get_engine(cursor) as engine: client = cls._get_client(engine) job_config = bigquery.QueryJobConfig(dry_run=True) @@ -503,15 +504,15 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: return [{k: str(v) for k, v in row.items()} for row in raw_cost] @classmethod def build_sqlalchemy_uri( cls, parameters: BigQueryParametersType, - encrypted_extra: Optional[Dict[str, Any]] = None, + encrypted_extra: Optional[dict[str, Any]] = None, ) -> str: query = parameters.get("query", {}) query_params = urllib.parse.urlencode(query) @@ -533,7 +534,7 @@ def build_sqlalchemy_uri( def get_parameters_from_uri( cls, uri: str, - encrypted_extra: Optional[Dict[str, Any]] = None, + encrypted_extra: Optional[dict[str, Any]] = None, ) -> Any: value = make_url_safe(uri) @@ -592,7 +593,7 @@ def unmask_encrypted_extra( return json.dumps(new_config) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel from google.auth.exceptions import DefaultCredentialsError @@ -602,7 +603,7 @@ def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: def validate_parameters( cls, properties: BasicPropertiesType, # pylint: disable=unused-argument - ) -> List[SupersetError]: + ) -> list[SupersetError]: return [] @classmethod @@ -636,7 +637,7 @@ def select_star( # pylint: disable=too-many-arguments show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: Optional[list[dict[str, Any]]] = None, ) -> str: """ Remove array structures from `SELECT *`. @@ -699,7 +700,7 @@ def select_star( # pylint: disable=too-many-arguments ) @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]: """ Label columns using their fully qualified name. diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index a62087bc6a635..af38c15e0b378 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -19,7 +19,7 @@ import logging import re from datetime import datetime -from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from flask import current_app from flask_babel import gettext as __ @@ -124,8 +124,8 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -145,7 +145,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec): supports_file_upload = False @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return {NewConnectionError: SupersetDBAPIDatabaseError} @classmethod @@ -159,7 +159,7 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: @classmethod @cache_manager.cache.memoize() - def get_function_names(cls, database: Database) -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -256,7 +256,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): engine_name = "ClickHouse Connect (Superset)" default_driver = "connect" - _function_names: List[str] = [] + _function_names: list[str] = [] sqlalchemy_uri_placeholder = ( "clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]" @@ -265,7 +265,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): encryption_parameters = {"secure": "true"} @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return {} @classmethod @@ -278,7 +278,7 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: return new_exception(str(exception)) @classmethod - def get_function_names(cls, database: Database) -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: # pylint: disable=import-outside-toplevel,import-error from clickhouse_connect.driver.exceptions import ClickHouseError @@ -304,7 +304,7 @@ def get_datatype(cls, type_code: str) -> str: def build_sqlalchemy_uri( cls, parameters: BasicParametersType, - encrypted_extra: Optional[Dict[str, str]] = None, + encrypted_extra: dict[str, str] | None = None, ) -> str: url_params = parameters.copy() if url_params.get("encryption"): @@ -318,7 +318,7 @@ def build_sqlalchemy_uri( @classmethod def get_parameters_from_uri( - cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None + cls, uri: str, encrypted_extra: dict[str, Any] | None = None ) -> BasicParametersType: url = make_url_safe(uri) query = url.query @@ -340,7 +340,7 @@ def get_parameters_from_uri( @classmethod def validate_parameters( cls, properties: BasicPropertiesType - ) -> List[SupersetError]: + ) -> list[SupersetError]: # pylint: disable=import-outside-toplevel,import-error from clickhouse_connect.driver import default_port diff --git a/superset/db_engine_specs/crate.py b/superset/db_engine_specs/crate.py index 6eafae829edda..d8d91c67962d6 100644 --- a/superset/db_engine_specs/crate.py +++ b/superset/db_engine_specs/crate.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import types @@ -53,8 +53,8 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.TIMESTAMP): diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 5f12f3174d363..5df24be65d6b5 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -17,7 +17,7 @@ import json from datetime import datetime -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -135,7 +135,7 @@ class DatabricksODBCEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra) @@ -160,14 +160,14 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) encryption_parameters = {"ssl": "1"} @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: "Database") -> dict[str, Any]: """ Add a user agent to be used in the requests. Trim whitespace from connect_args to avoid databricks driver errors """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)]) connect_args.setdefault("_user_agent_entry", USER_AGENT) @@ -184,7 +184,7 @@ def get_table_names( database: "Database", inspector: Inspector, schema: Optional[str], - ) -> Set[str]: + ) -> set[str]: return super().get_table_names( database, inspector, schema ) - cls.get_view_names(database, inspector, schema) @@ -213,8 +213,8 @@ def build_sqlalchemy_uri( # type: ignore @classmethod def extract_errors( - cls, ex: Exception, context: Optional[Dict[str, Any]] = None - ) -> List[SupersetError]: + cls, ex: Exception, context: Optional[dict[str, Any]] = None + ) -> list[SupersetError]: raw_message = cls._extract_error_message(ex) context = context or {} @@ -271,8 +271,8 @@ def get_parameters_from_uri( # type: ignore def validate_parameters( # type: ignore cls, properties: DatabricksPropertiesType, - ) -> List[SupersetError]: - errors: List[SupersetError] = [] + ) -> list[SupersetError]: + errors: list[SupersetError] = [] required = {"access_token", "host", "port", "database", "extra"} extra = json.loads(properties.get("extra", "{}")) engine_params = extra.get("engine_params", {}) diff --git a/superset/db_engine_specs/dremio.py b/superset/db_engine_specs/dremio.py index 7fae3014d6866..7b4c0458cd1a7 100644 --- a/superset/db_engine_specs/dremio.py +++ b/superset/db_engine_specs/dremio.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -46,7 +46,7 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index 16ac89212ad13..946544863dda7 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from urllib import parse from sqlalchemy import types @@ -59,7 +59,7 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -74,10 +74,10 @@ def convert_dttm( def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe="")) @@ -87,7 +87,7 @@ def adjust_engine_params( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 83829ec22ac32..43ce310a4061e 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -20,7 +20,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector @@ -79,7 +79,7 @@ def alter_new_orm_column(cls, orm_col: TableColumn) -> None: orm_col.is_dttm = True @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ For Druid, the path to a SSL certificate is placed in `connect_args`. @@ -104,8 +104,8 @@ def get_extra_params(database: Database) -> Dict[str, Any]: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -130,15 +130,15 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: """ Update the Druid type map. """ return super().get_columns(inspector, table_name, schema) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index 1248287b8408d..3bbf9ecc3834d 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -18,7 +18,8 @@ import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy import types @@ -51,7 +52,7 @@ class DuckDBEngineSpec(BaseEngineSpec): "P1Y": "DATE_TRUNC('year', {col})", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __('We can\'t seem to resolve the column "%(column_name)s"'), SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, @@ -65,8 +66,8 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, (types.String, types.DateTime)): @@ -75,6 +76,6 @@ def convert_dttm( @classmethod def get_table_names( - cls, database: Database, inspector: Inspector, schema: Optional[str] - ) -> Set[str]: + cls, database: Database, inspector: Inspector, schema: str | None + ) -> set[str]: return set(inspector.get_table_names(schema)) diff --git a/superset/db_engine_specs/dynamodb.py b/superset/db_engine_specs/dynamodb.py index c398a9c1dff11..5f7a9e2b71e58 100644 --- a/superset/db_engine_specs/dynamodb.py +++ b/superset/db_engine_specs/dynamodb.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -55,7 +55,7 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/elasticsearch.py b/superset/db_engine_specs/elasticsearch.py index 934aa0bb03cf6..d717c52bf592a 100644 --- a/superset/db_engine_specs/elasticsearch.py +++ b/superset/db_engine_specs/elasticsearch.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from packaging.version import Version from sqlalchemy import types @@ -50,10 +50,10 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho "P1Y": "HISTOGRAM({col}, INTERVAL 1 YEAR)", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-error,import-outside-toplevel import es.exceptions as es_exceptions @@ -65,7 +65,7 @@ def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: db_extra = db_extra or {} @@ -117,7 +117,7 @@ class OpenDistroEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py index c06fbd826dfd3..6da56e2feee8f 100644 --- a/superset/db_engine_specs/exasol.py +++ b/superset/db_engine_specs/exasol.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Tuple +from typing import Any, Optional from superset.db_engine_specs.base import BaseEngineSpec @@ -42,7 +42,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py index 306a642dc3d11..4448074157073 100644 --- a/superset/db_engine_specs/firebird.py +++ b/superset/db_engine_specs/firebird.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -72,7 +72,7 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/firebolt.py b/superset/db_engine_specs/firebolt.py index 65cd7143523c8..ace3d6b3b232e 100644 --- a/superset/db_engine_specs/firebolt.py +++ b/superset/db_engine_specs/firebolt.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -43,7 +43,7 @@ class FireboltEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 73a66c464f8e5..abf5bac48f96a 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -16,7 +16,8 @@ # under the License. import json import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -56,12 +57,12 @@ class GSheetsParametersSchema(Schema): class GSheetsParametersType(TypedDict): service_account_info: str - catalog: Optional[Dict[str, str]] + catalog: Optional[dict[str, str]] class GSheetsPropertiesType(TypedDict): parameters: GSheetsParametersType - catalog: Dict[str, str] + catalog: dict[str, str] class GSheetsEngineSpec(SqliteEngineSpec): @@ -77,7 +78,7 @@ class GSheetsEngineSpec(SqliteEngineSpec): default_driver = "apsw" sqlalchemy_uri_placeholder = "gsheets://" - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { SYNTAX_ERROR_REGEX: ( __( 'Please check your query for syntax errors near "%(server_error)s". ' @@ -110,7 +111,7 @@ def extra_table_metadata( database: "Database", table_name: str, schema_name: Optional[str], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: with database.get_raw_connection(schema=schema_name) as conn: cursor = conn.cursor() cursor.execute(f'SELECT GET_METADATA("{table_name}")') @@ -127,7 +128,7 @@ def build_sqlalchemy_uri( cls, _: GSheetsParametersType, encrypted_extra: Optional[ # pylint: disable=unused-argument - Dict[str, Any] + dict[str, Any] ] = None, ) -> str: return "gsheets://" @@ -136,7 +137,7 @@ def build_sqlalchemy_uri( def get_parameters_from_uri( cls, uri: str, # pylint: disable=unused-argument - encrypted_extra: Optional[Dict[str, Any]] = None, + encrypted_extra: Optional[dict[str, Any]] = None, ) -> Any: # Building parameters from encrypted_extra and uri if encrypted_extra: @@ -214,8 +215,8 @@ def parameters_json_schema(cls) -> Any: def validate_parameters( cls, properties: GSheetsPropertiesType, - ) -> List[SupersetError]: - errors: List[SupersetError] = [] + ) -> list[SupersetError]: + errors: list[SupersetError] = [] # backwards compatible just incase people are send data # via parameters for validation diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py index e579550b2e2d1..108838f9d2a8d 100644 --- a/superset/db_engine_specs/hana.py +++ b/superset/db_engine_specs/hana.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -45,7 +45,7 @@ class HanaEngineSpec(PostgresBaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 6d8986c1c7bad..7601ebb2cddf5 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -22,7 +22,7 @@ import tempfile import time from datetime import datetime -from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from urllib import parse import numpy as np @@ -150,9 +150,7 @@ def patch(cls) -> None: hive.Cursor.fetch_logs = fetch_logs @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: # pylint: disable=import-outside-toplevel import pyhive from TCLIService import ttypes @@ -168,10 +166,10 @@ def fetch_data( @classmethod def df_to_sql( cls, - database: "Database", + database: Database, table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. @@ -248,8 +246,8 @@ def _get_hive_type(dtype: np.dtype[Any]) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -263,10 +261,10 @@ def convert_dttm( def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema, safe="")) @@ -276,8 +274,8 @@ def adjust_engine_params( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], - ) -> Optional[str]: + connect_args: dict[str, Any], + ) -> str | None: """ Return the configured schema. """ @@ -292,10 +290,10 @@ def _extract_error_message(cls, ex: Exception) -> str: return msg @classmethod - def progress(cls, log_lines: List[str]) -> int: + def progress(cls, log_lines: list[str]) -> int: total_jobs = 1 # assuming there's at least 1 job current_job = 1 - stages: Dict[int, float] = {} + stages: dict[int, float] = {} for line in log_lines: match = cls.jobs_stats_r.match(line) if match: @@ -323,7 +321,7 @@ def progress(cls, log_lines: List[str]) -> int: return int(progress) @classmethod - def get_tracking_url_from_logs(cls, log_lines: List[str]) -> Optional[str]: + def get_tracking_url_from_logs(cls, log_lines: list[str]) -> str | None: lkp = "Tracking URL = " for line in log_lines: if lkp in line: @@ -407,19 +405,19 @@ def handle_cursor( # pylint: disable=too-many-locals @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: return inspector.get_columns(table_name, schema) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments cls, table_name: str, - schema: Optional[str], - database: "Database", + schema: str | None, + database: Database, query: Select, - columns: Optional[List[Dict[str, Any]]] = None, - ) -> Optional[Select]: + columns: list[dict[str, Any]] | None = None, + ) -> Select | None: try: col_names, values = cls.latest_partition( table_name, schema, database, show_first=True @@ -437,18 +435,18 @@ def where_latest_partition( # pylint: disable=too-many-arguments return None @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]: return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access @classmethod def latest_sub_partition( # type: ignore - cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any + cls, table_name: str, schema: str | None, database: Database, **kwargs: Any ) -> str: # TODO(bogdan): implement` pass @classmethod - def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: + def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: """Hive partitions look like ds={partition name}/ds={partition name}""" if not df.empty: return [ @@ -461,12 +459,12 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: def _partition_query( # pylint: disable=too-many-arguments cls, table_name: str, - schema: Optional[str], - indexes: List[Dict[str, Any]], - database: "Database", + schema: str | None, + indexes: list[dict[str, Any]], + database: Database, limit: int = 0, - order_by: Optional[List[Tuple[str, bool]]] = None, - filters: Optional[Dict[Any, Any]] = None, + order_by: list[tuple[str, bool]] | None = None, + filters: dict[Any, Any] | None = None, ) -> str: full_table_name = f"{schema}.{table_name}" if schema else table_name return f"SHOW PARTITIONS {full_table_name}" @@ -474,15 +472,15 @@ def _partition_query( # pylint: disable=too-many-arguments @classmethod def select_star( # pylint: disable=too-many-arguments cls, - database: "Database", + database: Database, table_name: str, engine: Engine, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: list[dict[str, Any]] | None = None, ) -> str: return super( # pylint: disable=bad-super-call PrestoEngineSpec, cls @@ -500,7 +498,7 @@ def select_star( # pylint: disable=too-many-arguments @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -516,9 +514,9 @@ def get_url_for_impersonation( @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -549,7 +547,7 @@ def execute( # type: ignore @classmethod @cache_manager.cache.memoize() - def get_function_names(cls, database: "Database") -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -600,10 +598,10 @@ def has_implicit_cancel(cls) -> bool: @classmethod def get_view_names( cls, - database: "Database", + database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the view names within the specified schema. @@ -635,9 +633,9 @@ def get_view_names( # TODO: contribute back to pyhive. def fetch_logs( # pylint: disable=protected-access - self: "Cursor", + self: Cursor, _max_rows: int = 1024, - orientation: Optional["TFetchOrientation"] = None, + orientation: TFetchOrientation | None = None, ) -> str: """Mocked. Retrieve the logs produced by the execution of the query. Can be called multiple times to fetch the logs produced after diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index e59c2b74fbee0..cd1c9e47329e2 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -18,7 +18,7 @@ import re import time from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import current_app from sqlalchemy import types @@ -57,7 +57,7 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -68,7 +68,7 @@ def convert_dttm( return None @classmethod - def get_schema_names(cls, inspector: Inspector) -> List[str]: + def get_schema_names(cls, inspector: Inspector) -> list[str]: schemas = [ row[0] for row in inspector.engine.execute("SHOW SCHEMAS") diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 9fddb23d26185..17147d5cc059f 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from sqlalchemy import types from sqlalchemy.dialects.mssql.base import SMALLDATETIME @@ -61,7 +61,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method " DATEDIFF(week, 0, DATEADD(day, -1, {col})), 0)", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed column_type_mappings = ( ( @@ -72,7 +72,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method ) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel,import-error import sqlalchemy_kusto.errors as kusto_exceptions @@ -84,7 +84,7 @@ def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -128,10 +128,10 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "P1Y": "datetime_diff('year',CreateDate, datetime(0001-01-01 00:00:00))+1", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel,import-error import sqlalchemy_kusto.errors as kusto_exceptions @@ -143,7 +143,7 @@ def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -168,7 +168,7 @@ def is_select_query(cls, parsed_query: ParsedQuery) -> bool: return not parsed_query.sql.startswith(".") @classmethod - def parse_sql(cls, sql: str) -> List[str]: + def parse_sql(cls, sql: str) -> list[str]: """ Kusto supports a single query statement, but it could include sub queries and variables declared via let keyword. diff --git a/superset/db_engine_specs/kylin.py b/superset/db_engine_specs/kylin.py index e340daea51f95..f522602a48e7a 100644 --- a/superset/db_engine_specs/kylin.py +++ b/superset/db_engine_specs/kylin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -42,7 +42,7 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 8b38ec742190f..3e0879b90415c 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -17,7 +17,8 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional from flask_babel import gettext as __ from sqlalchemy import types @@ -80,7 +81,7 @@ class MssqlEngineSpec(BaseEngineSpec): ), ) - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __( 'Either the username "%(username)s", password, ' @@ -115,7 +116,7 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -132,7 +133,7 @@ def convert_dttm( @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 6258f6b21a4c6..9f853d577c30b 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional from urllib import parse from flask_babel import gettext as __ @@ -143,9 +144,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): "INTERVAL 1 DAY)) - 1 DAY))", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -186,7 +187,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -201,10 +202,10 @@ def convert_dttm( def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: uri, new_connect_args = super().adjust_engine_params( uri, connect_args, @@ -221,7 +222,7 @@ def adjust_engine_params( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index 4b8a59117e956..59fa52a656a7e 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -17,7 +17,8 @@ import re import threading -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Pattern, Set, Tuple +from re import Pattern +from typing import Any, Callable, List, NamedTuple, Optional from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector @@ -178,15 +179,13 @@ def _polygon_to_geo_json( # Sanitization function for column values SanitizeFunc = Callable[[Any], Any] + # Represents a pair of a column index and the sanitization function # to apply to its values. -PlacedSanitizeFunc = NamedTuple( - "PlacedSanitizeFunc", - [ - ("column_index", int), - ("sanitize_func", SanitizeFunc), - ], -) +class PlacedSanitizeFunc(NamedTuple): + column_index: int + sanitize_func: SanitizeFunc + # This map contains functions used to sanitize values for column types # that cannot be processed natively by Superset. @@ -199,7 +198,7 @@ def _polygon_to_geo_json( try: from pyocient import TypeCodes - _sanitized_ocient_type_codes: Dict[int, SanitizeFunc] = { + _sanitized_ocient_type_codes: dict[int, SanitizeFunc] = { TypeCodes.BINARY: _to_hex, TypeCodes.ST_POINT: _point_to_geo_json, TypeCodes.IP: str, @@ -211,7 +210,7 @@ def _polygon_to_geo_json( _sanitized_ocient_type_codes = {} -def _find_columns_to_sanitize(cursor: Any) -> List[PlacedSanitizeFunc]: +def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]: """ Cleans the column value for consumption by Superset. @@ -238,10 +237,10 @@ class OcientEngineSpec(BaseEngineSpec): # Store mapping of superset Query id -> Ocient ID # These are inserted into the cache when executing the query # They are then removed, either upon cancellation or query completion - query_id_mapping: Dict[str, str] = dict() + query_id_mapping: dict[str, str] = dict() query_id_mapping_lock = threading.Lock() - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_INVALID_USERNAME_REGEX: ( __('The username "%(username)s" does not exist.'), SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR, @@ -309,15 +308,15 @@ class OcientEngineSpec(BaseEngineSpec): @classmethod def get_table_names( cls, database: Database, inspector: Inspector, schema: Optional[str] - ) -> Set[str]: + ) -> set[str]: return inspector.get_table_names(schema) @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: try: - rows: List[Tuple[Any, ...]] = super().fetch_data(cursor, limit) + rows: list[tuple[Any, ...]] = super().fetch_data(cursor, limit) except Exception as exception: with OcientEngineSpec.query_id_mapping_lock: del OcientEngineSpec.query_id_mapping[ @@ -329,7 +328,7 @@ def fetch_data( if len(rows) > 0 and type(rows[0]).__name__ == "Row": # Peek at the schema to determine which column values, if any, # require sanitization. - columns_to_sanitize: List[PlacedSanitizeFunc] = _find_columns_to_sanitize( + columns_to_sanitize: list[PlacedSanitizeFunc] = _find_columns_to_sanitize( cursor ) @@ -341,7 +340,7 @@ def identity(x: Any) -> Any: # Use the identity function if the column type doesn't need to be # sanitized. - sanitization_functions: List[SanitizeFunc] = [ + sanitization_functions: list[SanitizeFunc] = [ identity for _ in range(len(cursor.description)) ] for info in columns_to_sanitize: diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py index 4a219919bb537..1199b74406d2a 100644 --- a/superset/db_engine_specs/oracle.py +++ b/superset/db_engine_specs/oracle.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from sqlalchemy import types @@ -43,7 +43,7 @@ class OracleEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -68,7 +68,7 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: """ :param cursor: Cursor instance :param limit: Maximum number of rows to be returned by the cursor diff --git a/superset/db_engine_specs/pinot.py b/superset/db_engine_specs/pinot.py index cebdd693a4c7a..bfec8b294716d 100644 --- a/superset/db_engine_specs/pinot.py +++ b/superset/db_engine_specs/pinot.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional +from typing import Optional from sqlalchemy.sql.expression import ColumnClause @@ -30,7 +30,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method allows_alias_in_orderby = False # Pinot does its own conversion below - _time_grain_expressions: Dict[Optional[str], str] = { + _time_grain_expressions: dict[Optional[str], str] = { "PT1S": "1:SECONDS", "PT1M": "1:MINUTES", "PT5M": "5:MINUTES", @@ -45,7 +45,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "P1Y": "year", } - _python_to_java_time_patterns: Dict[str, str] = { + _python_to_java_time_patterns: dict[str, str] = { "%Y": "yyyy", "%m": "MM", "%d": "dd", @@ -54,7 +54,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "%S": "ss", } - _use_date_trunc_function: Dict[str, bool] = { + _use_date_trunc_function: dict[str, bool] = { "PT1S": False, "PT1M": False, "PT5M": False, diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index e809187af66a9..2088782f83bae 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -18,7 +18,8 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON @@ -73,7 +74,7 @@ SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P.*?)"') -def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]: +def parse_options(connect_args: dict[str, Any]) -> dict[str, str]: """ Parse ``options`` from ``connect_args`` into a dictionary. """ @@ -109,7 +110,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): "P1Y": "DATE_TRUNC('year', {col})", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_INVALID_USERNAME_REGEX: ( __('The username "%(username)s" does not exist.'), SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR, @@ -169,7 +170,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: if not cursor.description: return [] return super().fetch_data(cursor, limit) @@ -221,7 +222,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. @@ -253,10 +254,10 @@ def get_schema_from_engine_params( def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: if not schema: return uri, connect_args @@ -269,11 +270,11 @@ def adjust_engine_params( return uri, connect_args @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: sql = f"EXPLAIN {statement}" cursor.execute(sql) @@ -289,8 +290,8 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: return [{k: str(v) for k, v in row.items()} for row in raw_cost] @classmethod @@ -298,7 +299,7 @@ def get_catalog_names( cls, database: "Database", inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Return all catalogs. @@ -317,7 +318,7 @@ def get_catalog_names( @classmethod def get_table_names( cls, database: "Database", inspector: PGInspector, schema: Optional[str] - ) -> Set[str]: + ) -> set[str]: """Need to consider foreign tables for PostgreSQL""" return set(inspector.get_table_names(schema)) | set( inspector.get_foreign_table_names(schema) @@ -325,7 +326,7 @@ def get_table_names( @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -337,7 +338,7 @@ def convert_dttm( return None @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: "Database") -> dict[str, Any]: """ For Postgres, the path to a SSL certificate is placed in `connect_args`. diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 82b05e53e319d..d5a2ab7605517 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -23,19 +23,9 @@ from abc import ABCMeta from collections import defaultdict, deque from datetime import datetime +from re import Pattern from textwrap import dedent -from typing import ( - Any, - cast, - Dict, - List, - Optional, - Pattern, - Set, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, cast, Optional, TYPE_CHECKING from urllib import parse import pandas as pd @@ -78,7 +68,7 @@ # need try/catch because pyhive may not be installed try: - from pyhive.presto import Cursor # pylint: disable=unused-import + from pyhive.presto import Cursor except ImportError: pass @@ -107,7 +97,7 @@ logger = logging.getLogger(__name__) -def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]: +def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]: """ Get the children of a complex Presto type (row or array). @@ -276,8 +266,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: """ Convert a Python `datetime` object to a SQL expression. :param target_type: The target type of expression @@ -304,10 +294,10 @@ def epoch_to_dttm(cls) -> str: def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: database = uri.database if schema and database: schema = parse.quote(schema, safe="") @@ -323,8 +313,8 @@ def adjust_engine_params( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], - ) -> Optional[str]: + connect_args: dict[str, Any], + ) -> str | None: """ Return the configured schema. @@ -341,7 +331,7 @@ def get_schema_from_engine_params( return parse.unquote(database.split("/")[1]) @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. :param statement: A single SQL statement @@ -369,8 +359,8 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: """ Format cost estimate. :param raw_cost: JSON estimate from Trino @@ -401,7 +391,7 @@ def humanize(value: Any, suffix: str) -> str: ("networkCost", "Network cost", ""), ] for row in raw_cost: - estimate: Dict[str, float] = row.get("estimate", {}) + estimate: dict[str, float] = row.get("estimate", {}) statement_cost = {} for key, label, suffix in columns: if key in estimate: @@ -412,7 +402,7 @@ def humanize(value: Any, suffix: str) -> str: @classmethod @cache_manager.data_cache.memoize() - def get_function_names(cls, database: Database) -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -426,12 +416,12 @@ def get_function_names(cls, database: Database) -> List[str]: def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument cls, table_name: str, - schema: Optional[str], - indexes: List[Dict[str, Any]], + schema: str | None, + indexes: list[dict[str, Any]], database: Database, limit: int = 0, - order_by: Optional[List[Tuple[str, bool]]] = None, - filters: Optional[Dict[Any, Any]] = None, + order_by: list[tuple[str, bool]] | None = None, + filters: dict[Any, Any] | None = None, ) -> str: """ Return a partition query. @@ -449,7 +439,7 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus order :param filters: dict of field name and filter value combinations """ - limit_clause = "LIMIT {}".format(limit) if limit else "" + limit_clause = f"LIMIT {limit}" if limit else "" order_by_clause = "" if order_by: l = [] @@ -492,11 +482,11 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus def where_latest_partition( # pylint: disable=too-many-arguments cls, table_name: str, - schema: Optional[str], + schema: str | None, database: Database, query: Select, - columns: Optional[List[Dict[str, Any]]] = None, - ) -> Optional[Select]: + columns: list[dict[str, Any]] | None = None, + ) -> Select | None: try: col_names, values = cls.latest_partition( table_name, schema, database, show_first=True @@ -525,7 +515,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments return query @classmethod - def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: + def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: if not df.empty: return df.to_records(index=False)[0].item() return None @@ -535,10 +525,10 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: def latest_partition( cls, table_name: str, - schema: Optional[str], + schema: str | None, database: Database, show_first: bool = False, - ) -> Tuple[List[str], Optional[List[str]]]: + ) -> tuple[list[str], list[str] | None]: """Returns col name and the latest (max) partition value for a table :param table_name: the name of the table @@ -589,7 +579,7 @@ def latest_partition( @classmethod def latest_sub_partition( - cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any + cls, table_name: str, schema: str | None, database: Database, **kwargs: Any ) -> Any: """Returns the latest (max) partition value for a table @@ -652,7 +642,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): engine_name = "Presto" allows_alias_to_source_column = False - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( 'We can\'t seem to resolve the column "%(column_name)s" at ' @@ -708,16 +698,16 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): } @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: version = extra.get("version") return version is not None and Version(version) >= Version("0.319") @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -741,8 +731,8 @@ def get_table_names( cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the real table names within the specified schema. @@ -769,8 +759,8 @@ def get_view_names( cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the view names within the specified schema. @@ -817,7 +807,7 @@ def get_catalog_names( cls, database: Database, inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Get all catalogs. """ @@ -826,7 +816,7 @@ def get_catalog_names( @classmethod def _create_column_info( cls, name: str, data_type: types.TypeEngine - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create column info object :param name: column name @@ -836,7 +826,7 @@ def _create_column_info( return {"name": name, "type": f"{data_type}"} @classmethod - def _get_full_name(cls, names: List[Tuple[str, str]]) -> str: + def _get_full_name(cls, names: list[tuple[str, str]]) -> str: """ Get the full column name :param names: list of all individual column names @@ -860,7 +850,7 @@ def _has_nested_data_types(cls, component_type: str) -> bool: ) @classmethod - def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]: + def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]: """ Split data type based on given delimiter. Do not split the string if the delimiter is enclosed in quotes @@ -869,16 +859,14 @@ def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]: comma, whitespace) :return: list of strings after breaking it by the delimiter """ - return re.split( - r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type - ) + return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type) @classmethod def _parse_structural_column( # pylint: disable=too-many-locals cls, parent_column_name: str, parent_data_type: str, - result: List[Dict[str, Any]], + result: list[dict[str, Any]], ) -> None: """ Parse a row or array column @@ -893,7 +881,7 @@ def _parse_structural_column( # pylint: disable=too-many-locals # split on open parenthesis ( to get the structural # data type and its component types data_types = cls._split_data_type(full_data_type, r"\(") - stack: List[Tuple[str, str]] = [] + stack: list[tuple[str, str]] = [] for data_type in data_types: # split on closed parenthesis ) to track which component # types belong to what structural data type @@ -962,8 +950,8 @@ def _parse_structural_column( # pylint: disable=too-many-locals @classmethod def _show_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[ResultRow]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[ResultRow]: """ Show presto column names :param inspector: object that performs database schema inspection @@ -974,13 +962,13 @@ def _show_columns( quote = inspector.engine.dialect.identifier_preparer.quote_identifier full_table = quote(table_name) if schema: - full_table = "{}.{}".format(quote(schema), full_table) + full_table = f"{quote(schema)}.{full_table}" return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: """ Get columns from a Presto data source. This includes handling row and array data types @@ -991,7 +979,7 @@ def get_columns( (i.e. column name and data type) """ columns = cls._show_columns(inspector, table_name, schema) - result: List[Dict[str, Any]] = [] + result: list[dict[str, Any]] = [] for column in columns: # parse column if it is a row or array if is_feature_enabled("PRESTO_EXPAND_DATA") and ( @@ -1031,7 +1019,7 @@ def _is_column_name_quoted(cls, column_name: str) -> bool: return column_name.startswith('"') and column_name.endswith('"') @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]: """ Format column clauses where names are in quotes and labels are specified :param cols: columns @@ -1053,7 +1041,7 @@ def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: # quote each column name if it is not already quoted for index, col_name in enumerate(col_names): if not cls._is_column_name_quoted(col_name): - col_names[index] = '"{}"'.format(col_name) + col_names[index] = f'"{col_name}"' quoted_col_name = ".".join( col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"' for col_name in col_names @@ -1069,12 +1057,12 @@ def select_star( # pylint: disable=too-many-arguments database: Database, table_name: str, engine: Engine, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: list[dict[str, Any]] | None = None, ) -> str: """ Include selecting properties of row objects. We cannot easily break arrays into @@ -1102,9 +1090,9 @@ def select_star( # pylint: disable=too-many-arguments @classmethod def expand_data( # pylint: disable=too-many-locals - cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]] - ) -> Tuple[ - List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType] + cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]] + ) -> tuple[ + list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType] ]: """ We do not immediately display rows and arrays clearly in the data grid. This @@ -1133,7 +1121,7 @@ def expand_data( # pylint: disable=too-many-locals # process each column, unnesting ARRAY types and # expanding ROW types into new columns to_process = deque((column, 0) for column in columns) - all_columns: List[ResultSetColumnType] = [] + all_columns: list[ResultSetColumnType] = [] expanded_columns = [] current_array_level = None while to_process: @@ -1147,11 +1135,11 @@ def expand_data( # pylint: disable=too-many-locals # added by the first. every time we change a level in the nested arrays # we reinitialize this. if level != current_array_level: - unnested_rows: Dict[int, int] = defaultdict(int) + unnested_rows: dict[int, int] = defaultdict(int) current_array_level = level name = column["name"] - values: Optional[Union[str, List[Any]]] + values: str | list[Any] | None if column["type"] and column["type"].startswith("ARRAY("): # keep processing array children; we append to the right so that @@ -1198,7 +1186,7 @@ def expand_data( # pylint: disable=too-many-locals for row in data: values = row.get(name) or [] if isinstance(values, str): - values = cast(Optional[List[Any]], destringify(values)) + values = cast(Optional[list[Any]], destringify(values)) row[name] = values for value, col in zip(values or [], expanded): row[col["name"]] = value @@ -1211,8 +1199,8 @@ def expand_data( # pylint: disable=too-many-locals @classmethod def extra_table_metadata( - cls, database: Database, table_name: str, schema_name: Optional[str] - ) -> Dict[str, Any]: + cls, database: Database, table_name: str, schema_name: str | None + ) -> dict[str, Any]: metadata = {} if indexes := database.get_indexes(table_name, schema_name): @@ -1243,8 +1231,8 @@ def extra_table_metadata( @classmethod def get_create_view( - cls, database: Database, schema: Optional[str], table: str - ) -> Optional[str]: + cls, database: Database, schema: str | None, table: str + ) -> str | None: """ Return a CREATE VIEW statement, or `None` if not a view. @@ -1267,7 +1255,7 @@ def get_create_view( return rows[0][0] @classmethod - def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: + def get_tracking_url(cls, cursor: Cursor) -> str | None: try: if cursor.last_query_id: # pylint: disable=protected-access, line-too-long @@ -1277,7 +1265,7 @@ def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: return None @classmethod - def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: """Updates progress information""" if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index 27b749e418cfa..2e746a6349365 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -16,7 +16,8 @@ # under the License. import logging import re -from typing import Any, Dict, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional import pandas as pd from flask_babel import gettext as __ @@ -66,7 +67,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): encryption_parameters = {"sslmode": "verify-ca"} - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -106,7 +107,7 @@ def df_to_sql( database: Database, table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. diff --git a/superset/db_engine_specs/rockset.py b/superset/db_engine_specs/rockset.py index cc215054be5f7..71adca0b10ba7 100644 --- a/superset/db_engine_specs/rockset.py +++ b/superset/db_engine_specs/rockset.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from sqlalchemy import types @@ -51,7 +51,7 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 69ccf55931922..32ade649b0af3 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -18,7 +18,8 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from urllib import parse from apispec import APISpec @@ -107,7 +108,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): "P1Y": "DATE_TRUNC('YEAR', {col})", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { OBJECT_DOES_NOT_EXIST_REGEX: ( __("%(object)s does not exist in this database."), SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR, @@ -124,13 +125,13 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): } @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: "Database") -> dict[str, Any]: """ Add a user agent to be used in the requests. """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("application", USER_AGENT) @@ -140,10 +141,10 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: database = uri.database if "/" in database: database = database.split("/")[0] @@ -157,7 +158,7 @@ def adjust_engine_params( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. @@ -174,7 +175,7 @@ def get_catalog_names( cls, database: "Database", inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Return all catalogs. @@ -197,7 +198,7 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -261,7 +262,7 @@ def build_sqlalchemy_uri( cls, parameters: SnowflakeParametersType, encrypted_extra: Optional[ # pylint: disable=unused-argument - Dict[str, Any] + dict[str, Any] ] = None, ) -> str: return str( @@ -283,7 +284,7 @@ def get_parameters_from_uri( cls, uri: str, encrypted_extra: Optional[ # pylint: disable=unused-argument - Dict[str, str] + dict[str, str] ] = None, ) -> Any: url = make_url_safe(uri) @@ -300,8 +301,8 @@ def get_parameters_from_uri( @classmethod def validate_parameters( cls, properties: BasicPropertiesType - ) -> List[SupersetError]: - errors: List[SupersetError] = [] + ) -> list[SupersetError]: + errors: list[SupersetError] = [] required = { "warehouse", "username", @@ -346,7 +347,7 @@ def parameters_json_schema(cls) -> Any: @staticmethod def update_params_from_encrypted_extra( database: "Database", - params: Dict[str, Any], + params: dict[str, Any], ) -> None: if not database.encrypted_extra: return diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index a414143296338..767d0a20ad6ca 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy import types @@ -60,7 +61,7 @@ class SqliteEngineSpec(BaseEngineSpec): ), } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __('We can\'t seem to resolve the column "%(column_name)s"'), SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, @@ -74,7 +75,7 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, (types.String, types.DateTime)): @@ -84,6 +85,6 @@ def convert_dttm( @classmethod def get_table_names( cls, database: "Database", inspector: Inspector, schema: Optional[str] - ) -> Set[str]: + ) -> set[str]: """Need to disregard the schema for Sqlite""" return set(inspector.get_table_names()) diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py index f687fdbdb31bc..63269439af60c 100644 --- a/superset/db_engine_specs/starrocks.py +++ b/superset/db_engine_specs/starrocks.py @@ -17,7 +17,8 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, Type +from re import Pattern +from typing import Any, Optional from urllib import parse from flask_babel import gettext as __ @@ -40,11 +41,11 @@ logger = logging.getLogger(__name__) -class TINYINT(Integer): # pylint: disable=no-init +class TINYINT(Integer): __visit_name__ = "TINYINT" -class DOUBLE(Numeric): # pylint: disable=no-init +class DOUBLE(Numeric): __visit_name__ = "DOUBLE" @@ -52,7 +53,7 @@ class ARRAY(TypeEngine): # pylint: disable=no-init __visit_name__ = "ARRAY" @property - def python_type(self) -> Optional[Type[List[Any]]]: + def python_type(self) -> Optional[type[list[Any]]]: return list @@ -60,7 +61,7 @@ class MAP(TypeEngine): # pylint: disable=no-init __visit_name__ = "MAP" @property - def python_type(self) -> Optional[Type[Dict[Any, Any]]]: + def python_type(self) -> Optional[type[dict[Any, Any]]]: return dict @@ -68,7 +69,7 @@ class STRUCT(TypeEngine): # pylint: disable=no-init __visit_name__ = "STRUCT" @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[type[Any]]: return None @@ -117,7 +118,7 @@ class StarRocksEngineSpec(MySQLEngineSpec): (re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING), ) - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -134,10 +135,10 @@ class StarRocksEngineSpec(MySQLEngineSpec): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: database = uri.database if schema and database: schema = parse.quote(schema, safe="") @@ -152,9 +153,9 @@ def adjust_engine_params( @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: columns = cls._show_columns(inspector, table_name, schema) - result: List[Dict[str, Any]] = [] + result: list[dict[str, Any]] = [] for column in columns: column_spec = cls.get_column_spec(column.Type) column_type = column_spec.sqla_type if column_spec else None @@ -174,7 +175,7 @@ def get_columns( @classmethod def _show_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[ResultRow]: + ) -> list[ResultRow]: """ Show starrocks column names :param inspector: object that performs database schema inspection @@ -185,13 +186,13 @@ def _show_columns( quote = inspector.engine.dialect.identifier_preparer.quote_identifier full_table = quote(table_name) if schema: - full_table = "{}.{}".format(quote(schema), full_table) + full_table = f"{quote(schema)}.{full_table}" return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() @classmethod def _create_column_info( cls, name: str, data_type: types.TypeEngine - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create column info object :param name: column name @@ -204,7 +205,7 @@ def _create_column_info( def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 0fa4d05cbce0d..f05bd67ec35ab 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import simplejson as json from flask import current_app @@ -53,8 +53,8 @@ def extra_table_metadata( cls, database: Database, table_name: str, - schema_name: Optional[str], - ) -> Dict[str, Any]: + schema_name: str | None, + ) -> dict[str, Any]: metadata = {} if indexes := database.get_indexes(table_name, schema_name): @@ -68,12 +68,12 @@ def extra_table_metadata( metadata["partitions"] = { "cols": sorted( list( - set( + { column_name for index in indexes if index.get("name") == "partition" for column_name in index.get("column_names", []) - ) + } ) ), "latest": dict(zip(col_names, latest_parts)), @@ -95,9 +95,9 @@ def extra_table_metadata( @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -118,7 +118,7 @@ def update_impersonation_config( @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -131,11 +131,11 @@ def get_url_for_impersonation( return url @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def get_tracking_url(cls, cursor: Cursor) -> Optional[str]: + def get_tracking_url(cls, cursor: Cursor) -> str | None: try: return cursor.info_uri except AttributeError: @@ -199,7 +199,7 @@ def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: return True @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -207,9 +207,9 @@ def get_extra_params(database: Database) -> Dict[str, Any]: :param database: database instance from which to extract extras :raises CertificateException: If certificate is not valid/unparseable """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("source", USER_AGENT) @@ -222,7 +222,7 @@ def get_extra_params(database: Database) -> Dict[str, Any]: @staticmethod def update_params_from_encrypted_extra( database: Database, - params: Dict[str, Any], + params: dict[str, Any], ) -> None: if not database.encrypted_extra: return @@ -262,7 +262,7 @@ def update_params_from_encrypted_extra( raise ex @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions diff --git a/superset/embedded/dao.py b/superset/embedded/dao.py index 957a7242a77d3..27ca3385023be 100644 --- a/superset/embedded/dao.py +++ b/superset/embedded/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from superset.dao.base import BaseDAO from superset.extensions import db @@ -31,7 +31,7 @@ class EmbeddedDAO(BaseDAO): id_column_name = "uuid" @staticmethod - def upsert(dashboard: Dashboard, allowed_domains: List[str]) -> EmbeddedDashboard: + def upsert(dashboard: Dashboard, allowed_domains: list[str]) -> EmbeddedDashboard: """ Sets up a dashboard to be embeddable. Upsert is used to preserve the embedded_dashboard uuid across updates. @@ -45,7 +45,7 @@ def upsert(dashboard: Dashboard, allowed_domains: List[str]) -> EmbeddedDashboar return embedded @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> Any: + def create(cls, properties: dict[str, Any], commit: bool = True) -> Any: """ Use EmbeddedDAO.upsert() instead. At least, until we are ok with more than one embedded instance per dashboard. diff --git a/superset/errors.py b/superset/errors.py index 5261848687f2f..6f68f2466c456 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -16,7 +16,7 @@ # under the License. from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_babel import lazy_gettext as _ @@ -204,7 +204,7 @@ class SupersetError: message: str error_type: SupersetErrorType level: ErrorLevel - extra: Optional[Dict[str, Any]] = None + extra: Optional[dict[str, Any]] = None def __post_init__(self) -> None: """ @@ -227,7 +227,7 @@ def __post_init__(self) -> None: } ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: rv = {"message": self.message, "error_type": self.error_type} if self.extra: rv["extra"] = self.extra # type: ignore diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 5d167b02d0627..e18f6e4632a01 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -55,7 +55,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: index=False, ) - print("Creating table {} reference".format(tbl_name)) + print(f"Creating table {tbl_name} reference") table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() if not tbl: diff --git a/superset/examples/big_data.py b/superset/examples/big_data.py index 8c0f2e267c555..ed738d2c9682a 100644 --- a/superset/examples/big_data.py +++ b/superset/examples/big_data.py @@ -16,7 +16,6 @@ # under the License. import random import string -from typing import List import sqlalchemy.sql.sqltypes @@ -36,7 +35,7 @@ def load_big_data() -> None: print("Creating table `wide_table` with 100 columns") - columns: List[ColumnInfo] = [] + columns: list[ColumnInfo] = [] for i in range(100): column: ColumnInfo = { "name": f"col{i}", diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 8da041550e92a..45a3b39eb337f 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -16,7 +16,7 @@ # under the License. import json import textwrap -from typing import Dict, List, Tuple, Union +from typing import Union import pandas as pd from sqlalchemy import DateTime, inspect, String @@ -42,7 +42,7 @@ def gen_filter( subject: str, comparator: str, operator: str = "==" -) -> Dict[str, Union[bool, str]]: +) -> dict[str, Union[bool, str]]: return { "clause": "WHERE", "comparator": comparator, @@ -152,7 +152,7 @@ def _add_table_metrics(datasource: SqlaTable) -> None: datasource.metrics = metrics -def create_slices(tbl: SqlaTable) -> Tuple[List[Slice], List[Slice]]: +def create_slices(tbl: SqlaTable) -> tuple[list[Slice], list[Slice]]: metrics = [ { "expressionType": "SIMPLE", @@ -529,7 +529,7 @@ def create_slices(tbl: SqlaTable) -> Tuple[List[Slice], List[Slice]]: return slices, misc_slices -def create_dashboard(slices: List[Slice]) -> Dashboard: +def create_dashboard(slices: list[Slice]) -> Dashboard: print("Creating a dashboard") dash = db.session.query(Dashboard).filter_by(slug="births").first() if not dash: diff --git a/superset/examples/countries.py b/superset/examples/countries.py index 8f1d5466ae22a..2ea12baae73dd 100644 --- a/superset/examples/countries.py +++ b/superset/examples/countries.py @@ -16,9 +16,9 @@ # under the License. """This module contains data related to countries and is used for geo mapping""" # pylint: disable=too-many-lines -from typing import Any, Dict, List, Optional +from typing import Any, Optional -countries: List[Dict[str, Any]] = [ +countries: list[dict[str, Any]] = [ { "name": "Angola", "area": 1246700, @@ -2491,7 +2491,7 @@ }, ] -all_lookups: Dict[str, Dict[str, Dict[str, Any]]] = {} +all_lookups: dict[str, dict[str, dict[str, Any]]] = {} lookups = ["cioc", "cca2", "cca3", "name"] for lookup in lookups: all_lookups[lookup] = {} @@ -2499,7 +2499,7 @@ all_lookups[lookup][country[lookup].lower()] = country -def get(field: str, symbol: str) -> Optional[Dict[str, Any]]: +def get(field: str, symbol: str) -> Optional[dict[str, Any]]: """ Get country data based on a standard code and a symbol """ diff --git a/superset/examples/helpers.py b/superset/examples/helpers.py index e26e05e49739a..9f893f1ccca36 100644 --- a/superset/examples/helpers.py +++ b/superset/examples/helpers.py @@ -17,7 +17,7 @@ """Loads datasets, dashboards and slices in a new superset instance""" import json import os -from typing import Any, Dict, List, Set +from typing import Any from superset import app, db from superset.connectors.sqla.models import SqlaTable @@ -25,7 +25,7 @@ BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/" -misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboard +misc_dash_slices: set[str] = set() # slices assembled in a 'Misc Chart' dashboard def get_table_connector_registry() -> Any: @@ -36,7 +36,7 @@ def get_examples_folder() -> str: return os.path.join(app.config["BASE_DIR"], "examples") -def update_slice_ids(pos: Dict[Any, Any]) -> List[Slice]: +def update_slice_ids(pos: dict[Any, Any]) -> list[Slice]: """Update slice ids in position_json and return the slices found.""" slice_components = [ component @@ -44,7 +44,7 @@ def update_slice_ids(pos: Dict[Any, Any]) -> List[Slice]: if isinstance(component, dict) and component.get("type") == "CHART" ] slices = {} - for name in set(component["meta"]["sliceName"] for component in slice_components): + for name in {component["meta"]["sliceName"] for component in slice_components}: slc = db.session.query(Slice).filter_by(slice_name=name).first() if slc: slices[name] = slc @@ -64,7 +64,7 @@ def merge_slice(slc: Slice) -> None: db.session.commit() -def get_slice_json(defaults: Dict[Any, Any], **kwargs: Any) -> str: +def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str: defaults_copy = defaults.copy() defaults_copy.update(kwargs) return json.dumps(defaults_copy, indent=4, sort_keys=True) diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index de9630ef58503..6bad2a7ac252b 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional, Tuple +from typing import Optional import pandas as pd from sqlalchemy import BigInteger, Date, DateTime, inspect, String @@ -85,7 +85,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals obj.main_dttm_col = "ds" obj.database = database obj.filter_select_enabled = True - dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = { + dttm_and_expr_dict: dict[str, tuple[Optional[str], None]] = { "ds": (None, None), "ds2": (None, None), "epoch_s": ("epoch_s", None), diff --git a/superset/examples/paris.py b/superset/examples/paris.py index a54a3706b13c0..1180c428feb21 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -52,7 +52,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> index=False, ) - print("Creating table {} reference".format(tbl_name)) + print(f"Creating table {tbl_name} reference") table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() if not tbl: diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index 6011b82b09651..76c039afb88a3 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -54,7 +54,7 @@ def load_sf_population_polygons( index=False, ) - print("Creating table {} reference".format(tbl_name)) + print(f"Creating table {tbl_name} reference") table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() if not tbl: diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index e6d7557debf09..bce50c854eb5f 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -19,7 +19,6 @@ import json import textwrap -from typing import List from sqlalchemy import inspect @@ -40,7 +39,7 @@ DASH_SLUG = "supported_charts_dash" -def create_slices(tbl: SqlaTable) -> List[Slice]: +def create_slices(tbl: SqlaTable) -> list[Slice]: slice_kwargs = { "datasource_id": tbl.id, "datasource_type": DatasourceType.TABLE, diff --git a/superset/examples/utils.py b/superset/examples/utils.py index 8c2cfea23c4e6..52d58e3e4a99f 100644 --- a/superset/examples/utils.py +++ b/superset/examples/utils.py @@ -17,7 +17,7 @@ import logging import re from pathlib import Path -from typing import Any, Dict +from typing import Any import yaml from pkg_resources import resource_isdir, resource_listdir, resource_stream @@ -42,13 +42,13 @@ def load_examples_from_configs( command.run() -def load_contents(load_test_data: bool = False) -> Dict[str, Any]: +def load_contents(load_test_data: bool = False) -> dict[str, Any]: """Traverse configs directory and load contents""" root = Path("examples/configs") resource_names = resource_listdir("superset", str(root)) queue = [root / resource_name for resource_name in resource_names] - contents: Dict[Path, str] = {} + contents: dict[Path, str] = {} while queue: path_name = queue.pop() test_re = re.compile(r"\.test\.|metadata\.yaml$") @@ -74,7 +74,7 @@ def load_configs_from_directory( """ Load all the examples from a given directory. """ - contents: Dict[str, str] = {} + contents: dict[str, str] = {} queue = [root] while queue: path_name = queue.pop() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 2972188e0267e..9f9f6bb7005d0 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -17,7 +17,6 @@ """Loads datasets, dashboards and slices in a new superset instance""" import json import os -from typing import List import pandas as pd from sqlalchemy import DateTime, inspect, String @@ -139,7 +138,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s db.session.commit() -def create_slices(tbl: BaseDatasource) -> List[Slice]: +def create_slices(tbl: BaseDatasource) -> list[Slice]: metric = "sum__SP_POP_TOTL" metrics = ["sum__SP_POP_TOTL"] secondary_metric = { diff --git a/superset/exceptions.py b/superset/exceptions.py index 32b06203cdb1e..018d1b6dfb6b3 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from collections import defaultdict -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import gettext as _ from marshmallow import ValidationError @@ -47,7 +47,7 @@ def exception(self) -> Optional[Exception]: def error_type(self) -> Optional[SupersetErrorType]: return self._error_type - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: rv = {} if hasattr(self, "message"): rv["message"] = self.message @@ -67,7 +67,7 @@ def __init__(self, error: SupersetError, status: Optional[int] = None) -> None: if status is not None: self.status = status - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return self.error.to_dict() @@ -94,7 +94,7 @@ def __init__( error_type: SupersetErrorType, message: str, level: ErrorLevel, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, ) -> None: super().__init__( SupersetError( @@ -107,7 +107,7 @@ class SupersetErrorsException(SupersetException): """Exceptions with multiple SupersetErrorType associated with them""" def __init__( - self, errors: List[SupersetError], status: Optional[int] = None + self, errors: list[SupersetError], status: Optional[int] = None ) -> None: super().__init__(str(errors)) self.errors = errors @@ -119,7 +119,7 @@ class SupersetSyntaxErrorException(SupersetErrorsException): status = 422 error_type = SupersetErrorType.SYNTAX_ERROR - def __init__(self, errors: List[SupersetError]) -> None: + def __init__(self, errors: list[SupersetError]) -> None: super().__init__(errors) @@ -134,7 +134,7 @@ def __init__( self, message: str, level: ErrorLevel = ErrorLevel.ERROR, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, ) -> None: super().__init__( SupersetErrorType.GENERIC_DB_ENGINE_ERROR, @@ -152,7 +152,7 @@ def __init__( message: str, error: SupersetErrorType, level: ErrorLevel = ErrorLevel.ERROR, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, ) -> None: super().__init__( error, @@ -166,7 +166,7 @@ class SupersetSecurityException(SupersetErrorException): status = 403 def __init__( - self, error: SupersetError, payload: Optional[Dict[str, Any]] = None + self, error: SupersetError, payload: Optional[dict[str, Any]] = None ) -> None: super().__init__(error) self.payload = payload diff --git a/superset/explore/commands/get.py b/superset/explore/commands/get.py index fb690a9d75238..490d198360dad 100644 --- a/superset/explore/commands/get.py +++ b/superset/explore/commands/get.py @@ -16,7 +16,7 @@ # under the License. import logging from abc import ABC -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Optional import simplejson as json from flask import current_app, request @@ -60,7 +60,7 @@ def __init__( self._slice_id = params.slice_id # pylint: disable=too-many-locals,too-many-branches,too-many-statements - def run(self) -> Optional[Dict[str, Any]]: + def run(self) -> Optional[dict[str, Any]]: initial_form_data = {} if self._permalink_key is not None: @@ -147,7 +147,7 @@ def run(self) -> Optional[Dict[str, Any]]: utils.merge_request_params(form_data, request.args) # TODO: this is a dummy placeholder - should be refactored to being just `None` - datasource_data: Dict[str, Any] = { + datasource_data: dict[str, Any] = { "type": self._datasource_type, "name": datasource_name, "columns": [], diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 90e64f6df7266..97a8bcbf09ed4 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError @@ -31,7 +31,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): - def __init__(self, state: Dict[str, Any]): + def __init__(self, state: dict[str, Any]): self.chart_id: Optional[int] = state["formData"].get("slice_id") self.datasource: str = state["formData"]["datasource"] self.state = state diff --git a/superset/explore/permalink/types.py b/superset/explore/permalink/types.py index 393f0ed8d5890..7eb4a7cb6b124 100644 --- a/superset/explore/permalink/types.py +++ b/superset/explore/permalink/types.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Optional, TypedDict class ExplorePermalinkState(TypedDict, total=False): - formData: Dict[str, Any] - urlParams: Optional[List[Tuple[str, str]]] + formData: dict[str, Any] + urlParams: Optional[list[tuple[str, str]]] class ExplorePermalinkValue(TypedDict): diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index f63338597284a..c2e84f700f5ce 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -16,7 +16,7 @@ # under the License. import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import celery from cachelib.base import BaseCache @@ -58,7 +58,7 @@ def should_use_msgpack(self) -> bool: class UIManifestProcessor: def __init__(self, app_dir: str) -> None: self.app: Optional[Flask] = None - self.manifest: Dict[str, Dict[str, List[str]]] = {} + self.manifest: dict[str, dict[str, list[str]]] = {} self.manifest_file = f"{app_dir}/static/assets/manifest.json" def init_app(self, app: Flask) -> None: @@ -70,10 +70,10 @@ def init_app(self, app: Flask) -> None: def register_processor(self, app: Flask) -> None: app.template_context_processors[None].append(self.get_manifest) - def get_manifest(self) -> Dict[str, Callable[[str], List[str]]]: + def get_manifest(self) -> dict[str, Callable[[str], list[str]]]: loaded_chunks = set() - def get_files(bundle: str, asset_type: str = "js") -> List[str]: + def get_files(bundle: str, asset_type: str = "js") -> list[str]: files = self.get_manifest_files(bundle, asset_type) filtered_files = [f for f in files if f not in loaded_chunks] for f in filtered_files: @@ -88,7 +88,7 @@ def get_files(bundle: str, asset_type: str = "js") -> List[str]: def parse_manifest_json(self) -> None: try: - with open(self.manifest_file, "r") as f: + with open(self.manifest_file) as f: # the manifest includes non-entry files we only need entries in # templates full_manifest = json.load(f) @@ -96,7 +96,7 @@ def parse_manifest_json(self) -> None: except Exception: # pylint: disable=broad-except pass - def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]: + def get_manifest_files(self, bundle: str, asset_type: str) -> list[str]: if self.app and self.app.debug: self.parse_manifest_json() return self.manifest.get(bundle, {}).get(asset_type, []) @@ -117,7 +117,7 @@ def init_app(self, app: Flask) -> None: celery_app = celery.Celery() csrf = CSRFProtect() db = SQLA() -_event_logger: Dict[str, Any] = {} +_event_logger: dict[str, Any] = {} encrypted_field_factory = EncryptedFieldFactory() event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index f69276c908430..6e928c0d5f023 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Optional from uuid import UUID, uuid3 from flask import Flask @@ -37,7 +37,7 @@ def __init__(self, namespace: UUID, default_timeout: int = 300) -> None: @classmethod def factory( - cls, app: Flask, config: Dict[str, Any], args: List[Any], kwargs: Dict[str, Any] + cls, app: Flask, config: dict[str, Any], args: list[Any], kwargs: dict[str, Any] ) -> BaseCache: seed = config.get("CACHE_KEY_PREFIX", "") kwargs["namespace"] = get_uuid_namespace(seed) diff --git a/superset/forms.py b/superset/forms.py index c9b29dfcd088f..f1e220ba952f7 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -16,7 +16,7 @@ # under the License. """Contains the logic to create cohesive forms on the explore view""" import json -from typing import Any, List, Optional +from typing import Any, Optional from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from wtforms import Field @@ -24,12 +24,12 @@ class JsonListField(Field): widget = BS3TextFieldWidget() - data: List[str] = [] + data: list[str] = [] def _value(self) -> str: return json.dumps(self.data) - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist and valuelist[0]: self.data = json.loads(valuelist[0]) else: @@ -38,7 +38,7 @@ def process_formdata(self, valuelist: List[str]) -> None: class CommaSeparatedListField(Field): widget = BS3TextFieldWidget() - data: List[str] = [] + data: list[str] = [] def _value(self) -> str: if self.data: @@ -46,14 +46,14 @@ def _value(self) -> str: return "" - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist: self.data = [x.strip() for x in valuelist[0].split(",")] else: self.data = [] -def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]: +def filter_not_empty_values(values: Optional[list[Any]]) -> Optional[list[Any]]: """Returns a list of non empty values or None""" if not values: return None diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index c489cc323cb64..bbe25f498b4ee 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -19,7 +19,7 @@ import logging import os import sys -from typing import Any, Callable, Dict, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING import wtforms_json from deprecation import deprecated @@ -68,7 +68,7 @@ def __init__(self, app: SupersetApp) -> None: self.superset_app = app self.config = app.config - self.manifest: Dict[Any, Any] = {} + self.manifest: dict[Any, Any] = {} @deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore @property @@ -597,7 +597,7 @@ def __init__(self, app: Flask) -> None: self.app = app def __call__( - self, environ: Dict[str, Any], start_response: Callable[..., Any] + self, environ: dict[str, Any], start_response: Callable[..., Any] ) -> Any: # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore # content-length and read the stream till the end. diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 360c4fc1f4517..f096b65cd1617 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -18,17 +18,7 @@ import json import re from functools import lru_cache, partial -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Optional, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ @@ -71,14 +61,14 @@ @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) -def context_addons() -> Dict[str, Any]: +def context_addons() -> dict[str, Any]: return current_app.config.get("JINJA_CONTEXT_ADDONS", {}) class Filter(TypedDict): op: str # pylint: disable=C0103 col: str - val: Union[None, Any, List[Any]] + val: Union[None, Any, list[Any]] class ExtraCache: @@ -100,9 +90,9 @@ class ExtraCache: def __init__( self, - extra_cache_keys: Optional[List[Any]] = None, - applied_filters: Optional[List[str]] = None, - removed_filters: Optional[List[str]] = None, + extra_cache_keys: Optional[list[Any]] = None, + applied_filters: Optional[list[str]] = None, + removed_filters: Optional[list[str]] = None, dialect: Optional[Dialect] = None, ): self.extra_cache_keys = extra_cache_keys @@ -206,7 +196,7 @@ def url_param( def filter_values( self, column: str, default: Optional[str] = None, remove_filter: bool = False - ) -> List[Any]: + ) -> list[Any]: """Gets a values for a particular filter as a list This is useful if: @@ -230,7 +220,7 @@ def filter_values( only apply to the inner query :return: returns a list of filter values """ - return_val: List[Any] = [] + return_val: list[Any] = [] filters = self.get_filters(column, remove_filter) for flt in filters: val = flt.get("val") @@ -245,7 +235,7 @@ def filter_values( return return_val - def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]: + def get_filters(self, column: str, remove_filter: bool = False) -> list[Filter]: """Get the filters applied to the given column. In addition to returning values like the filter_values function the get_filters function returns the operator specified in the explorer UI. @@ -316,10 +306,10 @@ def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]: convert_legacy_filters_into_adhoc(form_data) merge_extra_filters(form_data) - filters: List[Filter] = [] + filters: list[Filter] = [] for flt in form_data.get("adhoc_filters", []): - val: Union[Any, List[Any]] = flt.get("comparator") + val: Union[Any, list[Any]] = flt.get("comparator") op: str = flt["operator"].upper() if flt.get("operator") else None # fltOpName: str = flt.get("filterOptionName") if ( @@ -370,7 +360,7 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: return return_value -def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]: +def validate_context_types(context: dict[str, Any]) -> dict[str, Any]: for key in context: arg_type = type(context[key]).__name__ if arg_type not in ALLOWED_TYPES and key not in context_addons(): @@ -395,8 +385,8 @@ def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]: def validate_template_context( - engine: Optional[str], context: Dict[str, Any] -) -> Dict[str, Any]: + engine: Optional[str], context: dict[str, Any] +) -> dict[str, Any]: if engine and engine in context: # validate engine context separately to allow for engine-specific methods engine_context = validate_context_types(context.pop(engine)) @@ -407,7 +397,7 @@ def validate_template_context( return validate_context_types(context) -def where_in(values: List[Any], mark: str = "'") -> str: +def where_in(values: list[Any], mark: str = "'") -> str: """ Given a list of values, build a parenthesis list suitable for an IN expression. @@ -439,9 +429,9 @@ def __init__( database: "Database", query: Optional["Query"] = None, table: Optional["SqlaTable"] = None, - extra_cache_keys: Optional[List[Any]] = None, - removed_filters: Optional[List[str]] = None, - applied_filters: Optional[List[str]] = None, + extra_cache_keys: Optional[list[Any]] = None, + removed_filters: Optional[list[str]] = None, + applied_filters: Optional[list[str]] = None, **kwargs: Any, ) -> None: self._database = database @@ -454,7 +444,7 @@ def __init__( self._extra_cache_keys = extra_cache_keys self._applied_filters = applied_filters self._removed_filters = removed_filters - self._context: Dict[str, Any] = {} + self._context: dict[str, Any] = {} self._env = SandboxedEnvironment(undefined=DebugUndefined) self.set_context(**kwargs) @@ -530,7 +520,7 @@ def set_context(self, **kwargs: Any) -> None: @staticmethod def _schema_table( table_name: str, schema: Optional[str] - ) -> Tuple[str, Optional[str]]: + ) -> tuple[str, Optional[str]]: if "." in table_name: schema, table_name = table_name.split(".") return table_name, schema @@ -547,7 +537,7 @@ def first_latest_partition(self, table_name: str) -> Optional[str]: latest_partitions = self.latest_partitions(table_name) return latest_partitions[0] if latest_partitions else None - def latest_partitions(self, table_name: str) -> Optional[List[str]]: + def latest_partitions(self, table_name: str) -> Optional[list[str]]: """ Gets the array of all latest partitions @@ -603,7 +593,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str: @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) -def get_template_processors() -> Dict[str, Any]: +def get_template_processors() -> dict[str, Any]: processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {}) for engine, processor in DEFAULT_PROCESSORS.items(): # do not overwrite engine-specific CUSTOM_TEMPLATE_PROCESSORS @@ -631,7 +621,7 @@ def get_template_processor( def dataset_macro( dataset_id: int, include_metrics: bool = False, - columns: Optional[List[str]] = None, + columns: Optional[list[str]] = None, ) -> str: """ Given a dataset ID, return the SQL that represents it. diff --git a/superset/key_value/types.py b/superset/key_value/types.py index fb9c31899f705..b2a47336c3d3d 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -21,7 +21,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, Optional, TypedDict +from typing import Any, TypedDict from uuid import UUID from marshmallow import Schema, ValidationError @@ -34,14 +34,14 @@ @dataclass class Key: - id: Optional[int] - uuid: Optional[UUID] + id: int | None + uuid: UUID | None class KeyValueFilter(TypedDict, total=False): resource: str - id: Optional[int] - uuid: Optional[UUID] + id: int | None + uuid: UUID | None class KeyValueResource(str, Enum): diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py index 2468618a81b62..6b487c278c0d0 100644 --- a/superset/key_value/utils.py +++ b/superset/key_value/utils.py @@ -18,7 +18,7 @@ from hashlib import md5 from secrets import token_urlsafe -from typing import Any, Union +from typing import Any from uuid import UUID, uuid3 import hashids @@ -35,7 +35,7 @@ def random_key() -> str: return token_urlsafe(48) -def get_filter(resource: KeyValueResource, key: Union[int, UUID]) -> KeyValueFilter: +def get_filter(resource: KeyValueResource, key: int | UUID) -> KeyValueFilter: try: filter_: KeyValueFilter = {"resource": resource.value} if isinstance(key, UUID): diff --git a/superset/legacy.py b/superset/legacy.py index 168b9c0b60a5e..03c1eff7dd5ee 100644 --- a/superset/legacy.py +++ b/superset/legacy.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. """Code related with dealing with legacy / change management""" -from typing import Any, Dict +from typing import Any -def update_time_range(form_data: Dict[str, Any]) -> None: +def update_time_range(form_data: dict[str, Any]) -> None: """Move since and until to time_range.""" if "since" in form_data or "until" in form_data: form_data["time_range"] = "{} : {}".format( diff --git a/superset/migrations/env.py b/superset/migrations/env.py index e3779bb65bcc2..130fb367fb6f6 100755 --- a/superset/migrations/env.py +++ b/superset/migrations/env.py @@ -17,7 +17,6 @@ import logging import urllib.parse from logging.config import fileConfig -from typing import List from alembic import context from alembic.operations.ops import MigrationScript @@ -85,7 +84,7 @@ def run_migrations_online() -> None: # when there are no changes to the schema # reference: https://alembic.sqlalchemy.org/en/latest/cookbook.html def process_revision_directives( # pylint: disable=redefined-outer-name, unused-argument - context: MigrationContext, revision: str, directives: List[MigrationScript] + context: MigrationContext, revision: str, directives: list[MigrationScript] ) -> None: if getattr(config.cmd_opts, "autogenerate", False): script = directives[0] diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index 5ea23551ead57..d3b2efa7a0cca 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -18,7 +18,7 @@ import copy import json -from typing import Any, Dict, Set +from typing import Any from alembic import op from sqlalchemy import and_, Column, Integer, String, Text @@ -44,8 +44,8 @@ class Slice(Base): # type: ignore class MigrateViz: - remove_keys: Set[str] = set() - rename_keys: Dict[str, str] = {} + remove_keys: set[str] = set() + rename_keys: dict[str, str] = {} source_viz_type: str target_viz_type: str has_x_axis_control: bool = False @@ -85,7 +85,7 @@ def _migrate(self) -> None: def _post_action(self) -> None: """Some actions after migrate""" - def _migrate_temporal_filter(self, rv_data: Dict[str, Any]) -> None: + def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None: """Adds a temporal filter.""" granularity_sqla = rv_data.pop("granularity_sqla", None) time_range = rv_data.pop("time_range", None) or conf.get("DEFAULT_TIME_FILTER") diff --git a/superset/migrations/shared/security_converge.py b/superset/migrations/shared/security_converge.py index 19caa3932b874..9b1730a2a1464 100644 --- a/superset/migrations/shared/security_converge.py +++ b/superset/migrations/shared/security_converge.py @@ -16,7 +16,6 @@ # under the License. import logging from dataclasses import dataclass -from typing import Dict, List, Tuple from sqlalchemy import ( Column, @@ -41,7 +40,7 @@ class Pvm: permission: str -PvmMigrationMapType = Dict[Pvm, Tuple[Pvm, ...]] +PvmMigrationMapType = dict[Pvm, tuple[Pvm, ...]] # Partial freeze of the current metadata db schema @@ -162,8 +161,8 @@ def _find_pvm(session: Session, view_name: str, permission_name: str) -> Permiss def add_pvms( - session: Session, pvm_data: Dict[str, Tuple[str, ...]], commit: bool = False -) -> List[PermissionView]: + session: Session, pvm_data: dict[str, tuple[str, ...]], commit: bool = False +) -> list[PermissionView]: """ Checks if exists and adds new Permissions, Views and PermissionView's """ @@ -181,7 +180,7 @@ def add_pvms( def _delete_old_permissions( - session: Session, pvm_map: Dict[PermissionView, List[PermissionView]] + session: Session, pvm_map: dict[PermissionView, list[PermissionView]] ) -> None: """ Delete old permissions: @@ -222,7 +221,7 @@ def migrate_roles( Migrates all existing roles that have the permissions to be migrated """ # Collect a map of PermissionView objects for migration - pvm_map: Dict[PermissionView, List[PermissionView]] = {} + pvm_map: dict[PermissionView, list[PermissionView]] = {} for old_pvm_key, new_pvms_ in pvm_key_map.items(): old_pvm = _find_pvm(session, old_pvm_key.view, old_pvm_key.permission) if old_pvm: @@ -252,8 +251,8 @@ def migrate_roles( session.commit() -def get_reversed_new_pvms(pvm_map: PvmMigrationMapType) -> Dict[str, Tuple[str, ...]]: - reversed_pvms: Dict[str, Tuple[str, ...]] = {} +def get_reversed_new_pvms(pvm_map: PvmMigrationMapType) -> dict[str, tuple[str, ...]]: + reversed_pvms: dict[str, tuple[str, ...]] = {} for old_pvm, new_pvms in pvm_map.items(): if old_pvm.view not in reversed_pvms: reversed_pvms[old_pvm.view] = (old_pvm.permission,) diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index e05b1d357f2e4..32e7dc1a3992e 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -18,7 +18,8 @@ import logging import os import time -from typing import Any, Callable, Dict, Iterator, Optional, Union +from collections.abc import Iterator +from typing import Any, Callable, Optional, Union from uuid import uuid4 from alembic import op @@ -127,7 +128,7 @@ def paginated_update( print_page_progress(processed, total) -def try_load_json(data: Optional[str]) -> Dict[str, Any]: +def try_load_json(data: Optional[str]) -> dict[str, Any]: try: return data and json.loads(data) or {} except json.decoder.JSONDecodeError: diff --git a/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py b/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py index 56d5f887b3e0e..1f3dbab6367aa 100644 --- a/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py +++ b/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py @@ -59,7 +59,7 @@ def upgrade(): slc.params = json.dumps(d, indent=2, sort_keys=True) session.merge(slc) session.commit() - print("Upgraded ({}/{}): {}".format(i, slice_len, slc.slice_name)) + print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}") except Exception as ex: print(slc.slice_name + " error: " + str(ex)) diff --git a/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py b/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py index 04a39a31f5805..8e97ada3cd69f 100644 --- a/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py +++ b/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py @@ -82,7 +82,7 @@ def upgrade(): url.url = newurl session.merge(url) session.commit() - print("Updating url ({}/{})".format(i, urls_len)) + print(f"Updating url ({i}/{urls_len})") session.close() diff --git a/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py b/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py index 26cfb93b991ca..f6d5610d97ef1 100644 --- a/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py +++ b/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py @@ -69,7 +69,7 @@ def upgrade(): batch_op.add_column(sa.Column("datasource_id", sa.Integer)) batch_op.create_foreign_key( - "fk_{}_datasource_id_datasources".format(foreign), + f"fk_{foreign}_datasource_id_datasources", "datasources", ["datasource_id"], ["id"], @@ -102,7 +102,7 @@ def upgrade(): for name in names: batch_op.drop_constraint( - name or "fk_{}_datasource_name_datasources".format(foreign), + name or f"fk_{foreign}_datasource_name_datasources", type_="foreignkey", ) @@ -148,7 +148,7 @@ def downgrade(): batch_op.add_column(sa.Column("datasource_name", sa.String(255))) batch_op.create_foreign_key( - "fk_{}_datasource_name_datasources".format(foreign), + f"fk_{foreign}_datasource_name_datasources", "datasources", ["datasource_name"], ["datasource_name"], @@ -174,7 +174,7 @@ def downgrade(): with op.batch_alter_table(foreign, naming_convention=conv) as batch_op: # Drop the datasource_id column and associated constraint. batch_op.drop_constraint( - "fk_{}_datasource_id_datasources".format(foreign), type_="foreignkey" + f"fk_{foreign}_datasource_id_datasources", type_="foreignkey" ) batch_op.drop_column("datasource_id") @@ -201,7 +201,7 @@ def downgrade(): # Re-create the foreign key associated with the cluster_name column. batch_op.create_foreign_key( - "fk_{}_datasource_id_datasources".format(foreign), + f"fk_{foreign}_datasource_id_datasources", "clusters", ["cluster_name"], ["cluster_name"], diff --git a/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py b/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py index 5593af0eb6650..4b1b807a6fddc 100644 --- a/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py +++ b/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py @@ -59,7 +59,7 @@ def upgrade(): { "annotationType": "INTERVAL", "style": "solid", - "name": "Layer {}".format(layer), + "name": f"Layer {layer}", "show": True, "overrides": {"since": None, "until": None}, "value": layer, diff --git a/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py b/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py index 286be8a5fc9c6..bf6276d702c7b 100644 --- a/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py +++ b/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py @@ -51,7 +51,7 @@ def upgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("Upgrading ({}/{}): {}".format(i, len(dashboards), dashboard.id)) + print(f"Upgrading ({i}/{len(dashboards)}): {dashboard.id}") positions = json.loads(dashboard.position_json or "{}") for pos in positions: if pos.get("v", 0) == 0: @@ -74,7 +74,7 @@ def downgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("Downgrading ({}/{}): {}".format(i, len(dashboards), dashboard.id)) + print(f"Downgrading ({i}/{len(dashboards)}): {dashboard.id}") positions = json.loads(dashboard.position_json or "{}") for pos in positions: if pos.get("v", 0) == 1: diff --git a/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py b/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py index dbe3f0ace4220..c73399fb92e17 100644 --- a/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py +++ b/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py @@ -49,7 +49,7 @@ def upgrade(): for table, column in names.items(): with op.batch_alter_table(table, naming_convention=conv) as batch_op: batch_op.create_unique_constraint( - "uq_{}_{}".format(table, column), [column, "datasource_id"] + f"uq_{table}_{column}", [column, "datasource_id"] ) @@ -71,6 +71,6 @@ def downgrade(): with op.batch_alter_table(table, naming_convention=conv) as batch_op: batch_op.drop_constraint( generic_find_uq_constraint_name(table, {column, "datasource_id"}, insp) - or "uq_{}_{}".format(table, column), + or f"uq_{table}_{column}", type_="unique", ) diff --git a/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py b/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py index 3e2b81c17a82c..49b19b9c696fd 100644 --- a/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py +++ b/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py @@ -61,7 +61,7 @@ def upgrade(): slc.params = json.dumps(params, indent=2, sort_keys=True) session.merge(slc) session.commit() - print("Upgraded ({}/{}): {}".format(i, slice_len, slc.slice_name)) + print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}") except Exception as ex: print(slc.slice_name + " error: " + str(ex)) diff --git a/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py b/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py index ec03328271599..6292e2860ab58 100644 --- a/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py +++ b/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py @@ -28,8 +28,6 @@ import json -import uuid -from collections import defaultdict from alembic import op from sqlalchemy import Column, Integer, Text diff --git a/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py b/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py index 2e491e9303f4e..a2dd50bf9ca81 100644 --- a/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py +++ b/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py @@ -77,24 +77,24 @@ def isodate_duration_to_string(obj): if obj.tdelta: if not obj.months and not obj.years: return format_seconds(obj.tdelta.total_seconds()) - raise Exception("Unable to convert: {0}".format(obj)) + raise Exception(f"Unable to convert: {obj}") if obj.months % 12 != 0: months = obj.months + 12 * obj.years - return "{0} months".format(months) + return f"{months} months" - return "{0} years".format(obj.years + obj.months // 12) + return f"{obj.years + obj.months // 12} years" def timedelta_to_string(obj): if obj.microseconds: - raise Exception("Unable to convert: {0}".format(obj)) + raise Exception(f"Unable to convert: {obj}") elif obj.seconds: return format_seconds(obj.total_seconds()) elif obj.days % 7 == 0: - return "{0} weeks".format(obj.days // 7) + return f"{obj.days // 7} weeks" else: - return "{0} days".format(obj.days) + return f"{obj.days} days" def format_seconds(value): @@ -106,7 +106,7 @@ def format_seconds(value): else: period = "second" - return "{0} {1}{2}".format(value, period, "s" if value > 1 else "") + return "{} {}{}".format(value, period, "s" if value > 1 else "") def compute_time_compare(granularity, periods): @@ -120,11 +120,11 @@ def compute_time_compare(granularity, periods): obj = isodate.parse_duration(granularity) * periods except isodate.isoerror.ISO8601Error: # if parse_human_timedelta can parse it, return it directly - delta = "{0} {1}{2}".format(periods, granularity, "s" if periods > 1 else "") + delta = "{} {}{}".format(periods, granularity, "s" if periods > 1 else "") obj = parse_human_timedelta(delta) if obj: return delta - raise Exception("Unable to parse: {0}".format(granularity)) + raise Exception(f"Unable to parse: {granularity}") if isinstance(obj, isodate.duration.Duration): return isodate_duration_to_string(obj) diff --git a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py index 13c4e61718cc4..620e2c5008e62 100644 --- a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py +++ b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py @@ -173,7 +173,7 @@ def get_header_component(title): def get_row_container(): return { "type": ROW_TYPE, - "id": "DASHBOARD_ROW_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_ROW_TYPE-{generate_id()}", "children": [], "meta": {"background": BACKGROUND_TRANSPARENT}, } @@ -182,7 +182,7 @@ def get_row_container(): def get_col_container(): return { "type": COLUMN_TYPE, - "id": "DASHBOARD_COLUMN_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_COLUMN_TYPE-{generate_id()}", "children": [], "meta": {"background": BACKGROUND_TRANSPARENT}, } @@ -203,18 +203,18 @@ def get_chart_holder(position): if len(code): markdown_content = code elif slice_name.strip(): - markdown_content = "##### {}".format(slice_name) + markdown_content = f"##### {slice_name}" return { "type": MARKDOWN_TYPE, - "id": "DASHBOARD_MARKDOWN_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_MARKDOWN_TYPE-{generate_id()}", "children": [], "meta": {"width": width, "height": height, "code": markdown_content}, } return { "type": CHART_TYPE, - "id": "DASHBOARD_CHART_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_CHART_TYPE-{generate_id()}", "children": [], "meta": {"width": width, "height": height, "chartId": int(slice_id)}, } @@ -584,10 +584,10 @@ def upgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("scanning dashboard ({}/{}) >>>>".format(i + 1, len(dashboards))) + print(f"scanning dashboard ({i + 1}/{len(dashboards)}) >>>>") position_json = json.loads(dashboard.position_json or "[]") if not is_v2_dash(position_json): - print("Converting dashboard... dash_id: {}".format(dashboard.id)) + print(f"Converting dashboard... dash_id: {dashboard.id}") position_dict = {} positions = [] slices = dashboard.slices @@ -650,7 +650,7 @@ def upgrade(): session.merge(dashboard) session.commit() else: - print("Skip converted dash_id: {}".format(dashboard.id)) + print(f"Skip converted dash_id: {dashboard.id}") session.close() diff --git a/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py b/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py index bfb7a66161def..3c6979f961e30 100644 --- a/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py +++ b/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py @@ -51,7 +51,7 @@ def upgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("scanning dashboard ({}/{}) >>>>".format(i + 1, len(dashboards))) + print(f"scanning dashboard ({i + 1}/{len(dashboards)}) >>>>") if dashboard.json_metadata: json_metadata = json.loads(dashboard.json_metadata) has_update = False @@ -74,7 +74,7 @@ def upgrade(): # if user already defined __time_range, # just abandon __from and __to if "__time_range" not in val: - val["__time_range"] = "{} : {}".format(__from, __to) + val["__time_range"] = f"{__from} : {__to}" json_metadata["default_filters"] = json.dumps(filters) has_update = True except Exception: diff --git a/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py b/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py index 1d0690c5e0477..073bfdc47447c 100644 --- a/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py +++ b/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py @@ -29,8 +29,6 @@ import copy import json import logging -import uuid -from collections import defaultdict from alembic import op from sqlalchemy import Column, Integer, Text diff --git a/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py b/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py index 404ea96e4402a..3b7c3951cd023 100644 --- a/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py +++ b/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py @@ -26,14 +26,13 @@ revision = "b5998378c225" down_revision = "72428d1ea401" -from typing import Dict import sqlalchemy as sa from alembic import op def upgrade(): - kwargs: Dict[str, str] = {} + kwargs: dict[str, str] = {} bind = op.get_bind() op.add_column( "dbs", diff --git a/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py b/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py index 6b63c468eca0c..4202de45609fd 100644 --- a/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py +++ b/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py @@ -21,7 +21,6 @@ Create Date: 2020-08-12 00:24:39.617899 """ -import collections import json import logging import uuid @@ -77,7 +76,7 @@ class Dashboard(Base): def create_new_markdown_component(chart_position, url): return { "type": "MARKDOWN", - "id": "MARKDOWN-{}".format(uuid.uuid4().hex[:8]), + "id": f"MARKDOWN-{uuid.uuid4().hex[:8]}", "children": [], "parents": chart_position["parents"], "meta": { diff --git a/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py b/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py index a6db4c2cb6153..45f091c38e2e7 100644 --- a/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py +++ b/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py @@ -167,7 +167,7 @@ def upgrade(): orphaned_faulty_view_menus = [] for faulty_view_menu in faulty_view_menus: # Get the dataset id from the view_menu name - match_ds_id = re.match("\[None\]\.\[.*\]\(id:(\d+)\)", faulty_view_menu.name) + match_ds_id = re.match(r"\[None\]\.\[.*\]\(id:(\d+)\)", faulty_view_menu.name) if match_ds_id: dataset_id = int(match_ds_id.group(1)) dataset = session.query(SqlaTable).get(dataset_id) diff --git a/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py b/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py index 79b032894f058..64396b6abe71c 100644 --- a/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py +++ b/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py @@ -27,7 +27,8 @@ down_revision = "085f06488938" import json -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any from alembic import op from sqlalchemy import Column, Integer, Text @@ -77,7 +78,7 @@ class Dashboard(Base): ) -def upgrade_select_filters(native_filters: Iterable[Dict[str, Any]]) -> None: +def upgrade_select_filters(native_filters: Iterable[dict[str, Any]]) -> None: """ Add `defaultToFirstItem` to `controlValues` of `select_filter` components """ @@ -89,7 +90,7 @@ def upgrade_select_filters(native_filters: Iterable[Dict[str, Any]]) -> None: control_values["defaultToFirstItem"] = value -def upgrade_filter_set(filter_set: Dict[str, Any]) -> int: +def upgrade_filter_set(filter_set: dict[str, Any]) -> int: changed_filters = 0 upgrade_select_filters(filter_set.get("nativeFilters", {}).values()) data_mask = filter_set.get("dataMask", {}) @@ -124,7 +125,7 @@ def upgrade_filter_set(filter_set: Dict[str, Any]) -> int: return changed_filters -def downgrade_filter_set(filter_set: Dict[str, Any]) -> int: +def downgrade_filter_set(filter_set: dict[str, Any]) -> int: changed_filters = 0 old_data_mask = filter_set.pop("dataMask", {}) native_filters = {} diff --git a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py index ec8f8e1cc0566..42368ce896957 100644 --- a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py +++ b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py @@ -27,7 +27,8 @@ down_revision = "d416d0d715cc" import json -from typing import Any, Dict, Iterable, Tuple +from collections.abc import Iterable +from typing import Any from alembic import op from sqlalchemy import Column, Integer, Text @@ -46,7 +47,7 @@ class Dashboard(Base): json_metadata = Column(Text) -def upgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: +def upgrade_filters(native_filters: Iterable[dict[str, Any]]) -> int: """ Move `defaultValue` into `defaultDataMask.filterState` """ @@ -61,7 +62,7 @@ def upgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: return changed_filters -def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: +def downgrade_filters(native_filters: Iterable[dict[str, Any]]) -> int: """ Move `defaultDataMask.filterState` into `defaultValue` """ @@ -76,7 +77,7 @@ def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: return changed_filters -def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: +def upgrade_dashboard(dashboard: dict[str, Any]) -> tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata # upgrade native select filter metadata @@ -119,7 +120,7 @@ def upgrade(): print(f"Upgraded {changed_filters} filters and {changed_filter_sets} filter sets.") -def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: +def downgrade_dashboard(dashboard: dict[str, Any]) -> tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata if native_filters := dashboard.get("native_filter_configuration"): diff --git a/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py b/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py index 888925a888c7b..8be11d3cf6bd8 100644 --- a/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py +++ b/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py @@ -27,7 +27,6 @@ down_revision = "e323605f370a" import json -from typing import Any, Dict, List, Tuple from alembic import op from sqlalchemy import and_, Column, Integer, String, Text diff --git a/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py b/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py index e44bdae7820d3..ab852c324bba1 100644 --- a/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py +++ b/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py @@ -27,7 +27,6 @@ down_revision = "3ebe0993c770" import json -import re from alembic import op from sqlalchemy import and_, Column, Integer, String, Text diff --git a/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py b/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py index db1b87e546062..b85e9397e6b61 100644 --- a/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py +++ b/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py @@ -27,7 +27,6 @@ down_revision = "60dc453f4e2e" import json -from typing import Dict from alembic import op from sqlalchemy import Column, Integer, Text @@ -45,7 +44,7 @@ class Slice(Base): params = Column(Text) -def migrate(mapping: Dict[str, str]) -> None: +def migrate(mapping: dict[str, str]) -> None: bind = op.get_bind() session = db.Session(bind=bind) diff --git a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py index 286a0731fc7e6..b51f6c78ac470 100644 --- a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py +++ b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py @@ -29,7 +29,7 @@ import json import os from datetime import datetime -from typing import List, Optional, Set, Type, Union +from typing import Optional, Union from uuid import uuid4 import sqlalchemy as sa @@ -86,7 +86,7 @@ def changed_by_fk(cls): def insert_from_select( - target: Union[str, sa.Table, Type[Base]], source: sa.sql.expression.Select + target: Union[str, sa.Table, type[Base]], source: sa.sql.expression.Select ) -> None: """ Execute INSERT FROM SELECT to copy data from a SELECT query to the target table. @@ -274,8 +274,8 @@ def find_tables( session: Session, database_id: int, default_schema: Optional[str], - tables: Set[Table], -) -> List[int]: + tables: set[Table], +) -> list[int]: """ Look for NewTable's of from a specific database """ diff --git a/superset/models/annotations.py b/superset/models/annotations.py index 3185460bf5f22..54de94e7f6457 100644 --- a/superset/models/annotations.py +++ b/superset/models/annotations.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """a collection of Annotation-related models""" -from typing import Any, Dict +from typing import Any from flask_appbuilder import Model from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text @@ -54,7 +54,7 @@ class Annotation(Model, AuditMixinNullable): __table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),) @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: return { "layer_id": self.layer_id, "start_dttm": self.start_dttm, diff --git a/superset/models/core.py b/superset/models/core.py index ee50f063456f7..3c2b12d3782ba 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=line-too-long """A collection of ORM sqlalchemy models for Superset""" +import builtins import enum import json import logging @@ -25,7 +26,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import numpy import pandas as pd @@ -194,7 +195,7 @@ def allows_subquery(self) -> bool: return self.db_engine_spec.allows_subqueries @property - def function_names(self) -> List[str]: + def function_names(self) -> list[str]: try: return self.db_engine_spec.get_function_names(self) except Exception as ex: # pylint: disable=broad-except @@ -234,7 +235,7 @@ def disable_data_preview(self) -> bool: return True @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: return { "id": self.id, "name": self.database_name, @@ -271,7 +272,7 @@ def masked_encrypted_extra(self) -> Optional[str]: return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra) @property - def parameters(self) -> Dict[str, Any]: + def parameters(self) -> dict[str, Any]: # Database parameters are a dictionary of values that are used to make up # the sqlalchemy_uri # When returning the parameters we should use the masked SQLAlchemy URI and the @@ -296,7 +297,7 @@ def parameters(self) -> Dict[str, Any]: return parameters @property - def parameters_schema(self) -> Dict[str, Any]: + def parameters_schema(self) -> dict[str, Any]: try: parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore except Exception: # pylint: disable=broad-except @@ -304,7 +305,7 @@ def parameters_schema(self) -> Dict[str, Any]: return parameters_schema @property - def metadata_cache_timeout(self) -> Dict[str, Any]: + def metadata_cache_timeout(self) -> dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) @property @@ -324,15 +325,15 @@ def table_cache_timeout(self) -> Optional[int]: return self.metadata_cache_timeout.get("table_cache_timeout") @property - def default_schemas(self) -> List[str]: + def default_schemas(self) -> list[str]: return self.get_extra().get("default_schemas", []) @property - def connect_args(self) -> Dict[str, Any]: + def connect_args(self) -> dict[str, Any]: return self.get_extra().get("engine_params", {}).get("connect_args", {}) @property - def engine_information(self) -> Dict[str, Any]: + def engine_information(self) -> dict[str, Any]: try: engine_information = self.db_engine_spec.get_public_information() except Exception: # pylint: disable=broad-except @@ -540,7 +541,7 @@ def quote_identifier(self) -> Callable[[str], str]: """Add quotes to potential identifiter expressions if needed""" return self.get_dialect().identifier_preparer.quote - def get_reserved_words(self) -> Set[str]: + def get_reserved_words(self) -> set[str]: return self.get_dialect().preparer.reserved_words def get_df( # pylint: disable=too-many-locals @@ -629,7 +630,7 @@ def select_star( # pylint: disable=too-many-arguments show_cols: bool = False, indent: bool = True, latest_partition: bool = False, - cols: Optional[List[Dict[str, Any]]] = None, + cols: Optional[list[dict[str, Any]]] = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) @@ -670,7 +671,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> Set[Tuple[str, str]]: + ) -> set[tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -706,7 +707,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> Set[Tuple[str, str]]: + ) -> set[tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -750,7 +751,7 @@ def get_all_schema_names( # pylint: disable=unused-argument cache_timeout: Optional[int] = None, force: bool = False, ssh_tunnel: Optional["SSHTunnel"] = None, - ) -> List[str]: + ) -> list[str]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -768,13 +769,15 @@ def get_all_schema_names( # pylint: disable=unused-argument raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @property - def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: + def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]: url = make_url_safe(self.sqlalchemy_uri_decrypted) return self.get_db_engine_spec(url) @classmethod @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) - def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]: + def get_db_engine_spec( + cls, url: URL + ) -> builtins.type[db_engine_specs.BaseEngineSpec]: backend = url.get_backend_name() try: driver = url.get_driver_name() @@ -784,7 +787,7 @@ def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]: return db_engine_specs.get_engine_spec(backend, driver) - def grains(self) -> Tuple[TimeGrain, ...]: + def grains(self) -> tuple[TimeGrain, ...]: """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain @@ -795,10 +798,10 @@ def grains(self) -> Tuple[TimeGrain, ...]: """ return self.db_engine_spec.get_time_grains() - def get_extra(self) -> Dict[str, Any]: + def get_extra(self) -> dict[str, Any]: return self.db_engine_spec.get_extra_params(self) - def get_encrypted_extra(self) -> Dict[str, Any]: + def get_encrypted_extra(self) -> dict[str, Any]: encrypted_extra = {} if self.encrypted_extra: try: @@ -809,7 +812,7 @@ def get_encrypted_extra(self) -> Dict[str, Any]: return encrypted_extra # pylint: disable=invalid-name - def update_params_from_encrypted_extra(self, params: Dict[str, Any]) -> None: + def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None: self.db_engine_spec.update_params_from_encrypted_extra(self, params) def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: @@ -832,7 +835,7 @@ def get_table_comment( def get_columns( self, table_name: str, schema: Optional[str] = None - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_columns(inspector, table_name, schema) @@ -840,19 +843,19 @@ def get_metrics( self, table_name: str, schema: Optional[str] = None, - ) -> List[MetricType]: + ) -> list[MetricType]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) def get_indexes( self, table_name: str, schema: Optional[str] = None - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: with self.get_inspector_with_context() as inspector: pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} @@ -866,13 +869,13 @@ def _convert(value: Any) -> Any: def get_foreign_keys( self, table_name: str, schema: Optional[str] = None - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, - ) -> List[str]: + ) -> list[str]: allowed_databases = self.get_extra().get("schemas_allowed_for_file_upload", []) if isinstance(allowed_databases, str): @@ -932,7 +935,7 @@ def _has_view( view_name: str, schema: Optional[str] = None, ) -> bool: - view_names: List[str] = [] + view_names: list[str] = [] try: view_names = dialect.get_view_names(connection=conn, schema=schema) except Exception: # pylint: disable=broad-except diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 9afd74f5e383a..f3b9c08794793 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -21,7 +21,7 @@ import uuid from collections import defaultdict from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable import sqlalchemy as sqla from flask import current_app @@ -68,9 +68,7 @@ logger = logging.getLogger(__name__) -def copy_dashboard( - _mapper: Mapper, connection: Connection, target: "Dashboard" -) -> None: +def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -> None: dashboard_id = config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: return @@ -146,7 +144,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): certification_details = Column(Text) json_metadata = Column(Text) slug = Column(String(255), unique=True) - slices: List[Slice] = relationship( + slices: list[Slice] = relationship( Slice, secondary=dashboard_slices, backref="dashboards" ) owners = relationship(security_manager.user_model, secondary=dashboard_user) @@ -187,14 +185,14 @@ def url(self) -> str: return f"/superset/dashboard/{self.slug or self.id}/" @staticmethod - def get_url(id_: int, slug: Optional[str] = None) -> str: + def get_url(id_: int, slug: str | None = None) -> str: # To be able to generate URL's without instanciating a Dashboard object return f"/superset/dashboard/{slug or id_}/" @property - def datasources(self) -> Set[BaseDatasource]: + def datasources(self) -> set[BaseDatasource]: # Verbose but efficient database enumeration of dashboard datasources. - datasources_by_cls_model: Dict[Type["BaseDatasource"], Set[int]] = defaultdict( + datasources_by_cls_model: dict[type[BaseDatasource], set[int]] = defaultdict( set ) @@ -210,14 +208,14 @@ def datasources(self) -> Set[BaseDatasource]: } @property - def filter_sets(self) -> Dict[int, FilterSet]: + def filter_sets(self) -> dict[int, FilterSet]: return {fs.id: fs for fs in self._filter_sets} @property - def filter_sets_lst(self) -> Dict[int, FilterSet]: + def filter_sets_lst(self) -> dict[int, FilterSet]: if security_manager.is_admin(): return self._filter_sets - filter_sets_by_owner_type: Dict[str, List[Any]] = {"Dashboard": [], "User": []} + filter_sets_by_owner_type: dict[str, list[Any]] = {"Dashboard": [], "User": []} for fs in self._filter_sets: filter_sets_by_owner_type[fs.owner_type].append(fs) user_filter_sets = list( @@ -232,7 +230,7 @@ def filter_sets_lst(self) -> Dict[int, FilterSet]: } @property - def charts(self) -> List[str]: + def charts(self) -> list[str]: return [slc.chart for slc in self.slices] @property @@ -281,7 +279,7 @@ def changed_by_url(self) -> str: return f"/superset/profile/{self.changed_by.username}" @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: positions = self.position_json if positions: positions = json.loads(positions) @@ -305,16 +303,16 @@ def data(self) -> Dict[str, Any]: make_name=lambda fname: f"{fname}-v1.0", unless=lambda: not is_feature_enabled("DASHBOARD_CACHE"), ) - def datasets_trimmed_for_slices(self) -> List[Dict[str, Any]]: + def datasets_trimmed_for_slices(self) -> list[dict[str, Any]]: # Verbose but efficient database enumeration of dashboard datasources. - slices_by_datasource: Dict[ - Tuple[Type["BaseDatasource"], int], Set[Slice] + slices_by_datasource: dict[ + tuple[type[BaseDatasource], int], set[Slice] ] = defaultdict(set) for slc in self.slices: slices_by_datasource[(slc.cls_model, slc.datasource_id)].add(slc) - result: List[Dict[str, Any]] = [] + result: list[dict[str, Any]] = [] for (cls_model, datasource_id), slices in slices_by_datasource.items(): datasource = ( @@ -336,7 +334,7 @@ def params(self, value: str) -> None: self.json_metadata = value @property - def position(self) -> Dict[str, Any]: + def position(self) -> dict[str, Any]: if self.position_json: return json.loads(self.position_json) return {} @@ -380,7 +378,7 @@ def clear_cache_for_datasource(cls, datasource_id: int) -> None: @classmethod def export_dashboards( # pylint: disable=too-many-locals - cls, dashboard_ids: List[int] + cls, dashboard_ids: list[int] ) -> str: copied_dashboards = [] datasource_ids = set() @@ -413,7 +411,7 @@ def export_dashboards( # pylint: disable=too-many-locals slices.append(copied_slc) json_metadata = json.loads(dashboard.json_metadata) - native_filter_configuration: List[Dict[str, Any]] = json_metadata.get( + native_filter_configuration: list[dict[str, Any]] = json_metadata.get( "native_filter_configuration", [] ) for native_filter in native_filter_configuration: @@ -449,12 +447,12 @@ def export_dashboards( # pylint: disable=too-many-locals ) @classmethod - def get(cls, id_or_slug: Union[str, int]) -> Dashboard: + def get(cls, id_or_slug: str | int) -> Dashboard: qry = db.session.query(Dashboard).filter(id_or_slug_filter(id_or_slug)) return qry.one_or_none() -def is_uuid(value: Union[str, int]) -> bool: +def is_uuid(value: str | int) -> bool: try: uuid.UUID(str(value)) return True @@ -462,7 +460,7 @@ def is_uuid(value: Union[str, int]) -> bool: return False -def is_int(value: Union[str, int]) -> bool: +def is_int(value: str | int) -> bool: try: int(value) return True @@ -470,7 +468,7 @@ def is_int(value: Union[str, int]) -> bool: return False -def id_or_slug_filter(id_or_slug: Union[int, str]) -> BinaryExpression: +def id_or_slug_filter(id_or_slug: int | str) -> BinaryExpression: if is_int(id_or_slug): return Dashboard.id == int(id_or_slug) if is_uuid(id_or_slug): @@ -490,7 +488,7 @@ def id_or_slug_filter(id_or_slug: Union[int, str]) -> BinaryExpression: def clear_dashboard_cache( _mapper: Mapper, _connection: Connection, - obj: Union[Slice, BaseDatasource, Dashboard], + obj: Slice | BaseDatasource | Dashboard, check_modified: bool = True, ) -> None: if check_modified and not object_session(obj).is_modified(obj): diff --git a/superset/models/datasource_access_request.py b/superset/models/datasource_access_request.py index 1f286f96d8b40..23df4cffae38a 100644 --- a/superset/models/datasource_access_request.py +++ b/superset/models/datasource_access_request.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Type, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from flask import Markup from flask_appbuilder import Model @@ -41,7 +41,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): ROLES_DENYLIST = set(config["ROBOT_PERMISSION_ROLES"]) @property - def cls_model(self) -> Type["BaseDatasource"]: + def cls_model(self) -> type["BaseDatasource"]: # pylint: disable=import-outside-toplevel from superset.datasource.dao import DatasourceDAO @@ -77,7 +77,7 @@ def roles_with_datasource(self) -> str: f"datasource_id={self.datasource_id}&" f"created_by={self.created_by.username}&role_to_grant={role.name}" ) - link = 'Grant {} Role'.format(href, role.name) + link = f'Grant {role.name} Role' action_list = action_list + "
  • " + link + "
  • " return "
      " + action_list + "
    " @@ -90,8 +90,8 @@ def user_roles(self) -> str: f"datasource_id={self.datasource_id}&" f"created_by={self.created_by.username}&role_to_extend={role.name}" ) - link = 'Extend {} Role'.format(href, role.name) + link = f'Extend {role.name} Role' if role.name in self.ROLES_DENYLIST: - link = "{} Role".format(role.name) + link = f"{role.name} Role" action_list = action_list + "
  • " + link + "
  • " return "
      " + action_list + "
    " diff --git a/superset/models/embedded_dashboard.py b/superset/models/embedded_dashboard.py index 7718bc886f97a..32a8e4abcef37 100644 --- a/superset/models/embedded_dashboard.py +++ b/superset/models/embedded_dashboard.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import uuid -from typing import List from flask_appbuilder import Model from sqlalchemy import Column, ForeignKey, Integer, Text @@ -49,7 +48,7 @@ class EmbeddedDashboard(Model, AuditMixinNullable): ) @property - def allowed_domains(self) -> List[str]: + def allowed_domains(self) -> list[str]: """ A list of domains which are allowed to embed the dashboard. An empty list means any domain can embed. diff --git a/superset/models/filter_set.py b/superset/models/filter_set.py index 1ace5bca32df9..ac25b114ff0c1 100644 --- a/superset/models/filter_set.py +++ b/superset/models/filter_set.py @@ -18,7 +18,7 @@ import json import logging -from typing import Any, Dict +from typing import Any from flask import current_app from flask_appbuilder import Model @@ -75,7 +75,7 @@ def changed_by_url(self) -> str: return "" return f"/superset/profile/{self.changed_by.username}" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "name": self.name, @@ -105,7 +105,7 @@ def get_by_dashboard_id(cls, dashboard_id: int) -> FilterSet: return qry.all() @property - def params(self) -> Dict[str, Any]: + def params(self) -> dict[str, Any]: if self.json_metadata: return json.loads(self.json_metadata) return {} diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 32c6f5ff6ab4a..42d5a24174217 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -16,29 +16,17 @@ # under the License. # pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" +import builtins import dataclasses import json import logging import re import uuid from collections import defaultdict +from collections.abc import Hashable from datetime import datetime, timedelta from json.decoder import JSONDecodeError -from typing import ( - Any, - cast, - Dict, - Hashable, - List, - NamedTuple, - Optional, - Set, - Text, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union import dateutil.parser import humanize @@ -145,7 +133,7 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) -def json_to_dict(json_str: str) -> Dict[Any, Any]: +def json_to_dict(json_str: str) -> dict[Any, Any]: if json_str: val = re.sub(",[ \t\r\n]+}", "}", json_str) val = re.sub(",[ \t\r\n]+\\]", "]", val) @@ -179,22 +167,22 @@ class ImportExportMixin: # The name of the attribute # with the SQL Alchemy back reference - export_children: List[str] = [] + export_children: list[str] = [] # List of (str) names of attributes # with the SQL Alchemy forward references - export_fields: List[str] = [] + export_fields: list[str] = [] # The names of the attributes # that are available for import and export - extra_import_fields: List[str] = [] + extra_import_fields: list[str] = [] # Additional fields that should be imported, # even though they were not exported __mapper__: Mapper @classmethod - def _unique_constrains(cls) -> List[Set[str]]: + def _unique_constrains(cls) -> list[set[str]]: """Get all (single column and multi column) unique constraints""" unique = [ {c.name for c in u.columns} @@ -207,7 +195,7 @@ def _unique_constrains(cls) -> List[Set[str]]: return unique @classmethod - def parent_foreign_key_mappings(cls) -> Dict[str, str]: + def parent_foreign_key_mappings(cls) -> dict[str, str]: """Get a mapping of foreign name to the local name of foreign keys""" parent_rel = cls.__mapper__.relationships.get(cls.export_parent) if parent_rel: @@ -217,7 +205,7 @@ def parent_foreign_key_mappings(cls) -> Dict[str, str]: @classmethod def export_schema( cls, recursive: bool = True, include_parent_ref: bool = False - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Export schema as a dictionary""" parent_excludes = set() if not include_parent_ref: @@ -227,12 +215,12 @@ def export_schema( def formatter(column: sa.Column) -> str: return ( - "{0} Default ({1})".format(str(column.type), column.default.arg) + f"{str(column.type)} Default ({column.default.arg})" if column.default else str(column.type) ) - schema: Dict[str, Any] = { + schema: dict[str, Any] = { column.name: formatter(column) for column in cls.__table__.columns # type: ignore if (column.name in cls.export_fields and column.name not in parent_excludes) @@ -252,10 +240,10 @@ def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals cls, session: Session, - dict_rep: Dict[Any, Any], + dict_rep: dict[Any, Any], parent: Optional[Any] = None, recursive: bool = True, - sync: Optional[List[str]] = None, + sync: Optional[list[str]] = None, ) -> Any: """Import obj from a dictionary""" if sync is None: @@ -281,9 +269,7 @@ def import_from_dict( if cls.export_parent: for prnt in parent_refs.keys(): if prnt not in dict_rep: - raise RuntimeError( - "{0}: Missing field {1}".format(cls.__name__, prnt) - ) + raise RuntimeError(f"{cls.__name__}: Missing field {prnt}") else: # Set foreign keys to parent obj for k, v in parent_refs.items(): @@ -371,7 +357,7 @@ def export_to_dict( include_parent_ref: bool = False, include_defaults: bool = False, export_uuids: bool = False, - ) -> Dict[Any, Any]: + ) -> dict[Any, Any]: """Export obj to dictionary""" export_fields = set(self.export_fields) if export_uuids: @@ -457,18 +443,18 @@ def reset_ownership(self) -> None: self.owners = [g.user] @property - def params_dict(self) -> Dict[Any, Any]: + def params_dict(self) -> dict[Any, Any]: return json_to_dict(self.params) @property - def template_params_dict(self) -> Dict[Any, Any]: + def template_params_dict(self) -> dict[Any, Any]: return json_to_dict(self.template_params) # type: ignore def _user_link(user: User) -> Union[Markup, str]: if not user: return "" - url = "/superset/profile/{}/".format(user.username) + url = f"/superset/profile/{user.username}/" return Markup('{}'.format(url, escape(user) or "")) @@ -505,13 +491,13 @@ def changed_by_fk(self) -> sa.Column: @property def created_by_name(self) -> str: if self.created_by: - return escape("{}".format(self.created_by)) + return escape(f"{self.created_by}") return "" @property def changed_by_name(self) -> str: if self.changed_by: - return escape("{}".format(self.changed_by)) + return escape(f"{self.changed_by}") return "" @renders("created_by") @@ -565,12 +551,12 @@ def __init__( # pylint: disable=too-many-arguments df: pd.DataFrame, query: str, duration: timedelta, - applied_template_filters: Optional[List[str]] = None, - applied_filter_columns: Optional[List[ColumnTyping]] = None, - rejected_filter_columns: Optional[List[ColumnTyping]] = None, + applied_template_filters: Optional[list[str]] = None, + applied_filter_columns: Optional[list[ColumnTyping]] = None, + rejected_filter_columns: Optional[list[ColumnTyping]] = None, status: str = QueryStatus.SUCCESS, error_message: Optional[str] = None, - errors: Optional[List[Dict[str, Any]]] = None, + errors: Optional[list[dict[str, Any]]] = None, from_dttm: Optional[datetime] = None, to_dttm: Optional[datetime] = None, ) -> None: @@ -593,7 +579,7 @@ class ExtraJSONMixin: extra_json = sa.Column(sa.Text, default="{}") @property - def extra(self) -> Dict[str, Any]: + def extra(self) -> dict[str, Any]: try: return json.loads(self.extra_json or "{}") or {} except (TypeError, JSONDecodeError) as exc: @@ -603,7 +589,7 @@ def extra(self) -> Dict[str, Any]: return {} @extra.setter - def extra(self, extras: Dict[str, Any]) -> None: + def extra(self, extras: dict[str, Any]) -> None: self.extra_json = json.dumps(extras) def set_extra_json_key(self, key: str, value: Any) -> None: @@ -615,7 +601,7 @@ def set_extra_json_key(self, key: str, value: Any) -> None: def ensure_extra_json_is_not_none( # pylint: disable=no-self-use self, _: str, - value: Optional[Dict[str, Any]], + value: Optional[dict[str, Any]], ) -> Any: if value is None: return "{}" @@ -627,7 +613,7 @@ class CertificationMixin: extra = sa.Column(sa.Text, default="{}") - def get_extra_dict(self) -> Dict[str, Any]: + def get_extra_dict(self) -> dict[str, Any]: try: return json.loads(self.extra) except (TypeError, json.JSONDecodeError): @@ -652,8 +638,8 @@ def warning_markdown(self) -> Optional[str]: def clone_model( target: Model, - ignore: Optional[List[str]] = None, - keep_relations: Optional[List[str]] = None, + ignore: Optional[list[str]] = None, + keep_relations: Optional[list[str]] = None, **kwargs: Any, ) -> Model: """ @@ -676,22 +662,22 @@ def clone_model( # todo(hugh): centralize where this code lives class QueryStringExtended(NamedTuple): - applied_template_filters: Optional[List[str]] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] - labels_expected: List[str] - prequeries: List[str] + applied_template_filters: Optional[list[str]] + applied_filter_columns: list[ColumnTyping] + rejected_filter_columns: list[ColumnTyping] + labels_expected: list[str] + prequeries: list[str] sql: str class SqlaQuery(NamedTuple): - applied_template_filters: List[str] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] + applied_template_filters: list[str] + applied_filter_columns: list[ColumnTyping] + rejected_filter_columns: list[ColumnTyping] cte: Optional[str] - extra_cache_keys: List[Any] - labels_expected: List[str] - prequeries: List[str] + extra_cache_keys: list[Any] + labels_expected: list[str] + prequeries: list[str] sqla_query: Select @@ -719,7 +705,7 @@ def type(self) -> str: raise NotImplementedError() @property - def db_extra(self) -> Optional[Dict[str, Any]]: + def db_extra(self) -> Optional[dict[str, Any]]: raise NotImplementedError() def query(self, query_obj: QueryObjectDict) -> QueryResult: @@ -730,11 +716,11 @@ def database_id(self) -> int: raise NotImplementedError() @property - def owners_data(self) -> List[Any]: + def owners_data(self) -> list[Any]: raise NotImplementedError() @property - def metrics(self) -> List[Any]: + def metrics(self) -> list[Any]: return [] @property @@ -750,7 +736,7 @@ def cache_timeout(self) -> int: raise NotImplementedError() @property - def column_names(self) -> List[str]: + def column_names(self) -> list[str]: raise NotImplementedError() @property @@ -762,15 +748,15 @@ def main_dttm_col(self) -> Optional[str]: raise NotImplementedError() @property - def dttm_cols(self) -> List[str]: + def dttm_cols(self) -> list[str]: raise NotImplementedError() @property - def db_engine_spec(self) -> Type["BaseEngineSpec"]: + def db_engine_spec(self) -> builtins.type["BaseEngineSpec"]: raise NotImplementedError() @property - def database(self) -> Type["Database"]: + def database(self) -> builtins.type["Database"]: raise NotImplementedError() @property @@ -782,7 +768,7 @@ def sql(self) -> str: raise NotImplementedError() @property - def columns(self) -> List[Any]: + def columns(self) -> list[Any]: raise NotImplementedError() def get_fetch_values_predicate( @@ -790,7 +776,7 @@ def get_fetch_values_predicate( ) -> TextClause: raise NotImplementedError() - def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: + def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: @@ -799,7 +785,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor, - ) -> List[TextClause]: + ) -> list[TextClause]: """ Return the appropriate row level security filters for this table and the current user. A custom username can be passed when the user is not present in the @@ -808,8 +794,8 @@ def get_sqla_row_level_filters( :param template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. """ - all_filters: List[TextClause] = [] - filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) + all_filters: list[TextClause] = [] + filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): clause = self.text( @@ -923,8 +909,8 @@ def _normalize_prequery_result_type( self, row: pd.Series, dimension: str, - columns_by_name: Dict[str, "TableColumn"], - ) -> Union[str, int, float, bool, Text]: + columns_by_name: dict[str, "TableColumn"], + ) -> Union[str, int, float, bool, str]: """ Convert a prequery result type to its equivalent Python type. @@ -944,7 +930,7 @@ def _normalize_prequery_result_type( value = value.item() column_ = columns_by_name[dimension] - db_extra: Dict[str, Any] = self.database.get_extra() # type: ignore + db_extra: dict[str, Any] = self.database.get_extra() # type: ignore if isinstance(column_, dict): if ( @@ -969,7 +955,7 @@ def _normalize_prequery_result_type( return value def make_orderby_compatible( - self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] + self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement] ) -> None: """ If needed, make sure aliases for selected columns are not used in @@ -1088,7 +1074,7 @@ def text(self, clause: str) -> TextClause: def get_from_clause( self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Tuple[Union[TableClause, Alias], Optional[str]]: + ) -> tuple[Union[TableClause, Alias], Optional[str]]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a @@ -1117,7 +1103,7 @@ def get_from_clause( def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument + columns_by_name: dict[str, "TableColumn"], # pylint: disable=unused-argument template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ @@ -1151,7 +1137,7 @@ def adhoc_metric_to_sqla( return self.make_sqla_column_compatible(sqla_metric, label) @property - def template_params_dict(self) -> Dict[Any, Any]: + def template_params_dict(self) -> dict[Any, Any]: return {} @staticmethod @@ -1162,9 +1148,9 @@ def filter_values_handler( # pylint: disable=too-many-arguments target_native_type: Optional[str] = None, is_list_target: bool = False, db_engine_spec: Optional[ - Type["BaseEngineSpec"] + builtins.type["BaseEngineSpec"] ] = None, # fix(hughhh): Optional[Type[BaseEngineSpec]] - db_extra: Optional[Dict[str, Any]] = None, + db_extra: Optional[dict[str, Any]] = None, ) -> Optional[FilterValues]: if values is None: return None @@ -1217,8 +1203,8 @@ def get_query_str(self, query_obj: QueryObjectDict) -> str: def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Dict[str, "SqlMetric"], - columns_by_name: Dict[str, "TableColumn"], + metrics_by_name: dict[str, "SqlMetric"], + columns_by_name: dict[str, "TableColumn"], template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): @@ -1248,9 +1234,9 @@ def adhoc_column_to_sqla( def _get_top_groups( self, df: pd.DataFrame, - dimensions: List[str], - groupby_exprs: Dict[str, Any], - columns_by_name: Dict[str, "TableColumn"], + dimensions: list[str], + groupby_exprs: dict[str, Any], + columns_by_name: dict[str, "TableColumn"], ) -> ColumnElement: groups = [] for _unused, row in df.iterrows(): @@ -1335,7 +1321,7 @@ def get_time_filter( # pylint: disable=too-many-arguments ) return and_(*l) - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -1369,7 +1355,7 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: def get_timestamp_expression( self, - column: Dict[str, Any], + column: dict[str, Any], time_grain: Optional[str], label: Optional[str] = None, template_processor: Optional[BaseTemplateProcessor] = None, @@ -1417,23 +1403,23 @@ def convert_tbl_column_to_sqla_col( def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, apply_fetch_values_predicate: bool = False, - columns: Optional[List[Column]] = None, - extras: Optional[Dict[str, Any]] = None, + columns: Optional[list[Column]] = None, + extras: Optional[dict[str, Any]] = None, filter: Optional[ # pylint: disable=redefined-builtin - List[utils.QueryObjectFilterClause] + list[utils.QueryObjectFilterClause] ] = None, from_dttm: Optional[datetime] = None, granularity: Optional[str] = None, - groupby: Optional[List[Column]] = None, + groupby: Optional[list[Column]] = None, inner_from_dttm: Optional[datetime] = None, inner_to_dttm: Optional[datetime] = None, is_rowcount: bool = False, is_timeseries: bool = True, - metrics: Optional[List[Metric]] = None, - orderby: Optional[List[OrderBy]] = None, + metrics: Optional[list[Metric]] = None, + orderby: Optional[list[OrderBy]] = None, order_desc: bool = True, to_dttm: Optional[datetime] = None, - series_columns: Optional[List[Column]] = None, + series_columns: Optional[list[Column]] = None, series_limit: Optional[int] = None, series_limit_metric: Optional[Metric] = None, row_limit: Optional[int] = None, @@ -1464,23 +1450,23 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma } columns = columns or [] groupby = groupby or [] - rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] - applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + rejected_adhoc_filters_columns: list[Union[str, ColumnTyping]] = [] + applied_adhoc_filters_columns: list[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: series_limit = timeseries_limit series_limit_metric = series_limit_metric or timeseries_limit_metric template_kwargs.update(self.template_params_dict) - extra_cache_keys: List[Any] = [] + extra_cache_keys: list[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys - removed_filters: List[str] = [] - applied_template_filters: List[str] = [] + removed_filters: list[str] = [] + applied_template_filters: list[str] = [] template_kwargs["removed_filters"] = removed_filters template_kwargs["applied_filters"] = applied_template_filters template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.db_engine_spec - prequeries: List[str] = [] + prequeries: list[str] = [] orderby = orderby or [] need_groupby = bool(metrics is not None or groupby) metrics = metrics or [] @@ -1489,11 +1475,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, "TableColumn"] = { + columns_by_name: dict[str, "TableColumn"] = { col.column_name: col for col in self.columns } - metrics_by_name: Dict[str, "SqlMetric"] = { + metrics_by_name: dict[str, "SqlMetric"] = { m.metric_name: m for m in self.metrics } @@ -1507,7 +1493,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if not metrics and not columns and not groupby: raise QueryObjectValidationError(_("Empty query?")) - metrics_exprs: List[ColumnElement] = [] + metrics_exprs: list[ColumnElement] = [] for metric in metrics: if utils.is_adhoc_metric(metric): assert isinstance(metric, dict) @@ -1542,7 +1528,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} # Since orderby may use adhoc metrics, too; we need to process them first - orderby_exprs: List[ColumnElement] = [] + orderby_exprs: list[ColumnElement] = [] for orig_col, ascending in orderby: col: Union[AdhocMetric, ColumnElement] = orig_col if isinstance(col, dict): @@ -1582,7 +1568,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma _("Unknown column used in orderby: %(col)s", col=orig_col) ) - select_exprs: List[Union[Column, Label]] = [] + select_exprs: list[Union[Column, Label]] = [] groupby_all_columns = {} groupby_series_columns = {} diff --git a/superset/models/slice.py b/superset/models/slice.py index 6835215338f49..15dddfc7e1eaf 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -18,7 +18,7 @@ import json import logging -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from urllib import parse import sqlalchemy as sqla @@ -70,7 +70,7 @@ class Slice( # pylint: disable=too-many-public-methods ): """A slice is essentially a report or a view on data""" - query_context_factory: Optional[QueryContextFactory] = None + query_context_factory: QueryContextFactory | None = None __tablename__ = "slices" id = Column(Integer, primary_key=True) @@ -134,17 +134,17 @@ def __repr__(self) -> str: return self.slice_name or str(self.id) @property - def cls_model(self) -> Type["BaseDatasource"]: + def cls_model(self) -> type[BaseDatasource]: # pylint: disable=import-outside-toplevel from superset.datasource.dao import DatasourceDAO return DatasourceDAO.sources[self.datasource_type] @property - def datasource(self) -> Optional["BaseDatasource"]: + def datasource(self) -> BaseDatasource | None: return self.get_datasource - def clone(self) -> "Slice": + def clone(self) -> Slice: return Slice( slice_name=self.slice_name, datasource_id=self.datasource_id, @@ -158,7 +158,7 @@ def clone(self) -> "Slice": # pylint: disable=using-constant-test @datasource.getter # type: ignore - def get_datasource(self) -> Optional["BaseDatasource"]: + def get_datasource(self) -> BaseDatasource | None: return ( db.session.query(self.cls_model) .filter_by(id=self.datasource_id) @@ -166,20 +166,20 @@ def get_datasource(self) -> Optional["BaseDatasource"]: ) @renders("datasource_name") - def datasource_link(self) -> Optional[Markup]: + def datasource_link(self) -> Markup | None: # pylint: disable=no-member datasource = self.datasource return datasource.link if datasource else None @renders("datasource_url") - def datasource_url(self) -> Optional[str]: + def datasource_url(self) -> str | None: # pylint: disable=no-member if self.table: return self.table.explore_url datasource = self.datasource return datasource.explore_url if datasource else None - def datasource_name_text(self) -> Optional[str]: + def datasource_name_text(self) -> str | None: # pylint: disable=no-member if self.table: if self.table.schema: @@ -192,7 +192,7 @@ def datasource_name_text(self) -> Optional[str]: return None @property - def datasource_edit_url(self) -> Optional[str]: + def datasource_edit_url(self) -> str | None: # pylint: disable=no-member datasource = self.datasource return datasource.url if datasource else None @@ -200,7 +200,7 @@ def datasource_edit_url(self) -> Optional[str]: # pylint: enable=using-constant-test @property - def viz(self) -> Optional[BaseViz]: + def viz(self) -> BaseViz | None: form_data = json.loads(self.params) viz_class = viz_types.get(self.viz_type) datasource = self.datasource @@ -213,9 +213,9 @@ def description_markeddown(self) -> str: return utils.markdown(self.description) @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: """Data used to render slice in templates""" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} self.token = "" try: viz = self.viz @@ -261,8 +261,8 @@ def json_data(self) -> str: return json.dumps(self.data) @property - def form_data(self) -> Dict[str, Any]: - form_data: Dict[str, Any] = {} + def form_data(self) -> dict[str, Any]: + form_data: dict[str, Any] = {} try: form_data = json.loads(self.params) except Exception as ex: # pylint: disable=broad-except @@ -272,7 +272,7 @@ def form_data(self) -> Dict[str, Any]: { "slice_id": self.id, "viz_type": self.viz_type, - "datasource": "{}__{}".format(self.datasource_id, self.datasource_type), + "datasource": f"{self.datasource_id}__{self.datasource_type}", } ) @@ -281,7 +281,7 @@ def form_data(self) -> Dict[str, Any]: update_time_range(form_data) return form_data - def get_query_context(self) -> Optional[QueryContext]: + def get_query_context(self) -> QueryContext | None: if self.query_context: try: return self.get_query_context_factory().create( @@ -295,13 +295,13 @@ def get_query_context(self) -> Optional[QueryContext]: def get_explore_url( self, base_url: str = "/explore", - overrides: Optional[Dict[str, Any]] = None, + overrides: dict[str, Any] | None = None, ) -> str: return self.build_explore_url(self.id, base_url, overrides) @staticmethod def build_explore_url( - id_: int, base_url: str = "/explore", overrides: Optional[Dict[str, Any]] = None + id_: int, base_url: str = "/explore", overrides: dict[str, Any] | None = None ) -> str: overrides = overrides or {} form_data = {"slice_id": id_} diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index b2f0c8c1ed2f1..b9ab153798f9c 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. """A collection of ORM sqlalchemy models for SQL Lab""" +import builtins import inspect import logging import re +from collections.abc import Hashable from datetime import datetime -from typing import Any, Dict, Hashable, List, Optional, Type, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import simplejson as json import sqlalchemy as sqla @@ -131,7 +133,7 @@ class Query( def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(query=self, database=self.database, **kwargs) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "changedOn": self.changed_on, "changed_on": self.changed_on.isoformat(), @@ -181,11 +183,11 @@ def username(self) -> str: return self.user.username @property - def sql_tables(self) -> List[Table]: + def sql_tables(self) -> list[Table]: return list(ParsedQuery(self.sql).tables) @property - def columns(self) -> List["TableColumn"]: + def columns(self) -> list["TableColumn"]: from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel TableColumn, ) @@ -204,11 +206,11 @@ def columns(self) -> List["TableColumn"]: return columns @property - def db_extra(self) -> Optional[Dict[str, Any]]: + def db_extra(self) -> Optional[dict[str, Any]]: return None @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: order_by_choices = [] for col in self.columns: column_name = str(col.column_name or "") @@ -247,11 +249,13 @@ def raise_for_access(self) -> None: security_manager.raise_for_access(query=self) @property - def db_engine_spec(self) -> Type["BaseEngineSpec"]: + def db_engine_spec( + self, + ) -> builtins.type["BaseEngineSpec"]: # pylint: disable=unsubscriptable-object return self.database.db_engine_spec @property - def owners_data(self) -> List[Dict[str, Any]]: + def owners_data(self) -> list[dict[str, Any]]: return [] @property @@ -267,7 +271,7 @@ def cache_timeout(self) -> int: return 0 @property - def column_names(self) -> List[Any]: + def column_names(self) -> list[Any]: return [col.column_name for col in self.columns] @property @@ -282,7 +286,7 @@ def main_dttm_col(self) -> Optional[str]: return None @property - def dttm_cols(self) -> List[Any]: + def dttm_cols(self) -> list[Any]: return [col.column_name for col in self.columns if col.is_dttm] @property @@ -298,7 +302,7 @@ def default_endpoint(self) -> str: return "" @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]: + def get_extra_cache_keys(query_obj: dict[str, Any]) -> list[Hashable]: return [] @property @@ -322,7 +326,7 @@ def tracking_url(self) -> Optional[str]: def tracking_url(self, value: str) -> None: self.tracking_url_raw = value - def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]: + def get_column(self, column_name: Optional[str]) -> Optional[dict[str, Any]]: if not column_name: return None for col in self.columns: @@ -397,7 +401,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): def __repr__(self) -> str: return str(self.label) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, } @@ -421,10 +425,10 @@ def sqlalchemy_uri(self) -> URL: return self.database.sqlalchemy_uri def url(self) -> str: - return "/superset/sqllab?savedQueryId={0}".format(self.id) + return f"/superset/sqllab?savedQueryId={self.id}" @property - def sql_tables(self) -> List[Table]: + def sql_tables(self) -> list[Table]: return list(ParsedQuery(self.sql).tables) @property @@ -483,7 +487,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin): ) saved_query = relationship("SavedQuery", foreign_keys=[saved_query_id]) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "user_id": self.user_id, @@ -520,7 +524,7 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin): expanded = Column(Boolean, default=False) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: try: description = json.loads(self.description) except json.JSONDecodeError: diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index c496f750399c8..234581dfb4c2c 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=abstract-method, no-init -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql.sqltypes import DATE, Integer, TIMESTAMP @@ -33,7 +33,7 @@ class TinyInteger(Integer): """ @property - def python_type(self) -> Type[int]: + def python_type(self) -> type[int]: return int @classmethod @@ -47,7 +47,7 @@ class Interval(TypeEngine): """ @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[type[Any]]: return None @classmethod @@ -61,7 +61,7 @@ class Array(TypeEngine): """ @property - def python_type(self) -> Optional[Type[List[Any]]]: + def python_type(self) -> Optional[type[list[Any]]]: return list @classmethod @@ -75,7 +75,7 @@ class Map(TypeEngine): """ @property - def python_type(self) -> Optional[Type[Dict[Any, Any]]]: + def python_type(self) -> Optional[type[dict[Any, Any]]]: return dict @classmethod @@ -89,7 +89,7 @@ class Row(TypeEngine): """ @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[type[Any]]: return None @classmethod diff --git a/superset/queries/dao.py b/superset/queries/dao.py index 642a5dd4cbb1c..e9fe15cac5c7c 100644 --- a/superset/queries/dao.py +++ b/superset/queries/dao.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Any, Union from superset import sql_lab from superset.common.db_query_status import QueryStatus @@ -56,14 +56,14 @@ def update_saved_query_exec_info(query_id: int) -> None: db.session.commit() @staticmethod - def save_metadata(query: Query, payload: Dict[str, Any]) -> None: + def save_metadata(query: Query, payload: dict[str, Any]) -> None: # pull relevant data from payload and store in extra_json columns = payload.get("columns", {}) db.session.add(query) query.set_extra_json_key("columns", columns) @staticmethod - def get_queries_changed_after(last_updated_ms: Union[float, int]) -> List[Query]: + def get_queries_changed_after(last_updated_ms: Union[float, int]) -> list[Query]: # UTC date time, same that is stored in the DB. last_updated_dt = datetime.utcfromtimestamp(last_updated_ms / 1000) diff --git a/superset/queries/saved_queries/commands/bulk_delete.py b/superset/queries/saved_queries/commands/bulk_delete.py index c96afd31e58a6..fb230180c8137 100644 --- a/superset/queries/saved_queries/commands/bulk_delete.py +++ b/superset/queries/saved_queries/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError @@ -30,9 +30,9 @@ class BulkDeleteSavedQueryCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Dashboard]] = None + self._models: Optional[list[Dashboard]] = None def run(self) -> None: self.validate() diff --git a/superset/queries/saved_queries/commands/export.py b/superset/queries/saved_queries/commands/export.py index 8c5357159e604..323256306ac20 100644 --- a/superset/queries/saved_queries/commands/export.py +++ b/superset/queries/saved_queries/commands/export.py @@ -18,7 +18,7 @@ import json import logging -from typing import Iterator, Tuple +from collections.abc import Iterator import yaml from werkzeug.utils import secure_filename @@ -39,7 +39,7 @@ class ExportSavedQueriesCommand(ExportModelsCommand): @staticmethod def _export( model: SavedQuery, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: # build filename based on database, optional schema, and label database_slug = secure_filename(model.database.database_name) schema_slug = secure_filename(model.schema) diff --git a/superset/queries/saved_queries/commands/importers/dispatcher.py b/superset/queries/saved_queries/commands/importers/dispatcher.py index 828320222567f..c2208f0e2af0a 100644 --- a/superset/queries/saved_queries/commands/importers/dispatcher.py +++ b/superset/queries/saved_queries/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -40,7 +40,7 @@ class ImportSavedQueriesCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/queries/saved_queries/commands/importers/v1/__init__.py b/superset/queries/saved_queries/commands/importers/v1/__init__.py index 1412dbd356125..79ec04f54b4f4 100644 --- a/superset/queries/saved_queries/commands/importers/v1/__init__.py +++ b/superset/queries/saved_queries/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Set +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -38,7 +38,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand): dao = SavedQueryDAO model_name = "saved_queries" prefix = "queries/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "queries/": ImportV1SavedQuerySchema(), } @@ -46,16 +46,16 @@ class ImportSavedQueriesCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover databases associated with saved queries - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("queries/"): database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) diff --git a/superset/queries/saved_queries/commands/importers/v1/utils.py b/superset/queries/saved_queries/commands/importers/v1/utils.py index f2d090bf11e5b..813f3c2295f58 100644 --- a/superset/queries/saved_queries/commands/importers/v1/utils.py +++ b/superset/queries/saved_queries/commands/importers/v1/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -23,7 +23,7 @@ def import_saved_query( - session: Session, config: Dict[str, Any], overwrite: bool = False + session: Session, config: dict[str, Any], overwrite: bool = False ) -> SavedQuery: existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first() if existing: diff --git a/superset/queries/saved_queries/dao.py b/superset/queries/saved_queries/dao.py index c6bcfa035ccea..daae1de8f5bd8 100644 --- a/superset/queries/saved_queries/dao.py +++ b/superset/queries/saved_queries/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from sqlalchemy.exc import SQLAlchemyError @@ -33,7 +33,7 @@ class SavedQueryDAO(BaseDAO): base_filter = SavedQueryFilter @staticmethod - def bulk_delete(models: Optional[List[SavedQuery]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[SavedQuery]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] try: db.session.query(SavedQuery).filter(SavedQuery.id.in_(item_ids)).delete( diff --git a/superset/queries/schemas.py b/superset/queries/schemas.py index b139784c5b58c..850664e92f49e 100644 --- a/superset/queries/schemas.py +++ b/superset/queries/schemas.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List from marshmallow import fields, Schema @@ -73,7 +72,7 @@ class Meta: # pylint: disable=too-few-public-methods include_relationships = True # pylint: disable=no-self-use - def get_sql_tables(self, obj: Query) -> List[Table]: + def get_sql_tables(self, obj: Query) -> list[Table]: return obj.sql_tables diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py index c5b4709447fa1..41163dc064085 100644 --- a/superset/reports/commands/alert.py +++ b/superset/reports/commands/alert.py @@ -20,7 +20,7 @@ import logging from operator import eq, ge, gt, le, lt, ne from timeit import default_timer -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -54,7 +54,7 @@ class AlertCommand(BaseCommand): def __init__(self, report_schedule: ReportSchedule): self._report_schedule = report_schedule - self._result: Optional[float] = None + self._result: float | None = None def run(self) -> bool: """ diff --git a/superset/reports/commands/base.py b/superset/reports/commands/base.py index 4fee6a8824568..598370576b370 100644 --- a/superset/reports/commands/base.py +++ b/superset/reports/commands/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from marshmallow import ValidationError @@ -36,7 +36,7 @@ class BaseReportScheduleCommand(BaseCommand): - _properties: Dict[str, Any] + _properties: dict[str, Any] def run(self) -> Any: pass @@ -45,7 +45,7 @@ def validate(self) -> None: pass def validate_chart_dashboard( - self, exceptions: List[ValidationError], update: bool = False + self, exceptions: list[ValidationError], update: bool = False ) -> None: """Validate chart or dashboard relation""" chart_id = self._properties.get("chart") diff --git a/superset/reports/commands/bulk_delete.py b/superset/reports/commands/bulk_delete.py index 28a39a2fb6ea2..7d6e1ed791310 100644 --- a/superset/reports/commands/bulk_delete.py +++ b/superset/reports/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset import security_manager from superset.commands.base import BaseCommand @@ -33,9 +33,9 @@ class BulkDeleteReportScheduleCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[ReportSchedule]] = None + self._models: Optional[list[ReportSchedule]] = None def run(self) -> None: self.validate() diff --git a/superset/reports/commands/create.py b/superset/reports/commands/create.py index 27626170d6458..04cf6ef43fe0d 100644 --- a/superset/reports/commands/create.py +++ b/superset/reports/commands/create.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import gettext as _ from marshmallow import ValidationError @@ -46,7 +46,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> ReportSchedule: @@ -59,8 +59,8 @@ def run(self) -> ReportSchedule: return report_schedule def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") name = self._properties.get("name", "") report_type = self._properties.get("type") creation_method = self._properties.get("creation_method") @@ -119,7 +119,7 @@ def validate(self) -> None: if exceptions: raise ReportScheduleInvalidError(exceptions=exceptions) - def _validate_report_extra(self, exceptions: List[ValidationError]) -> None: + def _validate_report_extra(self, exceptions: list[ValidationError]) -> None: extra: Optional[ReportScheduleExtra] = self._properties.get("extra") dashboard = self._properties.get("dashboard") diff --git a/superset/reports/commands/exceptions.py b/superset/reports/commands/exceptions.py index 22aff0727da46..cba12e0786c21 100644 --- a/superset/reports/commands/exceptions.py +++ b/superset/reports/commands/exceptions.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List from flask_babel import lazy_gettext as _ @@ -263,13 +262,13 @@ class ReportScheduleStateNotFoundError(CommandException): class ReportScheduleSystemErrorsException(CommandException, SupersetErrorsException): - errors: List[SupersetError] = [] + errors: list[SupersetError] = [] message = _("Report schedule system error") class ReportScheduleClientErrorsException(CommandException, SupersetErrorsException): status = 400 - errors: List[SupersetError] = [] + errors: list[SupersetError] = [] message = _("Report schedule client error") diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index f5f7bf4130254..608b2564a2907 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -17,7 +17,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from uuid import UUID import pandas as pd @@ -80,7 +80,7 @@ class BaseReportState: - current_states: List[ReportState] = [] + current_states: list[ReportState] = [] initial: bool = False def __init__( @@ -195,7 +195,7 @@ def _get_url( **kwargs, ) - def _get_screenshots(self) -> List[bytes]: + def _get_screenshots(self) -> list[bytes]: """ Get chart or dashboard screenshots :raises: ReportScheduleScreenshotFailedError @@ -394,14 +394,14 @@ def _get_notification_content(self) -> NotificationContent: def _send( self, notification_content: NotificationContent, - recipients: List[ReportRecipients], + recipients: list[ReportRecipients], ) -> None: """ Sends a notification to all recipients :raises: CommandException """ - notification_errors: List[SupersetError] = [] + notification_errors: list[SupersetError] = [] for recipient in recipients: notification = create_notification(recipient, notification_content) try: diff --git a/superset/reports/commands/update.py b/superset/reports/commands/update.py index 0c4f18f1b842f..5ca3ac849a5ee 100644 --- a/superset/reports/commands/update.py +++ b/superset/reports/commands/update.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -42,7 +42,7 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[ReportSchedule] = None @@ -57,8 +57,8 @@ def run(self) -> Model: return report_schedule def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") report_type = self._properties.get("type", ReportScheduleType.ALERT) name = self._properties.get("name", "") diff --git a/superset/reports/dao.py b/superset/reports/dao.py index be5ee8053c48b..64777e959ae11 100644 --- a/superset/reports/dao.py +++ b/superset/reports/dao.py @@ -17,7 +17,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder import Model from sqlalchemy.exc import SQLAlchemyError @@ -47,7 +47,7 @@ class ReportScheduleDAO(BaseDAO): base_filter = ReportScheduleFilter @staticmethod - def find_by_chart_id(chart_id: int) -> List[ReportSchedule]: + def find_by_chart_id(chart_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.chart_id == chart_id) @@ -55,7 +55,7 @@ def find_by_chart_id(chart_id: int) -> List[ReportSchedule]: ) @staticmethod - def find_by_chart_ids(chart_ids: List[int]) -> List[ReportSchedule]: + def find_by_chart_ids(chart_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.chart_id.in_(chart_ids)) @@ -63,7 +63,7 @@ def find_by_chart_ids(chart_ids: List[int]) -> List[ReportSchedule]: ) @staticmethod - def find_by_dashboard_id(dashboard_id: int) -> List[ReportSchedule]: + def find_by_dashboard_id(dashboard_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.dashboard_id == dashboard_id) @@ -71,7 +71,7 @@ def find_by_dashboard_id(dashboard_id: int) -> List[ReportSchedule]: ) @staticmethod - def find_by_dashboard_ids(dashboard_ids: List[int]) -> List[ReportSchedule]: + def find_by_dashboard_ids(dashboard_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.dashboard_id.in_(dashboard_ids)) @@ -79,7 +79,7 @@ def find_by_dashboard_ids(dashboard_ids: List[int]) -> List[ReportSchedule]: ) @staticmethod - def find_by_database_id(database_id: int) -> List[ReportSchedule]: + def find_by_database_id(database_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.database_id == database_id) @@ -87,7 +87,7 @@ def find_by_database_id(database_id: int) -> List[ReportSchedule]: ) @staticmethod - def find_by_database_ids(database_ids: List[int]) -> List[ReportSchedule]: + def find_by_database_ids(database_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.database_id.in_(database_ids)) @@ -96,7 +96,7 @@ def find_by_database_ids(database_ids: List[int]) -> List[ReportSchedule]: @staticmethod def bulk_delete( - models: Optional[List[ReportSchedule]], commit: bool = True + models: Optional[list[ReportSchedule]], commit: bool = True ) -> None: item_ids = [model.id for model in models] if models else [] try: @@ -156,7 +156,7 @@ def validate_update_uniqueness( return found_id is None or found_id == expect_id @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> ReportSchedule: + def create(cls, properties: dict[str, Any], commit: bool = True) -> ReportSchedule: """ create a report schedule and nested recipients :raises: DAOCreateFailedError @@ -187,7 +187,7 @@ def create(cls, properties: Dict[str, Any], commit: bool = True) -> ReportSchedu @classmethod def update( - cls, model: Model, properties: Dict[str, Any], commit: bool = True + cls, model: Model, properties: dict[str, Any], commit: bool = True ) -> ReportSchedule: """ create a report schedule and nested recipients @@ -219,7 +219,7 @@ def update( raise DAOCreateFailedError(str(ex)) from ex @staticmethod - def find_active(session: Optional[Session] = None) -> List[ReportSchedule]: + def find_active(session: Optional[Session] = None) -> list[ReportSchedule]: """ Find all active reports. If session is passed it will be used instead of the default `db.session`, this is useful when on a celery worker session context diff --git a/superset/reports/filters.py b/superset/reports/filters.py index 5fb87e0563345..a03238b640849 100644 --- a/superset/reports/filters.py +++ b/superset/reports/filters.py @@ -52,6 +52,6 @@ def apply(self, query: Query, value: Any) -> Query: or_( ReportSchedule.name.ilike(ilike_value), ReportSchedule.description.ilike(ilike_value), - ReportSchedule.sql.ilike((ilike_value)), + ReportSchedule.sql.ilike(ilike_value), ) ) diff --git a/superset/reports/logs/api.py b/superset/reports/logs/api.py index 8ad8455cc7cc5..f0c272caee7c7 100644 --- a/superset/reports/logs/api.py +++ b/superset/reports/logs/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe @@ -83,7 +83,7 @@ def ensure_alert_reports_enabled(self) -> Optional[Response]: @staticmethod def _apply_layered_relation_to_rison( # pylint: disable=invalid-name - layer_id: int, rison_parameters: Dict[str, Any] + layer_id: int, rison_parameters: dict[str, Any] ) -> None: if "filters" not in rison_parameters: rison_parameters["filters"] = [] diff --git a/superset/reports/notifications/__init__.py b/superset/reports/notifications/__init__.py index c466f59abd5b3..f2ac40bb4671f 100644 --- a/superset/reports/notifications/__init__.py +++ b/superset/reports/notifications/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/superset/reports/notifications/base.py b/superset/reports/notifications/base.py index 6eb2405d0ff67..640b326fc53d8 100644 --- a/superset/reports/notifications/base.py +++ b/superset/reports/notifications/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from dataclasses import dataclass -from typing import Any, List, Optional, Type +from typing import Any, Optional import pandas as pd @@ -29,7 +28,7 @@ class NotificationContent: name: str header_data: HeaderDataType # this is optional to account for error states csv: Optional[bytes] = None # bytes for csv file - screenshots: Optional[List[bytes]] = None # bytes for a list of screenshots + screenshots: Optional[list[bytes]] = None # bytes for a list of screenshots text: Optional[str] = None description: Optional[str] = "" url: Optional[str] = None # url to chart/dashboard for this screenshot @@ -44,7 +43,7 @@ class BaseNotification: # pylint: disable=too-few-public-methods notification type """ - plugins: List[Type["BaseNotification"]] = [] + plugins: list[type["BaseNotification"]] = [] type: Optional[ReportRecipientType] = None """ Child classes set their notification type ex: `type = "email"` this string will be diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py index 10a76e757387d..1b9e4ade72f01 100644 --- a/superset/reports/notifications/email.py +++ b/superset/reports/notifications/email.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -20,7 +19,7 @@ import textwrap from dataclasses import dataclass from email.utils import make_msgid, parseaddr -from typing import Any, Dict, Optional +from typing import Any, Optional import nh3 from flask_babel import gettext as __ @@ -69,8 +68,8 @@ class EmailContent: body: str header_data: Optional[HeaderDataType] = None - data: Optional[Dict[str, Any]] = None - images: Optional[Dict[str, bytes]] = None + data: Optional[dict[str, Any]] = None + images: Optional[dict[str, bytes]] = None class EmailNotification(BaseNotification): # pylint: disable=too-few-public-methods diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py index b89a700ef9c3e..4c3f2ee419a5c 100644 --- a/superset/reports/notifications/slack.py +++ b/superset/reports/notifications/slack.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,8 +16,9 @@ # under the License. import json import logging +from collections.abc import Sequence from io import IOBase -from typing import Sequence, Union +from typing import Union import backoff from flask_babel import gettext as __ diff --git a/superset/reports/schemas.py b/superset/reports/schemas.py index a45ee4cc38576..83dea02f8fa7e 100644 --- a/superset/reports/schemas.py +++ b/superset/reports/schemas.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Union +from typing import Any, Union from croniter import croniter from flask_babel import gettext as _ @@ -212,7 +212,7 @@ class ReportSchedulePostSchema(Schema): @validates_schema def validate_report_references( # pylint: disable=unused-argument,no-self-use - self, data: Dict[str, Any], **kwargs: Any + self, data: dict[str, Any], **kwargs: Any ) -> None: if data["type"] == ReportScheduleType.REPORT: if "database" in data: diff --git a/superset/result_set.py b/superset/result_set.py index 9aa06bba09ca1..f707b91dce58d 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -19,7 +19,7 @@ import datetime import json import logging -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional import numpy as np import pandas as pd @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List[str]: +def dedup(l: list[str], suffix: str = "__", case_sensitive: bool = True) -> list[str]: """De-duplicates a list of string by suffixing a counter Always returns the same number of entries as provided, and always returns @@ -46,8 +46,8 @@ def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List ) foo,bar,bar__1,bar__2,Bar__3 """ - new_l: List[str] = [] - seen: Dict[str, int] = {} + new_l: list[str] = [] + seen: dict[str, int] = {} for item in l: s_fixed_case = item if case_sensitive else item.lower() if s_fixed_case in seen: @@ -104,14 +104,14 @@ def __init__( # pylint: disable=too-many-locals self, data: DbapiResult, cursor_description: DbapiDescription, - db_engine_spec: Type[BaseEngineSpec], + db_engine_spec: type[BaseEngineSpec], ): self.db_engine_spec = db_engine_spec data = data or [] - column_names: List[str] = [] - pa_data: List[pa.Array] = [] - deduped_cursor_desc: List[Tuple[Any, ...]] = [] - numpy_dtype: List[Tuple[str, ...]] = [] + column_names: list[str] = [] + pa_data: list[pa.Array] = [] + deduped_cursor_desc: list[tuple[Any, ...]] = [] + numpy_dtype: list[tuple[str, ...]] = [] stringified_arr: NDArray[Any] if cursor_description: @@ -181,7 +181,7 @@ def __init__( # pylint: disable=too-many-locals column_names = [] self.table = pa.Table.from_arrays(pa_data, names=column_names) - self._type_dict: Dict[str, Any] = {} + self._type_dict: dict[str, Any] = {} try: # The driver may not be passing a cursor.description self._type_dict = { @@ -245,7 +245,7 @@ def size(self) -> int: return self.table.num_rows @property - def columns(self) -> List[ResultSetColumnType]: + def columns(self) -> list[ResultSetColumnType]: if not self.table.column_names: return [] diff --git a/superset/row_level_security/commands/bulk_delete.py b/superset/row_level_security/commands/bulk_delete.py index a6d4625a91f2b..a3703346cc9ed 100644 --- a/superset/row_level_security/commands/bulk_delete.py +++ b/superset/row_level_security/commands/bulk_delete.py @@ -16,7 +16,6 @@ # under the License. import logging -from typing import List from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError @@ -31,9 +30,9 @@ class BulkDeleteRLSRuleCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: List[ReportSchedule] = [] + self._models: list[ReportSchedule] = [] def run(self) -> None: self.validate() diff --git a/superset/row_level_security/commands/create.py b/superset/row_level_security/commands/create.py index 0c348e10c02c5..5552feeda02b3 100644 --- a/superset/row_level_security/commands/create.py +++ b/superset/row_level_security/commands/create.py @@ -17,7 +17,7 @@ import logging -from typing import Any, Dict +from typing import Any from superset.commands.base import BaseCommand from superset.commands.exceptions import DatasourceNotFoundValidationError @@ -31,7 +31,7 @@ class CreateRLSRuleCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() self._tables = self._properties.get("tables", []) self._roles = self._properties.get("roles", []) diff --git a/superset/row_level_security/commands/update.py b/superset/row_level_security/commands/update.py index 8c276ee2c4e20..a206fc3a393c2 100644 --- a/superset/row_level_security/commands/update.py +++ b/superset/row_level_security/commands/update.py @@ -17,7 +17,7 @@ import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from superset.commands.base import BaseCommand from superset.commands.exceptions import DatasourceNotFoundValidationError @@ -32,7 +32,7 @@ class UpdateRLSRuleCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._tables = self._properties.get("tables", []) diff --git a/superset/security/api.py b/superset/security/api.py index 7aac6ae22be0d..aff536519d80b 100644 --- a/superset/security/api.py +++ b/superset/security/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask import request, Response from flask_appbuilder import expose @@ -56,8 +56,8 @@ class ResourceSchema(PermissiveSchema): @post_load def convert_enum_to_value( # pylint: disable=no-self-use - self, data: Dict[str, Any], **kwargs: Any # pylint: disable=unused-argument - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any # pylint: disable=unused-argument + ) -> dict[str, Any]: # we don't care about the enum, we want the value inside data["type"] = data["type"].value return data diff --git a/superset/security/guest_token.py b/superset/security/guest_token.py index 44b59c1dbbb11..a8dc2e3393bf6 100644 --- a/superset/security/guest_token.py +++ b/superset/security/guest_token.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from typing import List, Optional, TypedDict, Union +from typing import Optional, TypedDict, Union from flask_appbuilder.security.sqla.models import Role from flask_login import AnonymousUserMixin @@ -36,7 +36,7 @@ class GuestTokenResource(TypedDict): id: Union[str, int] -GuestTokenResources = List[GuestTokenResource] +GuestTokenResources = list[GuestTokenResource] class GuestTokenRlsRule(TypedDict): @@ -49,7 +49,7 @@ class GuestToken(TypedDict): exp: float user: GuestTokenUser resources: GuestTokenResources - rls_rules: List[GuestTokenRlsRule] + rls_rules: list[GuestTokenRlsRule] class GuestUser(AnonymousUserMixin): @@ -76,7 +76,7 @@ def is_anonymous(self) -> bool: """ return False - def __init__(self, token: GuestToken, roles: List[Role]): + def __init__(self, token: GuestToken, roles: list[Role]): user = token["user"] self.guest_token = token self.username = user.get("username", "guest_user") diff --git a/superset/security/manager.py b/superset/security/manager.py index db6e631d918b0..94a731a3ffdc0 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -20,18 +20,7 @@ import re import time from collections import defaultdict -from typing import ( - Any, - Callable, - cast, - Dict, - List, - NamedTuple, - Optional, - Set, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union from flask import current_app, Flask, g, Request from flask_appbuilder import Model @@ -479,7 +468,7 @@ def get_datasource_access_error_object( # pylint: disable=invalid-name ) def get_table_access_error_msg( # pylint: disable=no-self-use - self, tables: Set["Table"] + self, tables: set["Table"] ) -> str: """ Return the error message for the denied SQL tables. @@ -492,7 +481,7 @@ def get_table_access_error_msg( # pylint: disable=no-self-use return f"""You need access to the following tables: {", ".join(quoted_tables)}, `all_database_access` or `all_datasource_access` permission""" - def get_table_access_error_object(self, tables: Set["Table"]) -> SupersetError: + def get_table_access_error_object(self, tables: set["Table"]) -> SupersetError: """ Return the error object for the denied SQL tables. @@ -510,7 +499,7 @@ def get_table_access_error_object(self, tables: Set["Table"]) -> SupersetError: ) def get_table_access_link( # pylint: disable=unused-argument,no-self-use - self, tables: Set["Table"] + self, tables: set["Table"] ) -> Optional[str]: """ Return the access link for the denied SQL tables. @@ -521,7 +510,7 @@ def get_table_access_link( # pylint: disable=unused-argument,no-self-use return current_app.config.get("PERMISSION_INSTRUCTIONS_LINK") - def get_user_datasources(self) -> List["BaseDatasource"]: + def get_user_datasources(self) -> list["BaseDatasource"]: """ Collect datasources which the user has explicit permissions to. @@ -542,7 +531,7 @@ def get_user_datasources(self) -> List["BaseDatasource"]: # group all datasources by database session = self.get_session all_datasources = SqlaTable.get_all_datasources(session) - datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set) + datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set) for datasource in all_datasources: datasources_by_database[datasource.database].add(datasource) @@ -569,7 +558,7 @@ def can_access_table(self, database: "Database", table: "Table") -> bool: return True - def user_view_menu_names(self, permission_name: str) -> Set[str]: + def user_view_menu_names(self, permission_name: str) -> set[str]: base_query = ( self.get_session.query(self.viewmenu_model.name) .join(self.permissionview_model) @@ -599,7 +588,7 @@ def user_view_menu_names(self, permission_name: str) -> Set[str]: return {s.name for s in view_menu_names} return set() - def get_accessible_databases(self) -> List[int]: + def get_accessible_databases(self) -> list[int]: """ Return the list of databases accessible by the user. @@ -613,8 +602,8 @@ def get_accessible_databases(self) -> List[int]: ] def get_schemas_accessible_by_user( - self, database: "Database", schemas: List[str], hierarchical: bool = True - ) -> List[str]: + self, database: "Database", schemas: list[str], hierarchical: bool = True + ) -> list[str]: """ Return the list of SQL schemas accessible by the user. @@ -654,9 +643,9 @@ def get_schemas_accessible_by_user( def get_datasources_accessible_by_user( # pylint: disable=invalid-name self, database: "Database", - datasource_names: List[DatasourceName], + datasource_names: list[DatasourceName], schema: Optional[str] = None, - ) -> List[DatasourceName]: + ) -> list[DatasourceName]: """ Return the list of SQL tables accessible by the user. @@ -802,7 +791,7 @@ def sync_role_definitions(self) -> None: self.get_session.commit() self.clean_perms() - def _get_pvms_from_builtin_role(self, role_name: str) -> List[PermissionView]: + def _get_pvms_from_builtin_role(self, role_name: str) -> list[PermissionView]: """ Gets a list of model PermissionView permissions inferred from a builtin role definition @@ -821,7 +810,7 @@ def _get_pvms_from_builtin_role(self, role_name: str) -> List[PermissionView]: role_from_permissions.append(pvm) return role_from_permissions - def find_roles_by_id(self, role_ids: List[int]) -> List[Role]: + def find_roles_by_id(self, role_ids: list[int]) -> list[Role]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ @@ -1179,7 +1168,7 @@ def _update_vm_datasources_access( # pylint: disable=too-many-locals connection: Connection, old_database_name: str, target: "Database", - ) -> List[ViewMenu]: + ) -> list[ViewMenu]: """ Helper method that Updates all datasource access permission when a database name changes. @@ -1205,7 +1194,7 @@ def _update_vm_datasources_access( # pylint: disable=too-many-locals .filter(SqlaTable.database_id == target.id) .all() ) - updated_view_menus: List[ViewMenu] = [] + updated_view_menus: list[ViewMenu] = [] for dataset in datasets: old_dataset_vm_name = self.get_dataset_perm( dataset.id, dataset.table_name, old_database_name @@ -1768,7 +1757,7 @@ def on_permission_view_after_delete( """ @staticmethod - def get_exclude_users_from_lists() -> List[str]: + def get_exclude_users_from_lists() -> list[str]: """ Override to dynamically identify a list of usernames to exclude from all UI dropdown lists, owners, created_by filters etc... @@ -1896,7 +1885,7 @@ def get_user_by_username( def get_anonymous_user(self) -> User: # pylint: disable=no-self-use return AnonymousUserMixin() - def get_user_roles(self, user: Optional[User] = None) -> List[Role]: + def get_user_roles(self, user: Optional[User] = None) -> list[Role]: if not user: user = g.user if user.is_anonymous: @@ -1906,7 +1895,7 @@ def get_user_roles(self, user: Optional[User] = None) -> List[Role]: def get_guest_rls_filters( self, dataset: "BaseDatasource" - ) -> List[GuestTokenRlsRule]: + ) -> list[GuestTokenRlsRule]: """ Retrieves the row level security filters for the current user and the dataset, if the user is authenticated with a guest token. @@ -1922,7 +1911,7 @@ def get_guest_rls_filters( ] return [] - def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: + def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and the passed table. @@ -1990,7 +1979,7 @@ def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: ) return query.all() - def get_rls_ids(self, table: "BaseDatasource") -> List[int]: + def get_rls_ids(self, table: "BaseDatasource") -> list[int]: """ Retrieves the appropriate row level security filters IDs for the current user and the passed table. @@ -2002,10 +1991,10 @@ def get_rls_ids(self, table: "BaseDatasource") -> List[int]: ids.sort() # Combinations rather than permutations return ids - def get_guest_rls_filters_str(self, table: "BaseDatasource") -> List[str]: + def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]: return [f.get("clause", "") for f in self.get_guest_rls_filters(table)] - def get_rls_cache_key(self, datasource: "BaseDatasource") -> List[str]: + def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]: rls_ids = [] if datasource.is_rls_supported: rls_ids = self.get_rls_ids(datasource) @@ -2122,7 +2111,7 @@ def create_guest_access_token( self, user: GuestTokenUser, resources: GuestTokenResources, - rls: List[GuestTokenRlsRule], + rls: list[GuestTokenRlsRule], ) -> bytes: secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] @@ -2183,7 +2172,7 @@ def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], ) - def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: + def parse_jwt_guest_token(self, raw_token: str) -> dict[str, Any]: """ Parses a guest token. Raises an error if the jwt fails standard claims checks. :param raw_token: the token gotten from the request diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 9ea881fadf95b..678da79fa78be 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -20,7 +20,7 @@ from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Optional, Union import backoff import msgpack @@ -88,9 +88,9 @@ def handle_query_error( ex: Exception, query: Query, session: Session, - payload: Optional[Dict[str, Any]] = None, + payload: Optional[dict[str, Any]] = None, prefix_message: str = "", -) -> Dict[str, Any]: +) -> dict[str, Any]: """Local method handling error while processing the SQL""" payload = payload or {} msg = f"{prefix_message} {str(ex)}".strip() @@ -122,7 +122,7 @@ def handle_query_error( return payload -def get_query_backoff_handler(details: Dict[Any, Any]) -> None: +def get_query_backoff_handler(details: dict[Any, Any]) -> None: query_id = details["kwargs"]["query_id"] logger.error( "Query with id `%s` could not be retrieved", str(query_id), exc_info=True @@ -168,8 +168,8 @@ def get_sql_results( # pylint: disable=too-many-arguments username: Optional[str] = None, start_time: Optional[float] = None, expand_data: bool = False, - log_params: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: + log_params: Optional[dict[str, Any]] = None, +) -> Optional[dict[str, Any]]: """Executes the sql query returns the results.""" with session_scope(not ctask.request.called_directly) as session: with override_user(security_manager.find_user(username)): @@ -196,7 +196,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem query: Query, session: Session, cursor: Any, - log_params: Optional[Dict[str, Any]], + log_params: Optional[dict[str, Any]], apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" @@ -332,7 +332,7 @@ def apply_limit_if_exists( def _serialize_payload( - payload: Dict[Any, Any], use_msgpack: Optional[bool] = False + payload: dict[Any, Any], use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: logger.debug("Serializing to msgpack: %r", use_msgpack) if use_msgpack: @@ -346,10 +346,10 @@ def _serialize_and_expand_data( db_engine_spec: BaseEngineSpec, use_msgpack: Optional[bool] = False, expand_data: bool = False, -) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]: +) -> tuple[Union[bytes, str], list[Any], list[Any], list[Any]]: selected_columns = result_set.columns - all_columns: List[Any] - expanded_columns: List[Any] + all_columns: list[Any] + expanded_columns: list[Any] if use_msgpack: with stats_timing( @@ -383,15 +383,15 @@ def execute_sql_statements( session: Session, start_time: Optional[float], expand_data: bool, - log_params: Optional[Dict[str, Any]], -) -> Optional[Dict[str, Any]]: + log_params: Optional[dict[str, Any]], +) -> Optional[dict[str, Any]]: """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) - payload: Dict[str, Any] = dict(query_id=query_id) + payload: dict[str, Any] = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 034daeb7af05d..974d7eacd4b7c 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -16,9 +16,10 @@ # under the License. import logging import re +from collections.abc import Iterator from dataclasses import dataclass from enum import Enum -from typing import Any, cast, Iterator, List, Optional, Set, Tuple +from typing import Any, cast, Optional from urllib import parse import sqlparse @@ -97,7 +98,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: def extract_top_from_query( - statement: TokenList, top_keywords: Set[str] + statement: TokenList, top_keywords: set[str] ) -> Optional[int]: """ Extract top clause value from SQL statement. @@ -122,7 +123,7 @@ def extract_top_from_query( return top -def get_cte_remainder_query(sql: str) -> Tuple[Optional[str], str]: +def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: """ parse the SQL and return the CTE and rest of the block to the caller @@ -192,8 +193,8 @@ def __init__(self, sql_statement: str, strip_comments: bool = False): sql_statement = sqlparse.format(sql_statement, strip_comments=True) self.sql: str = sql_statement - self._tables: Set[Table] = set() - self._alias_names: Set[str] = set() + self._tables: set[Table] = set() + self._alias_names: set[str] = set() self._limit: Optional[int] = None logger.debug("Parsing with sqlparse statement: %s", self.sql) @@ -202,7 +203,7 @@ def __init__(self, sql_statement: str, strip_comments: bool = False): self._limit = _extract_limit_from_query(statement) @property - def tables(self) -> Set[Table]: + def tables(self) -> set[Table]: if not self._tables: for statement in self._parsed: self._extract_from_token(statement) @@ -282,7 +283,7 @@ def stripped(self) -> str: def strip_comments(self) -> str: return sqlparse.format(self.stripped(), strip_comments=True) - def get_statements(self) -> List[str]: + def get_statements(self) -> list[str]: """Returns a list of SQL statements as strings, stripped""" statements = [] for statement in self._parsed: @@ -737,7 +738,7 @@ def insert_rls( def extract_table_references( sql_text: str, sqla_dialect: str, show_warning: bool = True -) -> Set["Table"]: +) -> set["Table"]: """ Return all the dependencies from a SQL sql_text. """ diff --git a/superset/sql_validators/__init__.py b/superset/sql_validators/__init__.py index c448f696a12f3..ad048a86a564b 100644 --- a/superset/sql_validators/__init__.py +++ b/superset/sql_validators/__init__.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Type +from typing import Optional from . import base, postgres, presto_db from .base import SQLValidationAnnotation -def get_validator_by_name(name: str) -> Optional[Type[base.BaseSQLValidator]]: +def get_validator_by_name(name: str) -> Optional[type[base.BaseSQLValidator]]: return { "PrestoDBSQLValidator": presto_db.PrestoDBSQLValidator, "PostgreSQLValidator": postgres.PostgreSQLValidator, diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py index de29a96e8e3f8..8344fc9264d64 100644 --- a/superset/sql_validators/base.py +++ b/superset/sql_validators/base.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from superset.models.core import Database @@ -34,7 +34,7 @@ def __init__( self.start_column = start_column self.end_column = end_column - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return a dictionary representation of this annotation""" return { "line_number": self.line_number, @@ -53,6 +53,6 @@ class BaseSQLValidator: # pylint: disable=too-few-public-methods @classmethod def validate( cls, sql: str, schema: Optional[str], database: Database - ) -> List[SQLValidationAnnotation]: + ) -> list[SQLValidationAnnotation]: """Check that the given SQL querystring is valid for the given engine""" raise NotImplementedError diff --git a/superset/sql_validators/postgres.py b/superset/sql_validators/postgres.py index f62be39f03a44..60c15ca034c27 100644 --- a/superset/sql_validators/postgres.py +++ b/superset/sql_validators/postgres.py @@ -16,7 +16,7 @@ # under the License. import re -from typing import List, Optional +from typing import Optional from pgsanity.pgsanity import check_string @@ -32,8 +32,8 @@ class PostgreSQLValidator(BaseSQLValidator): # pylint: disable=too-few-public-m @classmethod def validate( cls, sql: str, schema: Optional[str], database: Database - ) -> List[SQLValidationAnnotation]: - annotations: List[SQLValidationAnnotation] = [] + ) -> list[SQLValidationAnnotation]: + annotations: list[SQLValidationAnnotation] = [] valid, error = check_string(sql, add_semicolon=True) if valid: return annotations diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index c5ecf4c96e20c..9d3e7641a6fd0 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -18,7 +18,7 @@ import logging import time from contextlib import closing -from typing import Any, Dict, List, Optional +from typing import Any, Optional from superset import app, security_manager from superset.models.core import Database @@ -109,7 +109,7 @@ def validate_statement( raise PrestoSQLValidationError( "The pyhive presto client returned an unhandled " "database error." ) from db_error - error_args: Dict[str, Any] = db_error.args[0] + error_args: dict[str, Any] = db_error.args[0] # Confirm the two fields we need to be able to present an annotation # are present in the error response -- a message, and a location. @@ -148,7 +148,7 @@ def validate_statement( @classmethod def validate( cls, sql: str, schema: Optional[str], database: Database - ) -> List[SQLValidationAnnotation]: + ) -> list[SQLValidationAnnotation]: """ Presto supports query-validation queries by running them with a prepended explain. @@ -167,7 +167,7 @@ def validate( ) as engine: # Sharing a single connection and cursor across the # execution of all statements (if many) - annotations: List[SQLValidationAnnotation] = [] + annotations: list[SQLValidationAnnotation] = [] with closing(engine.raw_connection()) as conn: cursor = conn.cursor() for statement in parsed_query.get_statements(): diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index 3c24bf1c26681..35d110d8fca14 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Optional from urllib import parse import simplejson as json @@ -326,7 +326,7 @@ def execute_sql_query(self) -> FlaskResponse: @staticmethod def _create_sql_json_command( - execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]] + execution_context: SqlJsonExecutionContext, log_params: Optional[dict[str, Any]] ) -> ExecuteSqlCommand: query_dao = QueryDAO() sql_json_executor = SqlLabRestApi._create_sql_json_executor( diff --git a/superset/sqllab/commands/estimate.py b/superset/sqllab/commands/estimate.py index 2b8c5814b953c..bf1d6c4fa57d0 100644 --- a/superset/sqllab/commands/estimate.py +++ b/superset/sqllab/commands/estimate.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List +from typing import Any from flask_babel import gettext as __ @@ -40,7 +40,7 @@ class QueryEstimationCommand(BaseCommand): _database_id: int _sql: str - _template_params: Dict[str, Any] + _template_params: dict[str, Any] _schema: str _database: Database @@ -64,7 +64,7 @@ def validate(self) -> None: def run( self, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: self.validate() sql = self._sql @@ -96,7 +96,7 @@ def run( ) from ex spec = self._database.db_engine_spec - query_cost_formatters: Dict[str, Any] = app.config[ + query_cost_formatters: dict[str, Any] = app.config[ "QUERY_COST_FORMATTERS_BY_ENGINE" ] query_cost_formatter = query_cost_formatters.get( diff --git a/superset/sqllab/commands/execute.py b/superset/sqllab/commands/execute.py index 97c8514d5d8d6..09b0769ce21b7 100644 --- a/superset/sqllab/commands/execute.py +++ b/superset/sqllab/commands/execute.py @@ -19,7 +19,7 @@ import copy import logging -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from flask_babel import gettext as __ @@ -51,7 +51,7 @@ logger = logging.getLogger(__name__) -CommandResult = Dict[str, Any] +CommandResult = dict[str, Any] class ExecuteSqlCommand(BaseCommand): @@ -63,7 +63,7 @@ class ExecuteSqlCommand(BaseCommand): _sql_json_executor: SqlJsonExecutor _execution_context_convertor: ExecutionContextConvertor _sqllab_ctas_no_limit: bool - _log_params: Optional[Dict[str, Any]] = None + _log_params: dict[str, Any] | None = None def __init__( self, @@ -75,7 +75,7 @@ def __init__( sql_json_executor: SqlJsonExecutor, execution_context_convertor: ExecutionContextConvertor, sqllab_ctas_no_limit_flag: bool, - log_params: Optional[Dict[str, Any]] = None, + log_params: dict[str, Any] | None = None, ) -> None: self._execution_context = execution_context self._query_dao = query_dao @@ -122,7 +122,7 @@ def run( # pylint: disable=too-many-statements,useless-suppression except Exception as ex: raise SqlLabException(self._execution_context, exception=ex) from ex - def _try_get_existing_query(self) -> Optional[Query]: + def _try_get_existing_query(self) -> Query | None: return self._query_dao.find_one_or_none( client_id=self._execution_context.client_id, user_id=self._execution_context.user_id, @@ -130,7 +130,7 @@ def _try_get_existing_query(self) -> Optional[Query]: ) @classmethod - def is_query_handled(cls, query: Optional[Query]) -> bool: + def is_query_handled(cls, query: Query | None) -> bool: return query is not None and query.status in [ QueryStatus.RUNNING, QueryStatus.PENDING, @@ -166,7 +166,7 @@ def _get_the_query_db(self) -> Database: return mydb @classmethod - def _validate_query_db(cls, database: Optional[Database]) -> None: + def _validate_query_db(cls, database: Database | None) -> None: if not database: raise SupersetGenericErrorException( __( diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py index e9559be3b97f4..1b9b0e03442fa 100644 --- a/superset/sqllab/commands/export.py +++ b/superset/sqllab/commands/export.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, cast, List, TypedDict +from typing import Any, cast, TypedDict import pandas as pd from flask_babel import gettext as __ @@ -40,7 +40,7 @@ class SqlExportResult(TypedDict): query: Query count: int - data: List[Any] + data: list[Any] class SqlResultExportCommand(BaseCommand): diff --git a/superset/sqllab/commands/results.py b/superset/sqllab/commands/results.py index d6c415a09fef6..83c8aa8f6a51c 100644 --- a/superset/sqllab/commands/results.py +++ b/superset/sqllab/commands/results.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, cast, Dict, Optional +from typing import Any, cast from flask_babel import gettext as __ @@ -40,14 +40,14 @@ class SqlExecutionResultsCommand(BaseCommand): _key: str - _rows: Optional[int] + _rows: int | None _blob: Any _query: Query def __init__( self, key: str, - rows: Optional[int] = None, + rows: int | None = None, ) -> None: self._key = key self._rows = rows @@ -100,7 +100,7 @@ def validate(self) -> None: def run( self, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Runs arbitrary sql and returns data as json""" self.validate() payload = utils.zlib_decompress( diff --git a/superset/sqllab/exceptions.py b/superset/sqllab/exceptions.py index 70e4fa9752b95..f06cc8dd2eeac 100644 --- a/superset/sqllab/exceptions.py +++ b/superset/sqllab/exceptions.py @@ -17,7 +17,7 @@ from __future__ import annotations import os -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from flask_babel import lazy_gettext as _ @@ -31,15 +31,15 @@ class SqlLabException(SupersetException): sql_json_execution_context: SqlJsonExecutionContext failed_reason_msg: str - suggestion_help_msg: Optional[str] + suggestion_help_msg: str | None def __init__( # pylint: disable=too-many-arguments self, sql_json_execution_context: SqlJsonExecutionContext, - error_type: Optional[SupersetErrorType] = None, - reason_message: Optional[str] = None, - exception: Optional[Exception] = None, - suggestion_help_msg: Optional[str] = None, + error_type: SupersetErrorType | None = None, + reason_message: str | None = None, + exception: Exception | None = None, + suggestion_help_msg: str | None = None, ) -> None: self.sql_json_execution_context = sql_json_execution_context self.failed_reason_msg = self._get_reason(reason_message, exception) @@ -68,21 +68,21 @@ def _generate_message(self) -> str: if self.failed_reason_msg: msg = msg + self.failed_reason_msg if self.suggestion_help_msg is not None: - msg = "{} {} {}".format(msg, os.linesep, self.suggestion_help_msg) + msg = f"{msg} {os.linesep} {self.suggestion_help_msg}" return msg @classmethod def _get_reason( - cls, reason_message: Optional[str] = None, exception: Optional[Exception] = None + cls, reason_message: str | None = None, exception: Exception | None = None ) -> str: if reason_message is not None: - return ": {}".format(reason_message) + return f": {reason_message}" if exception is not None: if hasattr(exception, "get_message"): - return ": {}".format(exception.get_message()) + return f": {exception.get_message()}" if hasattr(exception, "message"): - return ": {}".format(exception.message) - return ": {}".format(str(exception)) + return f": {exception.message}" + return f": {str(exception)}" return "" @@ -93,7 +93,7 @@ class QueryIsForbiddenToAccessException(SqlLabException): def __init__( self, sql_json_execution_context: SqlJsonExecutionContext, - exception: Optional[Exception] = None, + exception: Exception | None = None, ) -> None: super().__init__( sql_json_execution_context, diff --git a/superset/sqllab/execution_context_convertor.py b/superset/sqllab/execution_context_convertor.py index f49fbd9a31db5..430db0d52ff9c 100644 --- a/superset/sqllab/execution_context_convertor.py +++ b/superset/sqllab/execution_context_convertor.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import simplejson as json @@ -36,7 +36,7 @@ class ExecutionContextConvertor: _max_row_in_display_configuration: int # pylint: disable=invalid-name _exc_status: SqlJsonExecutionStatus - payload: Dict[str, Any] + payload: dict[str, Any] def set_max_row_in_display(self, value: int) -> None: self._max_row_in_display_configuration = value # pylint: disable=invalid-name diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index 1369e78db15de..db1adf43bab34 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, no-self-use, too-few-public-methods, too-many-arguments from __future__ import annotations -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask_babel import gettext as __, ngettext from jinja2 import TemplateError @@ -124,16 +124,16 @@ def _raise_template_exception( class SqlQueryRenderException(SqlLabException): - _extra: Optional[Dict[str, Any]] + _extra: dict[str, Any] | None def __init__( self, sql_json_execution_context: SqlJsonExecutionContext, error_type: SupersetErrorType, - reason_message: Optional[str] = None, - exception: Optional[Exception] = None, - suggestion_help_msg: Optional[str] = None, - extra: Optional[Dict[str, Any]] = None, + reason_message: str | None = None, + exception: Exception | None = None, + suggestion_help_msg: str | None = None, + extra: dict[str, Any] | None = None, ) -> None: super().__init__( sql_json_execution_context, @@ -145,10 +145,10 @@ def __init__( self._extra = extra @property - def extra(self) -> Optional[Dict[str, Any]]: + def extra(self) -> dict[str, Any] | None: return self._extra - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: rv = super().to_dict() if self._extra: rv["extra"] = self._extra diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py index e4e6b60654b87..124f477e9625f 100644 --- a/superset/sqllab/sql_json_executer.py +++ b/superset/sqllab/sql_json_executer.py @@ -20,7 +20,7 @@ import dataclasses import logging from abc import ABC -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask_babel import gettext as __ @@ -43,7 +43,7 @@ QueryStatus = utils.QueryStatus logger = logging.getLogger(__name__) -SqlResults = Dict[str, Any] +SqlResults = dict[str, Any] GetSqlResultsTask = Callable[..., SqlResults] @@ -53,7 +53,7 @@ def execute( self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], + log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: raise NotImplementedError() @@ -88,7 +88,7 @@ def execute( self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], + log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: query_id = execution_context.query.id try: @@ -120,8 +120,8 @@ def _get_sql_results_with_timeout( self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], - ) -> Optional[SqlResults]: + log_params: dict[str, Any] | None, + ) -> SqlResults | None: with utils.timeout( seconds=self._timeout_duration_in_seconds, error_message=self._get_timeout_error_msg(), @@ -132,8 +132,8 @@ def _get_sql_results( self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], - ) -> Optional[SqlResults]: + log_params: dict[str, Any] | None, + ) -> SqlResults | None: return self._get_sql_results_task( execution_context.query.id, rendered_query, @@ -161,7 +161,7 @@ def execute( self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], + log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: query_id = execution_context.query.id logger.info("Query %i: Running query on a Celery worker", query_id) diff --git a/superset/sqllab/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py index 644c978b32765..22277804ee621 100644 --- a/superset/sqllab/sqllab_execution_context.py +++ b/superset/sqllab/sqllab_execution_context.py @@ -19,7 +19,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, cast, Dict, Optional, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from flask import g from sqlalchemy.orm.exc import DetachedInstanceError @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) -SqlResults = Dict[str, Any] +SqlResults = dict[str, Any] @dataclass @@ -45,7 +45,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes database_id: int schema: str sql: str - template_params: Dict[str, Any] + template_params: dict[str, Any] async_flag: bool limit: int status: str @@ -53,14 +53,14 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes client_id_or_short_id: str sql_editor_id: str tab_name: str - user_id: Optional[int] + user_id: int | None expand_data: bool - create_table_as_select: Optional[CreateTableAsSelect] - database: Optional[Database] + create_table_as_select: CreateTableAsSelect | None + database: Database | None query: Query - _sql_result: Optional[SqlResults] + _sql_result: SqlResults | None - def __init__(self, query_params: Dict[str, Any]): + def __init__(self, query_params: dict[str, Any]): self.create_table_as_select = None self.database = None self._init_from_query_params(query_params) @@ -70,7 +70,7 @@ def __init__(self, query_params: Dict[str, Any]): def set_query(self, query: Query) -> None: self.query = query - def _init_from_query_params(self, query_params: Dict[str, Any]) -> None: + def _init_from_query_params(self, query_params: dict[str, Any]) -> None: self.database_id = cast(int, query_params.get("database_id")) self.schema = cast(str, query_params.get("schema")) self.sql = cast(str, query_params.get("sql")) @@ -90,7 +90,7 @@ def _init_from_query_params(self, query_params: Dict[str, Any]) -> None: ) @staticmethod - def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]: + def _get_template_params(query_params: dict[str, Any]) -> dict[str, Any]: try: template_params = json.loads(query_params.get("templateParams") or "{}") except json.JSONDecodeError: @@ -102,7 +102,7 @@ def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]: return template_params @staticmethod - def _get_limit_param(query_params: Dict[str, Any]) -> int: + def _get_limit_param(query_params: dict[str, Any]) -> int: limit = apply_max_row_limit(query_params.get("queryLimit") or 0) if limit < 0: logger.warning( @@ -125,7 +125,7 @@ def set_database(self, database: Database) -> None: schema_name = self._get_ctas_target_schema_name(database) self.create_table_as_select.target_schema_name = schema_name # type: ignore - def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]: + def _get_ctas_target_schema_name(self, database: Database) -> str | None: if database.force_ctas_schema: return database.force_ctas_schema return get_cta_schema_name(database, g.user, self.schema, self.sql) @@ -134,10 +134,10 @@ def _validate_db(self, database: Database) -> None: # TODO validate db.id is equal to self.database_id pass - def get_execution_result(self) -> Optional[SqlResults]: + def get_execution_result(self) -> SqlResults | None: return self._sql_result - def set_execution_result(self, sql_result: Optional[SqlResults]) -> None: + def set_execution_result(self, sql_result: SqlResults | None) -> None: self._sql_result = sql_result def create_query(self) -> Query: @@ -178,15 +178,15 @@ def get_query_details(self) -> str: try: if hasattr(self, "query"): if self.query.id: - return "query '{}' - '{}'".format(self.query.id, self.query.sql) + return f"query '{self.query.id}' - '{self.query.sql}'" except DetachedInstanceError: pass - return "query '{}'".format(self.sql) + return f"query '{self.sql}'" class CreateTableAsSelect: # pylint: disable=too-few-public-methods ctas_method: CtasMethod - target_schema_name: Optional[str] + target_schema_name: str | None target_table_name: str def __init__( @@ -197,7 +197,7 @@ def __init__( self.target_table_name = target_table_name @staticmethod - def create_from(query_params: Dict[str, Any]) -> CreateTableAsSelect: + def create_from(query_params: dict[str, Any]) -> CreateTableAsSelect: ctas_method = query_params.get("ctas_method", CtasMethod.TABLE) schema = cast(str, query_params.get("schema")) tmp_table_name = cast(str, query_params.get("tmp_table_name")) diff --git a/superset/sqllab/utils.py b/superset/sqllab/utils.py index 3bcd7308a1281..abceaaf136cc4 100644 --- a/superset/sqllab/utils.py +++ b/superset/sqllab/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any import pyarrow as pa @@ -22,8 +22,8 @@ def apply_display_max_row_configuration_if_require( # pylint: disable=invalid-name - sql_results: Dict[str, Any], max_rows_in_result: int -) -> Dict[str, Any]: + sql_results: dict[str, Any], max_rows_in_result: int +) -> dict[str, Any]: """ Given a `sql_results` nested structure, applies a limit to the number of rows diff --git a/superset/stats_logger.py b/superset/stats_logger.py index 4b869042a90df..fc223f752967b 100644 --- a/superset/stats_logger.py +++ b/superset/stats_logger.py @@ -54,22 +54,20 @@ def incr(self, key: str) -> None: logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL) def decr(self, key: str) -> None: - logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL)) + logger.debug(Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL) def timing(self, key: str, value: float) -> None: logger.debug( - (Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL) + Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL ) def gauge(self, key: str, value: float) -> None: logger.debug( - ( - Fore.CYAN - + "[stats_logger] (gauge) " - + f"{key}" - + f"{value}" - + Style.RESET_ALL - ) + Fore.CYAN + + "[stats_logger] (gauge) " + + f"{key}" + + f"{value}" + + Style.RESET_ALL ) diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 8eaea54176ed3..7c21df6a88b88 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections.abc import Sequence from datetime import datetime -from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, TYPE_CHECKING, Union -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from werkzeug.wrappers import Response if TYPE_CHECKING: @@ -69,8 +70,8 @@ class ResultSetColumnType(TypedDict): is_dttm: bool -CacheConfig = Dict[str, Any] -DbapiDescriptionRow = Tuple[ +CacheConfig = dict[str, Any] +DbapiDescriptionRow = tuple[ Union[str, bytes], str, Optional[str], @@ -79,27 +80,27 @@ class ResultSetColumnType(TypedDict): Optional[int], bool, ] -DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, ...]] -DbapiResult = Sequence[Union[List[Any], Tuple[Any, ...]]] +DbapiDescription = Union[list[DbapiDescriptionRow], tuple[DbapiDescriptionRow, ...]] +DbapiResult = Sequence[Union[list[Any], tuple[Any, ...]]] FilterValue = Union[bool, datetime, float, int, str] -FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] -FormData = Dict[str, Any] -Granularity = Union[str, Dict[str, Union[str, float]]] +FilterValues = Union[FilterValue, list[FilterValue], tuple[FilterValue]] +FormData = dict[str, Any] +Granularity = Union[str, dict[str, Union[str, float]]] Column = Union[AdhocColumn, str] Metric = Union[AdhocMetric, str] -OrderBy = Tuple[Metric, bool] -QueryObjectDict = Dict[str, Any] -VizData = Optional[Union[List[Any], Dict[Any, Any]]] -VizPayload = Dict[str, Any] +OrderBy = tuple[Metric, bool] +QueryObjectDict = dict[str, Any] +VizData = Optional[Union[list[Any], dict[Any, Any]]] +VizPayload = dict[str, Any] # Flask response. Base = Union[bytes, str] Status = Union[int, str] -Headers = Dict[str, Any] +Headers = dict[str, Any] FlaskResponse = Union[ Response, Base, - Tuple[Base, Status], - Tuple[Base, Status, Headers], - Tuple[Response, Status], + tuple[Base, Status], + tuple[Base, Status, Headers], + tuple[Response, Status], ] diff --git a/superset/tables/models.py b/superset/tables/models.py index 9a0c07fdcf5a4..a24035fb97dd4 100644 --- a/superset/tables/models.py +++ b/superset/tables/models.py @@ -24,7 +24,8 @@ These models are not fully implemented, and shouldn't be used yet. """ -from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import sqlalchemy as sa from flask_appbuilder import Model @@ -87,7 +88,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # The relationship between datasets and columns is 1:n, but we use a # many-to-many association table to avoid adding two mutually exclusive # columns(dataset_id and table_id) to Column - columns: List[Column] = relationship( + columns: list[Column] = relationship( "Column", secondary=table_column_association_table, cascade="all, delete-orphan", @@ -96,7 +97,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # is loaded. backref="tables", ) - datasets: List["Dataset"] # will be populated by Dataset.tables backref + datasets: list["Dataset"] # will be populated by Dataset.tables backref # We use ``sa.Text`` for these attributes because (1) in modern databases the # performance is the same as ``VARCHAR``[1] and (2) because some table names can be @@ -130,7 +131,7 @@ def sync_columns(self) -> None: existing_columns = {column.name: column for column in self.columns} quote_identifier = self.database.quote_identifier - def update_or_create_column(column_meta: Dict[str, Any]) -> Column: + def update_or_create_column(column_meta: dict[str, Any]) -> Column: column_name: str = column_meta["name"] if column_name in existing_columns: column = existing_columns[column_name] @@ -153,8 +154,8 @@ def bulk_load_or_create( table_names: Iterable[TableName], default_schema: Optional[str] = None, sync_columns: Optional[bool] = False, - default_props: Optional[Dict[str, Any]] = None, - ) -> List["Table"]: + default_props: Optional[dict[str, Any]] = None, + ) -> list["Table"]: """ Load or create multiple Table instances. """ diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index 1e886e2af65a1..20327b54f01cd 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List from superset.commands.base import BaseCommand, CreateMixin from superset.dao.exceptions import DAOCreateFailedError @@ -28,7 +27,7 @@ class CreateCustomTagCommand(CreateMixin, BaseCommand): - def __init__(self, object_type: ObjectTypes, object_id: int, tags: List[str]): + def __init__(self, object_type: ObjectTypes, object_id: int, tags: list[str]): self._object_type = object_type self._object_id = object_id self._tags = tags diff --git a/superset/tags/commands/delete.py b/superset/tags/commands/delete.py index acec01661935b..08189b5ac55d8 100644 --- a/superset/tags/commands/delete.py +++ b/superset/tags/commands/delete.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError @@ -90,7 +89,7 @@ def validate(self) -> None: class DeleteTagsCommand(DeleteMixin, BaseCommand): - def __init__(self, tags: List[str]): + def __init__(self, tags: list[str]): self._tags = tags def run(self) -> None: diff --git a/superset/tags/dao.py b/superset/tags/dao.py index c676b4ab3c25c..9ea61f5c90d0b 100644 --- a/superset/tags/dao.py +++ b/superset/tags/dao.py @@ -16,7 +16,7 @@ # under the License. import logging from operator import and_ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError @@ -45,7 +45,7 @@ def validate_tag_name(tag_name: str) -> bool: @staticmethod def create_custom_tagged_objects( - object_type: ObjectTypes, object_id: int, tag_names: List[str] + object_type: ObjectTypes, object_id: int, tag_names: list[str] ) -> None: tagged_objects = [] for name in tag_names: @@ -95,7 +95,7 @@ def delete_tagged_object( raise DAODeleteFailedError(exception=ex) from ex @staticmethod - def delete_tags(tag_names: List[str]) -> None: + def delete_tags(tag_names: list[str]) -> None: """ deletes tags from a list of tag names """ @@ -158,8 +158,8 @@ def find_tagged_object( @staticmethod def get_tagged_objects_for_tags( - tags: Optional[List[str]] = None, obj_types: Optional[List[str]] = None - ) -> List[Dict[str, Any]]: + tags: Optional[list[str]] = None, obj_types: Optional[list[str]] = None + ) -> list[dict[str, Any]]: """ returns a list of tagged objects filtered by tag names and object types if no filters applied returns all tagged objects @@ -174,7 +174,7 @@ def get_tagged_objects_for_tags( # filter types - results: List[Dict[str, Any]] = [] + results: list[dict[str, Any]] = [] # dashboards if (not obj_types) or ("dashboard" in obj_types): diff --git a/superset/tags/models.py b/superset/tags/models.py index 797308c30675a..bb845303ffd20 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -14,16 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import ( - absolute_import, - annotations, - division, - print_function, - unicode_literals, -) +from __future__ import annotations import enum -from typing import List, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING from flask_appbuilder import Model from sqlalchemy import Column, Enum, ForeignKey, Integer, String @@ -122,26 +116,26 @@ def get_object_type(class_name: str) -> ObjectTypes: try: return mapping[class_name.lower()] except KeyError as ex: - raise Exception("No mapping found for {0}".format(class_name)) from ex + raise Exception(f"No mapping found for {class_name}") from ex class ObjectUpdater: - object_type: Optional[str] = None + object_type: str | None = None @classmethod def get_owners_ids( - cls, target: Union[Dashboard, FavStar, Slice, Query, SqlaTable] - ) -> List[int]: + cls, target: Dashboard | FavStar | Slice | Query | SqlaTable + ) -> list[int]: raise NotImplementedError("Subclass should implement `get_owners_ids`") @classmethod def _add_owners( cls, session: Session, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: for owner_id in cls.get_owners_ids(target): - name = "owner:{0}".format(owner_id) + name = f"owner:{owner_id}" tag = get_tag(name, session, TagTypes.owner) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.id, object_type=cls.object_type @@ -153,7 +147,7 @@ def after_insert( cls, _mapper: Mapper, connection: Connection, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: session = Session(bind=connection) @@ -161,7 +155,7 @@ def after_insert( cls._add_owners(session, target) # add `type:` tags - tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type) + tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.id, object_type=cls.object_type ) @@ -174,7 +168,7 @@ def after_update( cls, _mapper: Mapper, connection: Connection, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: session = Session(bind=connection) @@ -203,7 +197,7 @@ def after_delete( cls, _mapper: Mapper, connection: Connection, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: session = Session(bind=connection) @@ -220,7 +214,7 @@ class ChartUpdater(ObjectUpdater): object_type = "chart" @classmethod - def get_owners_ids(cls, target: Slice) -> List[int]: + def get_owners_ids(cls, target: Slice) -> list[int]: return [owner.id for owner in target.owners] @@ -228,7 +222,7 @@ class DashboardUpdater(ObjectUpdater): object_type = "dashboard" @classmethod - def get_owners_ids(cls, target: Dashboard) -> List[int]: + def get_owners_ids(cls, target: Dashboard) -> list[int]: return [owner.id for owner in target.owners] @@ -236,7 +230,7 @@ class QueryUpdater(ObjectUpdater): object_type = "query" @classmethod - def get_owners_ids(cls, target: Query) -> List[int]: + def get_owners_ids(cls, target: Query) -> list[int]: return [target.user_id] @@ -244,7 +238,7 @@ class DatasetUpdater(ObjectUpdater): object_type = "dataset" @classmethod - def get_owners_ids(cls, target: SqlaTable) -> List[int]: + def get_owners_ids(cls, target: SqlaTable) -> list[int]: return [owner.id for owner in target.owners] @@ -254,7 +248,7 @@ def after_insert( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: session = Session(bind=connection) - name = "favorited_by:{0}".format(target.user_id) + name = f"favorited_by:{target.user_id}" tag = get_tag(name, session, TagTypes.favorited_by) tagged_object = TaggedObject( tag_id=tag.id, diff --git a/superset/tasks/__init__.py b/superset/tasks/__init__.py index fd9417fe5c1e9..13a83393a9124 100644 --- a/superset/tasks/__init__.py +++ b/superset/tasks/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index ffd92c262747e..cfcb3e31c6ba8 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -18,7 +18,7 @@ import copy import logging -from typing import Any, cast, Dict, Optional, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded from flask import current_app, g @@ -45,12 +45,12 @@ ] # TODO: new config key -def set_form_data(form_data: Dict[str, Any]) -> None: +def set_form_data(form_data: dict[str, Any]) -> None: # pylint: disable=assigning-non-slot g.form_data = form_data -def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: +def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext: try: return ChartDataQueryContextSchema().load(form_data) except KeyError as ex: @@ -61,8 +61,8 @@ def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( - job_metadata: Dict[str, Any], - form_data: Dict[str, Any], + job_metadata: dict[str, Any], + form_data: dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel from superset.charts.data.commands.get_data_command import ChartDataCommand @@ -104,9 +104,9 @@ def load_chart_data_into_cache( @celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) def load_explore_json_into_cache( # pylint: disable=too-many-locals - job_metadata: Dict[str, Any], - form_data: Dict[str, Any], - response_type: Optional[str] = None, + job_metadata: dict[str, Any], + form_data: dict[str, Any], + response_type: str | None = None, force: bool = False, ) -> None: cache_key_prefix = "ejr-" # ejr: explore_json request diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index bdbf8add7eaba..448271269a209 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from urllib import request from urllib.error import URLError @@ -72,7 +72,7 @@ class Strategy: # pylint: disable=too-few-public-methods def __init__(self) -> None: pass - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: raise NotImplementedError("Subclasses must implement get_urls!") @@ -94,7 +94,7 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dummy" - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: session = db.create_scoped_session() charts = session.query(Slice).all() @@ -126,7 +126,7 @@ def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None: self.top_n = top_n self.since = parse_human_datetime(since) if since else None - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: urls = [] session = db.create_scoped_session() @@ -165,11 +165,11 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dashboard_tags" - def __init__(self, tags: Optional[List[str]] = None) -> None: + def __init__(self, tags: Optional[list[str]] = None) -> None: super().__init__() self.tags = tags or [] - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: urls = [] session = db.create_scoped_session() @@ -216,7 +216,7 @@ def get_urls(self) -> List[str]: @celery_app.task(name="fetch_url") -def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]: +def fetch_url(url: str, headers: dict[str, str]) -> dict[str, str]: """ Celery job to fetch url """ @@ -242,7 +242,7 @@ def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]: @celery_app.task(name="cache-warmup") def cache_warmup( strategy_name: str, *args: Any, **kwargs: Any -) -> Union[Dict[str, List[str]], str]: +) -> Union[dict[str, list[str]], str]: """ Warm up cache. @@ -272,7 +272,7 @@ def cache_warmup( cookies = MachineAuthProvider.get_auth_cookies(user) headers = {"Cookie": f"session={cookies.get('session', '')}"} - results: Dict[str, List[str]] = {"scheduled": [], "errors": []} + results: dict[str, list[str]] = {"scheduled": [], "errors": []} for url in strategy.get_urls(): try: logger.info("Scheduling %s", url) diff --git a/superset/tasks/cron_util.py b/superset/tasks/cron_util.py index 9c275addf6b71..19d342ebdcf86 100644 --- a/superset/tasks/cron_util.py +++ b/superset/tasks/cron_util.py @@ -16,8 +16,8 @@ # under the License. import logging +from collections.abc import Iterator from datetime import datetime, timedelta, timezone as dt_timezone -from typing import Iterator from croniter import croniter from pytz import timezone as pytz_timezone, UnknownTimeZoneError diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py index 9c1dab82202b8..5012330bbd43e 100644 --- a/superset/tasks/utils.py +++ b/superset/tasks/utils.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import List, Optional, Tuple, TYPE_CHECKING, Union +from typing import TYPE_CHECKING from flask import current_app, g @@ -32,10 +32,10 @@ # pylint: disable=too-many-branches def get_executor( - executor_types: List[ExecutorType], - model: Union[Dashboard, ReportSchedule, Slice], - current_user: Optional[str] = None, -) -> Tuple[ExecutorType, str]: + executor_types: list[ExecutorType], + model: Dashboard | ReportSchedule | Slice, + current_user: str | None = None, +) -> tuple[ExecutorType, str]: """ Extract the user that should be used to execute a scheduled task. Certain executor types extract the user from the underlying object (e.g. CREATOR), the constant @@ -86,7 +86,7 @@ def get_executor( raise ExecutorNotFoundError() -def get_current_user() -> Optional[str]: +def get_current_user() -> str | None: user = g.user if hasattr(g, "user") and g.user else None if user and not user.is_anonymous: return user.username diff --git a/superset/translations/utils.py b/superset/translations/utils.py index 79d01539a16e1..23eca1dd8c810 100644 --- a/superset/translations/utils.py +++ b/superset/translations/utils.py @@ -16,15 +16,15 @@ # under the License. import json import os -from typing import Any, Dict, Optional +from typing import Any, Optional # Global caching for JSON language packs -ALL_LANGUAGE_PACKS: Dict[str, Dict[str, Any]] = {"en": {}} +ALL_LANGUAGE_PACKS: dict[str, dict[str, Any]] = {"en": {}} DIR = os.path.dirname(os.path.abspath(__file__)) -def get_language_pack(locale: str) -> Optional[Dict[str, Any]]: +def get_language_pack(locale: str) -> Optional[dict[str, Any]]: """Get/cache a language pack Returns the language pack from cache if it exists, caches otherwise @@ -34,7 +34,7 @@ def get_language_pack(locale: str) -> Optional[Dict[str, Any]]: """ pack = ALL_LANGUAGE_PACKS.get(locale) if not pack: - filename = DIR + "/{}/LC_MESSAGES/messages.json".format(locale) + filename = DIR + f"/{locale}/LC_MESSAGES/messages.json" try: with open(filename, encoding="utf8") as f: pack = json.load(f) diff --git a/superset/utils/async_query_manager.py b/superset/utils/async_query_manager.py index 71559aaa3dcbc..1913fc1decc81 100644 --- a/superset/utils/async_query_manager.py +++ b/superset/utils/async_query_manager.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Literal, Optional import jwt import redis @@ -38,7 +38,7 @@ class AsyncQueryJobException(Exception): def build_job_metadata( channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "channel_id": channel_id, "job_id": job_id, @@ -49,7 +49,7 @@ def build_job_metadata( } -def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]: +def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]: event_id = event_data[0] event_payload = event_data[1]["data"] return {"id": event_id, **json.loads(event_payload)} @@ -149,7 +149,7 @@ def validate_session(response: Response) -> Response: return response - def parse_jwt_from_request(self, req: Request) -> Dict[str, Any]: + def parse_jwt_from_request(self, req: Request) -> dict[str, Any]: token = req.cookies.get(self._jwt_cookie_name) if not token: raise AsyncQueryTokenException("Token not preset") @@ -160,7 +160,7 @@ def parse_jwt_from_request(self, req: Request) -> Dict[str, Any]: logger.warning("Parse jwt failed", exc_info=True) raise AsyncQueryTokenException("Failed to parse token") from ex - def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]: + def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]: job_id = str(uuid.uuid4()) return build_job_metadata( channel_id, job_id, user_id, status=self.STATUS_PENDING @@ -168,14 +168,14 @@ def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]: def read_events( self, channel: str, last_id: Optional[str] - ) -> List[Optional[Dict[str, Any]]]: + ) -> list[Optional[dict[str, Any]]]: stream_name = f"{self._stream_prefix}{channel}" start_id = increment_id(last_id) if last_id else "-" results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT) return [] if not results else list(map(parse_event, results)) def update_job( - self, job_metadata: Dict[str, Any], status: str, **kwargs: Any + self, job_metadata: dict[str, Any], status: str, **kwargs: Any ) -> None: if "channel_id" not in job_metadata: raise AsyncQueryJobException("No channel ID specified") diff --git a/superset/utils/cache.py b/superset/utils/cache.py index a632b04b374a5..693f3a73bcfe5 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -20,7 +20,7 @@ import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, TYPE_CHECKING from flask import current_app as app, request from flask_caching import Cache @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) -def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str: +def generate_cache_key(values_dict: dict[str, Any], key_prefix: str = "") -> str: hash_str = md5_sha_from_dict(values_dict, default=json_int_dttm_ser) return f"{key_prefix}{hash_str}" @@ -49,9 +49,9 @@ def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str def set_and_log_cache( cache_instance: Cache, cache_key: str, - cache_value: Dict[str, Any], - cache_timeout: Optional[int] = None, - datasource_uid: Optional[str] = None, + cache_value: dict[str, Any], + cache_timeout: int | None = None, + datasource_uid: str | None = None, ) -> None: if isinstance(cache_instance.cache, NullCache): return @@ -91,11 +91,11 @@ def set_and_log_cache( def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument args_hash = hash(frozenset(request.args.items())) - return "view/{}/{}".format(request.path, args_hash) + return f"view/{request.path}/{args_hash}" def memoized_func( - key: Optional[str] = None, cache: Cache = cache_manager.cache + key: str | None = None, cache: Cache = cache_manager.cache ) -> Callable[..., Any]: """ Decorator with configurable key and cache backend. @@ -152,10 +152,10 @@ def wrapped_f(*args: Any, **kwargs: Any) -> Any: def etag_cache( cache: Cache = cache_manager.cache, - get_last_modified: Optional[Callable[..., datetime]] = None, - max_age: Optional[Union[int, float]] = None, - raise_for_access: Optional[Callable[..., Any]] = None, - skip: Optional[Callable[..., bool]] = None, + get_last_modified: Callable[..., datetime] | None = None, + max_age: int | float | None = None, + raise_for_access: Callable[..., Any] | None = None, + skip: Callable[..., bool] | None = None, ) -> Callable[..., Any]: """ A decorator for caching views and handling etag conditional requests. diff --git a/superset/utils/celery.py b/superset/utils/celery.py index 474fc98d9416e..35771791456ce 100644 --- a/superset/utils/celery.py +++ b/superset/utils/celery.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import logging +from collections.abc import Iterator from contextlib import contextmanager -from typing import Iterator from sqlalchemy import create_engine from sqlalchemy.exc import SQLAlchemyError diff --git a/superset/utils/core.py b/superset/utils/core.py index c537abf459648..24e539b2b6f11 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -35,6 +35,7 @@ import traceback import uuid import zlib +from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass from datetime import date, datetime, time, timedelta @@ -47,24 +48,7 @@ from io import BytesIO from timeit import default_timer from types import TracebackType -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - Iterator, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Type, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar from urllib.parse import unquote_plus from zipfile import ZipFile @@ -197,11 +181,11 @@ class LoggerLevel(str, Enum): class HeaderDataType(TypedDict): notification_format: str - owners: List[int] + owners: list[int] notification_type: str - notification_source: Optional[str] - chart_id: Optional[int] - dashboard_id: Optional[int] + notification_source: str | None + chart_id: int | None + dashboard_id: int | None class DatasourceDict(TypedDict): @@ -212,20 +196,20 @@ class DatasourceDict(TypedDict): class AdhocFilterClause(TypedDict, total=False): clause: str expressionType: str - filterOptionName: Optional[str] - comparator: Optional[FilterValues] + filterOptionName: str | None + comparator: FilterValues | None operator: str subject: str - isExtra: Optional[bool] - sqlExpression: Optional[str] + isExtra: bool | None + sqlExpression: str | None class QueryObjectFilterClause(TypedDict, total=False): col: Column op: str # pylint: disable=invalid-name - val: Optional[FilterValues] - grain: Optional[str] - isExtra: Optional[bool] + val: FilterValues | None + grain: str | None + isExtra: bool | None class ExtraFiltersTimeColumnType(str, Enum): @@ -351,9 +335,9 @@ class ReservedUrlParameters(str, Enum): EDIT_MODE = "edit" @staticmethod - def is_standalone_mode() -> Optional[bool]: + def is_standalone_mode() -> bool | None: standalone_param = request.args.get(ReservedUrlParameters.STANDALONE.value) - standalone: Optional[bool] = bool( + standalone: bool | None = bool( standalone_param and standalone_param != "false" and standalone_param != "0" ) return standalone @@ -370,10 +354,10 @@ class ColumnTypeSource(Enum): class ColumnSpec(NamedTuple): - sqla_type: Union[TypeEngine, str] + sqla_type: TypeEngine | str generic_type: GenericDataType is_dttm: bool - python_date_format: Optional[str] = None + python_date_format: str | None = None try: @@ -407,8 +391,8 @@ def flasher(msg: str, severity: str = "message") -> None: def parse_js_uri_path_item( - item: Optional[str], unquote: bool = True, eval_undefined: bool = False -) -> Optional[str]: + item: str | None, unquote: bool = True, eval_undefined: bool = False +) -> str | None: """Parse a uri path item made with js. :param item: a uri path component @@ -421,7 +405,7 @@ def parse_js_uri_path_item( return unquote_plus(item) if unquote and item else item -def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float, int]]: +def cast_to_num(value: float | int | str | None) -> float | int | None: """Casts a value to an int/float >>> cast_to_num('1 ') @@ -457,7 +441,7 @@ def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float return None -def cast_to_boolean(value: Any) -> Optional[bool]: +def cast_to_boolean(value: Any) -> bool | None: """Casts a value to an int/float >>> cast_to_boolean(1) @@ -487,7 +471,7 @@ def cast_to_boolean(value: Any) -> Optional[bool]: return False -def list_minus(l: List[Any], minus: List[Any]) -> List[Any]: +def list_minus(l: list[Any], minus: list[Any]) -> list[Any]: """Returns l without what is in minus >>> list_minus([1, 2, 3], [2]) @@ -501,12 +485,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.sort_keys = True - def default(self, o: Any) -> Union[Dict[Any, Any], str]: + def default(self, o: Any) -> dict[Any, Any] | str: if isinstance(o, uuid.UUID): return str(o) try: vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"} - return {"__{}__".format(o.__class__.__name__): vals} + return {f"__{o.__class__.__name__}__": vals} except Exception: # pylint: disable=broad-except if isinstance(o, datetime): return {"__datetime__": o.replace(microsecond=0).isoformat()} @@ -519,13 +503,13 @@ class JSONEncodedDict(TypeDecorator): # pylint: disable=abstract-method impl = TEXT def process_bind_param( - self, value: Optional[Dict[Any, Any]], dialect: str - ) -> Optional[str]: + self, value: dict[Any, Any] | None, dialect: str + ) -> str | None: return json.dumps(value) if value is not None else None def process_result_value( - self, value: Optional[str], dialect: str - ) -> Optional[Dict[Any, Any]]: + self, value: str | None, dialect: str + ) -> dict[Any, Any] | None: return json.loads(value) if value is not None else None @@ -634,7 +618,7 @@ def json_int_dttm_ser(obj: Any) -> Any: return base_json_conv(obj) -def json_dumps_w_dates(payload: Dict[Any, Any], sort_keys: bool = False) -> str: +def json_dumps_w_dates(payload: dict[Any, Any], sort_keys: bool = False) -> str: """Dumps payload to JSON with Datetime objects properly converted""" return json.dumps(payload, default=json_int_dttm_ser, sort_keys=sort_keys) @@ -662,7 +646,7 @@ def error_msg_from_exception(ex: Exception) -> str: return msg or str(ex) -def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str: +def markdown(raw: str, markup_wrap: bool | None = False) -> str: safe_markdown_tags = { "h1", "h2", @@ -709,15 +693,15 @@ def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str: return safe -def readfile(file_path: str) -> Optional[str]: +def readfile(file_path: str) -> str | None: with open(file_path) as f: content = f.read() return content def generic_find_constraint_name( - table: str, columns: Set[str], referenced: str, database: SQLA -) -> Optional[str]: + table: str, columns: set[str], referenced: str, database: SQLA +) -> str | None: """Utility to find a constraint name in alembic migrations""" tbl = sa.Table( table, database.metadata, autoload=True, autoload_with=database.engine @@ -731,8 +715,8 @@ def generic_find_constraint_name( def generic_find_fk_constraint_name( - table: str, columns: Set[str], referenced: str, insp: Inspector -) -> Optional[str]: + table: str, columns: set[str], referenced: str, insp: Inspector +) -> str | None: """Utility to find a foreign-key constraint name in alembic migrations""" for fk in insp.get_foreign_keys(table): if ( @@ -745,8 +729,8 @@ def generic_find_fk_constraint_name( def generic_find_fk_constraint_names( # pylint: disable=invalid-name - table: str, columns: Set[str], referenced: str, insp: Inspector -) -> Set[str]: + table: str, columns: set[str], referenced: str, insp: Inspector +) -> set[str]: """Utility to find foreign-key constraint names in alembic migrations""" names = set() @@ -761,8 +745,8 @@ def generic_find_fk_constraint_names( # pylint: disable=invalid-name def generic_find_uq_constraint_name( - table: str, columns: Set[str], insp: Inspector -) -> Optional[str]: + table: str, columns: set[str], insp: Inspector +) -> str | None: """Utility to find a unique constraint name in alembic migrations""" for uq in insp.get_unique_constraints(table): @@ -773,14 +757,14 @@ def generic_find_uq_constraint_name( def get_datasource_full_name( - database_name: str, datasource_name: str, schema: Optional[str] = None + database_name: str, datasource_name: str, schema: str | None = None ) -> str: if not schema: - return "[{}].[{}]".format(database_name, datasource_name) - return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) + return f"[{database_name}].[{datasource_name}]" + return f"[{database_name}].[{schema}].[{datasource_name}]" -def validate_json(obj: Union[bytes, bytearray, str]) -> None: +def validate_json(obj: bytes | bytearray | str) -> None: if obj: try: json.loads(obj) @@ -851,7 +835,7 @@ def __exit__( # pylint: disable=redefined-outer-name,redefined-builtin # Windows has no support for SIGALRM, so we use the timer based timeout -timeout: Union[Type[TimerTimeout], Type[SigalrmTimeout]] = ( +timeout: type[TimerTimeout] | type[SigalrmTimeout] = ( TimerTimeout if platform.system() == "Windows" else SigalrmTimeout ) @@ -897,9 +881,9 @@ def notify_user_about_perm_udate( # pylint: disable=too-many-arguments granter: User, user: User, role: Role, - datasource: "BaseDatasource", + datasource: BaseDatasource, tpl_name: str, - config: Dict[str, Any], + config: dict[str, Any], ) -> None: msg = render_template( tpl_name, granter=granter, user=user, role=role, datasource=datasource @@ -923,15 +907,15 @@ def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many to: str, subject: str, html_content: str, - config: Dict[str, Any], - files: Optional[List[str]] = None, - data: Optional[Dict[str, str]] = None, - images: Optional[Dict[str, bytes]] = None, + config: dict[str, Any], + files: list[str] | None = None, + data: dict[str, str] | None = None, + images: dict[str, bytes] | None = None, dryrun: bool = False, - cc: Optional[str] = None, - bcc: Optional[str] = None, + cc: str | None = None, + bcc: str | None = None, mime_subtype: str = "mixed", - header_data: Optional[HeaderDataType] = None, + header_data: HeaderDataType | None = None, ) -> None: """ Send an email with html content, eg: @@ -1000,9 +984,9 @@ def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many def send_mime_email( e_from: str, - e_to: List[str], + e_to: list[str], mime_msg: MIMEMultipart, - config: Dict[str, Any], + config: dict[str, Any], dryrun: bool = False, ) -> None: smtp_host = config["SMTP_HOST"] @@ -1035,8 +1019,8 @@ def send_mime_email( smtp.quit() -def get_email_address_list(address_string: str) -> List[str]: - address_string_list: List[str] = [] +def get_email_address_list(address_string: str) -> list[str]: + address_string_list: list[str] = [] if isinstance(address_string, str): address_string_list = re.split(r",|\s|;", address_string) return [x.strip() for x in address_string_list if x.strip()] @@ -1049,12 +1033,12 @@ def get_email_address_str(address_string: str) -> str: return address_list_str -def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]: +def choicify(values: Iterable[Any]) -> list[tuple[Any, Any]]: """Takes an iterable and makes an iterable of tuples with it""" return [(v, v) for v in values] -def zlib_compress(data: Union[bytes, str]) -> bytes: +def zlib_compress(data: bytes | str) -> bytes: """ Compress things in a py2/3 safe fashion >>> json_str = '{"test": 1}' @@ -1065,7 +1049,7 @@ def zlib_compress(data: Union[bytes, str]) -> bytes: return zlib.compress(data) -def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, str]: +def zlib_decompress(blob: bytes, decode: bool | None = True) -> bytes | str: """ Decompress things to a string in a py2/3 safe fashion >>> json_str = '{"test": 1}' @@ -1094,12 +1078,12 @@ def simple_filter_to_adhoc( } if filter_clause.get("isExtra"): result["isExtra"] = True - result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result)) + result["filterOptionName"] = md5_sha_from_dict(cast(dict[Any, Any], result)) return result -def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterClause: +def form_data_to_adhoc(form_data: dict[str, Any], clause: str) -> AdhocFilterClause: if clause not in ("where", "having"): raise ValueError(__("Unsupported clause type: %(clause)s", clause=clause)) result: AdhocFilterClause = { @@ -1107,19 +1091,19 @@ def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterCla "expressionType": "SQL", "sqlExpression": form_data.get(clause), } - result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result)) + result["filterOptionName"] = md5_sha_from_dict(cast(dict[Any, Any], result)) return result -def merge_extra_form_data(form_data: Dict[str, Any]) -> None: +def merge_extra_form_data(form_data: dict[str, Any]) -> None: """ Merge extra form data (appends and overrides) into the main payload and add applied time extras to the payload. """ filter_keys = ["filters", "adhoc_filters"] extra_form_data = form_data.pop("extra_form_data", {}) - append_filters: List[QueryObjectFilterClause] = extra_form_data.get("filters", None) + append_filters: list[QueryObjectFilterClause] = extra_form_data.get("filters", None) # merge append extras for key in [key for key in EXTRA_FORM_DATA_APPEND_KEYS if key not in filter_keys]: @@ -1144,9 +1128,9 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None: if extras: form_data["extras"] = extras - adhoc_filters: List[AdhocFilterClause] = form_data.get("adhoc_filters", []) + adhoc_filters: list[AdhocFilterClause] = form_data.get("adhoc_filters", []) form_data["adhoc_filters"] = adhoc_filters - append_adhoc_filters: List[AdhocFilterClause] = extra_form_data.get( + append_adhoc_filters: list[AdhocFilterClause] = extra_form_data.get( "adhoc_filters", [] ) adhoc_filters.extend( @@ -1170,7 +1154,7 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None: adhoc_filter["comparator"] = form_data["time_range"] -def merge_extra_filters(form_data: Dict[str, Any]) -> None: +def merge_extra_filters(form_data: dict[str, Any]) -> None: # extra_filters are temporary/contextual filters (using the legacy constructs) # that are external to the slice definition. We use those for dynamic # interactive filters like the ones emitted by the "Filter Box" visualization. @@ -1193,7 +1177,7 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None: # Grab list of existing filters 'keyed' on the column and operator - def get_filter_key(f: Dict[str, Any]) -> str: + def get_filter_key(f: dict[str, Any]) -> str: if "expressionType" in f: return "{}__{}".format(f["subject"], f["operator"]) @@ -1244,7 +1228,7 @@ def get_filter_key(f: Dict[str, Any]) -> str: del form_data["extra_filters"] -def merge_request_params(form_data: Dict[str, Any], params: Dict[str, Any]) -> None: +def merge_request_params(form_data: dict[str, Any], params: dict[str, Any]) -> None: """ Merge request parameters to the key `url_params` in form_data. Only updates or appends parameters to `form_data` that are defined in `params; pre-existing @@ -1261,7 +1245,7 @@ def merge_request_params(form_data: Dict[str, Any], params: Dict[str, Any]) -> N form_data["url_params"] = url_params -def user_label(user: User) -> Optional[str]: +def user_label(user: User) -> str | None: """Given a user ORM FAB object, returns a label""" if user: if user.first_name and user.last_name: @@ -1272,7 +1256,7 @@ def user_label(user: User) -> Optional[str]: return None -def get_example_default_schema() -> Optional[str]: +def get_example_default_schema() -> str | None: """ Return the default schema of the examples database, if any. """ @@ -1295,7 +1279,7 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: ) -def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]: +def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]: axis_cols = [ col for col in columns or [] @@ -1304,14 +1288,12 @@ def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]: return tuple(get_column_name(col) for col in axis_cols) -def get_xaxis_label(columns: Optional[List[Column]]) -> Optional[str]: +def get_xaxis_label(columns: list[Column] | None) -> str | None: labels = get_base_axis_labels(columns) return labels[0] if labels else None -def get_column_name( - column: Column, verbose_map: Optional[Dict[str, Any]] = None -) -> str: +def get_column_name(column: Column, verbose_map: dict[str, Any] | None = None) -> str: """ Extract label from column @@ -1336,9 +1318,7 @@ def get_column_name( raise ValueError("Missing label") -def get_metric_name( - metric: Metric, verbose_map: Optional[Dict[str, Any]] = None -) -> str: +def get_metric_name(metric: Metric, verbose_map: dict[str, Any] | None = None) -> str: """ Extract label from metric @@ -1374,9 +1354,9 @@ def get_metric_name( def get_column_names( - columns: Optional[Sequence[Column]], - verbose_map: Optional[Dict[str, Any]] = None, -) -> List[str]: + columns: Sequence[Column] | None, + verbose_map: dict[str, Any] | None = None, +) -> list[str]: return [ column for column in [get_column_name(column, verbose_map) for column in columns or []] @@ -1385,9 +1365,9 @@ def get_column_names( def get_metric_names( - metrics: Optional[Sequence[Metric]], - verbose_map: Optional[Dict[str, Any]] = None, -) -> List[str]: + metrics: Sequence[Metric] | None, + verbose_map: dict[str, Any] | None = None, +) -> list[str]: return [ metric for metric in [get_metric_name(metric, verbose_map) for metric in metrics or []] @@ -1396,9 +1376,9 @@ def get_metric_names( def get_first_metric_name( - metrics: Optional[Sequence[Metric]], - verbose_map: Optional[Dict[str, Any]] = None, -) -> Optional[str]: + metrics: Sequence[Metric] | None, + verbose_map: dict[str, Any] | None = None, +) -> str | None: metric_labels = get_metric_names(metrics, verbose_map) return metric_labels[0] if metric_labels else None @@ -1417,7 +1397,7 @@ def convert_legacy_filters_into_adhoc( # pylint: disable=invalid-name mapping = {"having": "having_filters", "where": "filters"} if not form_data.get("adhoc_filters"): - adhoc_filters: List[AdhocFilterClause] = [] + adhoc_filters: list[AdhocFilterClause] = [] form_data["adhoc_filters"] = adhoc_filters for clause, filters in mapping.items(): @@ -1475,17 +1455,13 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name sql_where_filters.append(sql_expression) elif clause == "HAVING": sql_having_filters.append(sql_expression) - form_data["where"] = " AND ".join( - ["({})".format(sql) for sql in sql_where_filters] - ) - form_data["having"] = " AND ".join( - ["({})".format(sql) for sql in sql_having_filters] - ) + form_data["where"] = " AND ".join([f"({sql})" for sql in sql_where_filters]) + form_data["having"] = " AND ".join([f"({sql})" for sql in sql_having_filters]) form_data["having_filters"] = simple_having_filters form_data["filters"] = simple_where_filters -def get_username() -> Optional[str]: +def get_username() -> str | None: """ Get username (if defined) associated with the current user. @@ -1498,7 +1474,7 @@ def get_username() -> Optional[str]: return None -def get_user_id() -> Optional[int]: +def get_user_id() -> int | None: """ Get the user identifier (if defined) associated with the current user. @@ -1517,7 +1493,7 @@ def get_user_id() -> Optional[int]: @contextmanager -def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]: +def override_user(user: User | None, force: bool = True) -> Iterator[Any]: """ Temporarily override the current user per `flask.g` with the specified user. @@ -1583,7 +1559,7 @@ def create_ssl_cert_file(certificate: str) -> str: def time_function( func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any -) -> Tuple[float, Any]: +) -> tuple[float, Any]: """ Measures the amount of time a function takes to execute in ms @@ -1603,7 +1579,7 @@ def MediumText() -> Variant: # pylint:disable=invalid-name def shortid() -> str: - return "{}".format(uuid.uuid4())[-12:] + return f"{uuid.uuid4()}"[-12:] class DatasourceName(NamedTuple): @@ -1611,7 +1587,7 @@ class DatasourceName(NamedTuple): schema: str -def get_stacktrace() -> Optional[str]: +def get_stacktrace() -> str | None: if current_app.config["SHOW_STACKTRACE"]: return traceback.format_exc() return None @@ -1649,7 +1625,7 @@ def split( yield string[i:] -def get_iterable(x: Any) -> List[Any]: +def get_iterable(x: Any) -> list[Any]: """ Get an iterable (list) representation of the object. @@ -1659,7 +1635,7 @@ def get_iterable(x: Any) -> List[Any]: return x if isinstance(x, list) else [x] -def get_form_data_token(form_data: Dict[str, Any]) -> str: +def get_form_data_token(form_data: dict[str, Any]) -> str: """ Return the token contained within form data or generate a new one. @@ -1669,7 +1645,7 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str: return form_data.get("token") or "token_" + uuid.uuid4().hex[:8] -def get_column_name_from_column(column: Column) -> Optional[str]: +def get_column_name_from_column(column: Column) -> str | None: """ Extract the physical column that a column is referencing. If the column is an adhoc column, always returns `None`. @@ -1682,7 +1658,7 @@ def get_column_name_from_column(column: Column) -> Optional[str]: return column # type: ignore -def get_column_names_from_columns(columns: List[Column]) -> List[str]: +def get_column_names_from_columns(columns: list[Column]) -> list[str]: """ Extract the physical columns that a list of columns are referencing. Ignore adhoc columns @@ -1693,7 +1669,7 @@ def get_column_names_from_columns(columns: List[Column]) -> List[str]: return [col for col in map(get_column_name_from_column, columns) if col] -def get_column_name_from_metric(metric: Metric) -> Optional[str]: +def get_column_name_from_metric(metric: Metric) -> str | None: """ Extract the column that a metric is referencing. If the metric isn't a simple metric, always returns `None`. @@ -1704,11 +1680,11 @@ def get_column_name_from_metric(metric: Metric) -> Optional[str]: if is_adhoc_metric(metric): metric = cast(AdhocMetric, metric) if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE: - return cast(Dict[str, Any], metric["column"])["column_name"] + return cast(dict[str, Any], metric["column"])["column_name"] return None -def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: +def get_column_names_from_metrics(metrics: list[Metric]) -> list[str]: """ Extract the columns that a list of metrics are referencing. Excludes all SQL metrics. @@ -1721,12 +1697,12 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: def extract_dataframe_dtypes( df: pd.DataFrame, - datasource: Optional[Union[BaseDatasource, Query]] = None, -) -> List[GenericDataType]: + datasource: BaseDatasource | Query | None = None, +) -> list[GenericDataType]: """Serialize pandas/numpy dtypes to generic types""" # omitting string types as those will be the default type - inferred_type_map: Dict[str, GenericDataType] = { + inferred_type_map: dict[str, GenericDataType] = { "floating": GenericDataType.NUMERIC, "integer": GenericDataType.NUMERIC, "mixed-integer-float": GenericDataType.NUMERIC, @@ -1737,7 +1713,7 @@ def extract_dataframe_dtypes( "date": GenericDataType.TEMPORAL, } - columns_by_name: Dict[str, Any] = {} + columns_by_name: dict[str, Any] = {} if datasource: for column in datasource.columns: if isinstance(column, dict): @@ -1745,7 +1721,7 @@ def extract_dataframe_dtypes( else: columns_by_name[column.column_name] = column - generic_types: List[GenericDataType] = [] + generic_types: list[GenericDataType] = [] for column in df.columns: column_object = columns_by_name.get(column) series = df[column] @@ -1767,7 +1743,7 @@ def extract_dataframe_dtypes( return generic_types -def extract_column_dtype(col: "BaseColumn") -> GenericDataType: +def extract_column_dtype(col: BaseColumn) -> GenericDataType: if col.is_temporal: return GenericDataType.TEMPORAL if col.is_numeric: @@ -1776,11 +1752,9 @@ def extract_column_dtype(col: "BaseColumn") -> GenericDataType: return GenericDataType.STRING -def indexed( - items: List[Any], key: Union[str, Callable[[Any], Any]] -) -> Dict[Any, List[Any]]: +def indexed(items: list[Any], key: str | Callable[[Any], Any]) -> dict[Any, list[Any]]: """Build an index for a list of objects""" - idx: Dict[Any, Any] = {} + idx: dict[Any, Any] = {} for item in items: key_ = getattr(item, key) if isinstance(key, str) else key(item) idx.setdefault(key_, []).append(item) @@ -1792,14 +1766,14 @@ def is_test() -> bool: def get_time_filter_status( - datasource: "BaseDatasource", - applied_time_extras: Dict[str, str], -) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - temporal_columns: Set[Any] = { + datasource: BaseDatasource, + applied_time_extras: dict[str, str], +) -> tuple[list[dict[str, str]], list[dict[str, str]]]: + temporal_columns: set[Any] = { col.column_name for col in datasource.columns if col.is_dttm } - applied: List[Dict[str, str]] = [] - rejected: List[Dict[str, str]] = [] + applied: list[dict[str, str]] = [] + rejected: list[dict[str, str]] = [] if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL): if time_column in temporal_columns: applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL}) @@ -1844,14 +1818,14 @@ def format_list(items: Sequence[str], sep: str = ", ", quote: str = '"') -> str: return sep.join(f"{quote}{x.replace(quote, quote_escaped)}{quote}" for x in items) -def find_duplicates(items: Iterable[InputType]) -> List[InputType]: +def find_duplicates(items: Iterable[InputType]) -> list[InputType]: """Find duplicate items in an iterable.""" return [item for item, count in collections.Counter(items).items() if count > 1] def remove_duplicates( - items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None -) -> List[InputType]: + items: Iterable[InputType], key: Callable[[InputType], Any] | None = None +) -> list[InputType]: """Remove duplicate items in an iterable.""" if not key: return list(dict.fromkeys(items).keys()) @@ -1868,9 +1842,9 @@ def remove_duplicates( @dataclass class DateColumn: col_label: str - timestamp_format: Optional[str] = None - offset: Optional[int] = None - time_shift: Optional[str] = None + timestamp_format: str | None = None + offset: int | None = None + time_shift: str | None = None def __hash__(self) -> int: return hash(self.col_label) @@ -1881,9 +1855,9 @@ def __eq__(self, other: object) -> bool: @classmethod def get_legacy_time_column( cls, - timestamp_format: Optional[str], - offset: Optional[int], - time_shift: Optional[str], + timestamp_format: str | None, + offset: int | None, + time_shift: str | None, ) -> DateColumn: return cls( timestamp_format=timestamp_format, @@ -1895,7 +1869,7 @@ def get_legacy_time_column( def normalize_dttm_col( df: pd.DataFrame, - dttm_cols: Tuple[DateColumn, ...] = tuple(), + dttm_cols: tuple[DateColumn, ...] = tuple(), ) -> None: for _col in dttm_cols: if _col.col_label not in df.columns: @@ -1925,7 +1899,7 @@ def normalize_dttm_col( df[_col.col_label] += parse_human_timedelta(_col.time_shift) -def parse_boolean_string(bool_str: Optional[str]) -> bool: +def parse_boolean_string(bool_str: str | None) -> bool: """ Convert a string representation of a true/false value into a boolean @@ -1956,7 +1930,7 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool: def apply_max_row_limit( limit: int, - max_limit: Optional[int] = None, + max_limit: int | None = None, ) -> int: """ Override row limit if max global limit is defined @@ -1979,7 +1953,7 @@ def apply_max_row_limit( return max_limit -def create_zip(files: Dict[str, Any]) -> BytesIO: +def create_zip(files: dict[str, Any]) -> BytesIO: buf = BytesIO() with ZipFile(buf, "w") as bundle: for filename, contents in files.items(): @@ -1989,7 +1963,7 @@ def create_zip(files: Dict[str, Any]) -> BytesIO: return buf -def remove_extra_adhoc_filters(form_data: Dict[str, Any]) -> None: +def remove_extra_adhoc_filters(form_data: dict[str, Any]) -> None: """ Remove filters from slice data that originate from a filter box or native filter """ diff --git a/superset/utils/csv.py b/superset/utils/csv.py index a6c834b834f2c..bab14058f2842 100644 --- a/superset/utils/csv.py +++ b/superset/utils/csv.py @@ -17,7 +17,7 @@ import logging import re import urllib.request -from typing import Any, Dict, Optional +from typing import Any, Optional from urllib.error import URLError import numpy as np @@ -81,7 +81,7 @@ def df_to_escaped_csv(df: pd.DataFrame, **kwargs: Any) -> Any: def get_chart_csv_data( - chart_url: str, auth_cookies: Optional[Dict[str, str]] = None + chart_url: str, auth_cookies: Optional[dict[str, str]] = None ) -> Optional[bytes]: content = None if auth_cookies: @@ -98,7 +98,7 @@ def get_chart_csv_data( def get_chart_dataframe( - chart_url: str, auth_cookies: Optional[Dict[str, str]] = None + chart_url: str, auth_cookies: Optional[dict[str, str]] = None ) -> Optional[pd.DataFrame]: # Disable all the unnecessary-lambda violations in this function # pylint: disable=unnecessary-lambda diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py index c0ee64370d1ce..ce89b2a255b76 100644 --- a/superset/utils/dashboard_filter_scopes_converter.py +++ b/superset/utils/dashboard_filter_scopes_converter.py @@ -17,7 +17,7 @@ import json import logging from collections import defaultdict -from typing import Any, Dict, List +from typing import Any from shortid import ShortId @@ -27,11 +27,11 @@ def convert_filter_scopes( - json_metadata: Dict[Any, Any], filter_boxes: List[Slice] -) -> Dict[int, Dict[str, Dict[str, Any]]]: + json_metadata: dict[Any, Any], filter_boxes: list[Slice] +) -> dict[int, dict[str, dict[str, Any]]]: filter_scopes = {} - immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or [] - immuned_by_column: Dict[str, List[int]] = defaultdict(list) + immuned_by_id: list[int] = json_metadata.get("filter_immune_slices") or [] + immuned_by_column: dict[str, list[int]] = defaultdict(list) for slice_id, columns in json_metadata.get( "filter_immune_slice_fields", {} ).items(): @@ -39,7 +39,7 @@ def convert_filter_scopes( immuned_by_column[column].append(int(slice_id)) def add_filter_scope( - filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int + filter_fields: dict[str, dict[str, Any]], filter_field: str, filter_id: int ) -> None: # in case filter field is invalid if isinstance(filter_field, str): @@ -54,7 +54,7 @@ def add_filter_scope( logging.info("slice [%i] has invalid field: %s", filter_id, filter_field) for filter_box in filter_boxes: - filter_fields: Dict[str, Dict[str, Any]] = {} + filter_fields: dict[str, dict[str, Any]] = {} filter_id = filter_box.id slice_params = json.loads(filter_box.params or "{}") configs = slice_params.get("filter_configs") or [] @@ -75,10 +75,10 @@ def add_filter_scope( def copy_filter_scopes( - old_to_new_slc_id_dict: Dict[int, int], - old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]], -) -> Dict[str, Dict[Any, Any]]: - new_filter_scopes: Dict[str, Dict[Any, Any]] = {} + old_to_new_slc_id_dict: dict[int, int], + old_filter_scopes: dict[int, dict[str, dict[str, Any]]], +) -> dict[str, dict[Any, Any]]: + new_filter_scopes: dict[str, dict[Any, Any]] = {} for filter_id, scopes in old_filter_scopes.items(): new_filter_key = old_to_new_slc_id_dict.get(int(filter_id)) if new_filter_key: @@ -93,10 +93,10 @@ def copy_filter_scopes( def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements - json_metadata: Dict[str, Any], - position_json: Dict[str, Any], - filter_boxes: List[Slice], -) -> List[Dict[str, Any]]: + json_metadata: dict[str, Any], + position_json: dict[str, Any], + filter_boxes: list[Slice], +) -> list[dict[str, Any]]: """ Convert the legacy filter scopes et al. to the native filter configuration. @@ -121,11 +121,11 @@ def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too filter_scopes = json_metadata.get("filter_scopes", {}) filter_box_ids = {filter_box.id for filter_box in filter_boxes} - filter_scope_by_key_and_field: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict( + filter_scope_by_key_and_field: dict[str, dict[str, dict[str, Any]]] = defaultdict( dict ) - filter_by_key_and_field: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict) + filter_by_key_and_field: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) # Dense representation of filter scopes, falling back to chart level filter configs # if the respective filter scope is not defined at the dashboard level. @@ -150,7 +150,7 @@ def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too for field, filter_scope in filter_scope_by_key_and_field[key].items(): default = default_filters.get(key, {}).get(field) - fltr: Dict[str, Any] = { + fltr: dict[str, Any] = { "cascadeParentIds": [], "id": f"NATIVE_FILTER-{shortid.generate()}", "scope": { diff --git a/superset/utils/database.py b/superset/utils/database.py index 750d873d1c9cf..70730554f393c 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from flask import current_app @@ -32,7 +32,7 @@ # TODO: duplicate code with DatabaseDao, below function should be moved or use dao def get_or_create_db( - database_name: str, sqlalchemy_uri: str, always_create: Optional[bool] = True + database_name: str, sqlalchemy_uri: str, always_create: bool | None = True ) -> Database: # pylint: disable=import-outside-toplevel from superset import db diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 7cdc23784a737..438e379a96cf0 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -20,7 +20,7 @@ from datetime import datetime, timedelta from functools import lru_cache from time import struct_time -from typing import Dict, List, Optional, Tuple +from typing import Optional import pandas as pd import parsedatetime @@ -75,7 +75,7 @@ def parse_human_datetime(human_readable: str) -> datetime: return dttm -def normalize_time_delta(human_readable: str) -> Dict[str, int]: +def normalize_time_delta(human_readable: str) -> dict[str, int]: x_unit = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s+(ago|later)*$" # pylint: disable=line-too-long,useless-suppression matched = re.match(x_unit, human_readable, re.IGNORECASE) if not matched: @@ -149,7 +149,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m time_shift: Optional[str] = None, relative_start: Optional[str] = None, relative_end: Optional[str] = None, -) -> Tuple[Optional[datetime], Optional[datetime]]: +) -> tuple[Optional[datetime], Optional[datetime]]: """Return `since` and `until` date time tuple from string representations of time_range, since, until and time_shift. @@ -227,7 +227,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m ] since_and_until_partition = [_.strip() for _ in time_range.split(separator, 1)] - since_and_until: List[Optional[str]] = [] + since_and_until: list[Optional[str]] = [] for part in since_and_until_partition: if not part: # if since or until is "", set as None diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index e77a5599057a1..4ecd2eca98679 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -17,9 +17,10 @@ from __future__ import annotations import time +from collections.abc import Iterator from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, Dict, Iterator, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, TYPE_CHECKING from flask import current_app, Response @@ -32,7 +33,7 @@ from superset.stats_logger import BaseStatsLogger -def statsd_gauge(metric_prefix: Optional[str] = None) -> Callable[..., Any]: +def statsd_gauge(metric_prefix: str | None = None) -> Callable[..., Any]: def decorate(f: Callable[..., Any]) -> Callable[..., Any]: """ Handle sending statsd gauge metric from any method or function @@ -83,13 +84,13 @@ def arghash(args: Any, kwargs: Any) -> int: return hash(sorted_args) -def debounce(duration: Union[float, int] = 0.1) -> Callable[..., Any]: +def debounce(duration: float | int = 0.1) -> Callable[..., Any]: """Ensure a function called with the same arguments executes only once per `duration` (default: 100ms). """ def decorate(f: Callable[..., Any]) -> Callable[..., Any]: - last: Dict[str, Any] = {"t": None, "input": None, "output": None} + last: dict[str, Any] = {"t": None, "input": None, "output": None} def wrapped(*args: Any, **kwargs: Any) -> Any: now = time.time() diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index 93070732e78f7..f3fb1bbd6cdea 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: +def export_schema_to_dict(back_references: bool) -> dict[str, Any]: """Exports the supported import/export schema to a dictionary""" databases = [ Database.export_schema(recursive=True, include_parent_ref=back_references) @@ -39,7 +39,7 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: def export_to_dict( session: Session, recursive: bool, back_references: bool, include_defaults: bool -) -> Dict[str, Any]: +) -> dict[str, Any]: """Exports databases and druid clusters to a dictionary""" logger.info("Starting export") dbs = session.query(Database) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index 52b784bb23f88..c812581ac498d 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -16,7 +16,7 @@ # under the License. import logging from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import Flask from flask_babel import lazy_gettext as _ @@ -31,9 +31,9 @@ class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-meth @abstractmethod def create( self, - app_config: Optional[Dict[str, Any]], - *args: List[Any], - **kwargs: Optional[Dict[str, Any]], + app_config: Optional[dict[str, Any]], + *args: list[Any], + **kwargs: Optional[dict[str, Any]], ) -> TypeDecorator: pass @@ -43,9 +43,9 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods ): def create( self, - app_config: Optional[Dict[str, Any]], - *args: List[Any], - **kwargs: Optional[Dict[str, Any]], + app_config: Optional[dict[str, Any]], + *args: list[Any], + **kwargs: Optional[dict[str, Any]], ) -> TypeDecorator: if app_config: return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs) @@ -56,7 +56,7 @@ def create( class EncryptedFieldFactory: def __init__(self) -> None: self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None - self._config: Optional[Dict[str, Any]] = None + self._config: Optional[dict[str, Any]] = None def init_app(self, app: Flask) -> None: self._config = app.config @@ -65,7 +65,7 @@ def init_app(self, app: Flask) -> None: ]() def create( - self, *args: List[Any], **kwargs: Optional[Dict[str, Any]] + self, *args: list[Any], **kwargs: Optional[dict[str, Any]] ) -> TypeDecorator: if self._concrete_type_adapter: return self._concrete_type_adapter.create(self._config, *args, **kwargs) @@ -81,14 +81,14 @@ def __init__(self, previous_secret_key: str) -> None: self._previous_secret_key = previous_secret_key self._dialect: Dialect = db.engine.url.get_dialect() - def discover_encrypted_fields(self) -> Dict[str, Dict[str, EncryptedType]]: + def discover_encrypted_fields(self) -> dict[str, dict[str, EncryptedType]]: """ Iterates over SqlAlchemy's metadata, looking for EncryptedType columns along the way. Builds up a dict of table_name -> dict of col_name: enc type instance :return: """ - meta_info: Dict[str, Any] = {} + meta_info: dict[str, Any] = {} for table_name, table in self._db.metadata.tables.items(): for col_name, col in table.columns.items(): @@ -120,7 +120,7 @@ def _read_bytes(col_name: str, value: Any) -> Optional[bytes]: @staticmethod def _select_columns_from_table( - conn: Connection, column_names: List[str], table_name: str + conn: Connection, column_names: list[str], table_name: str ) -> Row: return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}") @@ -129,7 +129,7 @@ def _re_encrypt_row( conn: Connection, row: Row, table_name: str, - columns: Dict[str, EncryptedType], + columns: dict[str, EncryptedType], ) -> None: """ Re encrypts all columns in a Row diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py index 9874656722e64..ea295c776c7c3 100644 --- a/superset/utils/feature_flag_manager.py +++ b/superset/utils/feature_flag_manager.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. from copy import deepcopy -from typing import Dict from flask import Flask @@ -25,7 +24,7 @@ def __init__(self) -> None: super().__init__() self._get_feature_flags_func = None self._is_feature_enabled_func = None - self._feature_flags: Dict[str, bool] = {} + self._feature_flags: dict[str, bool] = {} def init_app(self, app: Flask) -> None: self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"] @@ -33,7 +32,7 @@ def init_app(self, app: Flask) -> None: self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"] self._feature_flags.update(app.config["FEATURE_FLAGS"]) - def get_feature_flags(self) -> Dict[str, bool]: + def get_feature_flags(self) -> dict[str, bool]: if self._get_feature_flags_func: return self._get_feature_flags_func(deepcopy(self._feature_flags)) if callable(self._is_feature_enabled_func): diff --git a/superset/utils/filters.py b/superset/utils/filters.py index 4772f49ba0a36..88154a40b3d74 100644 --- a/superset/utils/filters.py +++ b/superset/utils/filters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Type +from typing import Any from flask_appbuilder import Model from sqlalchemy import or_ @@ -22,7 +22,7 @@ def get_dataset_access_filters( - base_model: Type[Model], + base_model: type[Model], *args: Any, ) -> BooleanClauseList: # pylint: disable=import-outside-toplevel diff --git a/superset/utils/hashing.py b/superset/utils/hashing.py index 66983582cac54..fff654263e4a5 100644 --- a/superset/utils/hashing.py +++ b/superset/utils/hashing.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import hashlib -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import simplejson as json @@ -25,7 +25,7 @@ def md5_sha_from_str(val: str) -> str: def md5_sha_from_dict( - obj: Dict[Any, Any], + obj: dict[Any, Any], ignore_nan: bool = False, default: Optional[Callable[[Any], Any]] = None, ) -> str: diff --git a/superset/utils/log.py b/superset/utils/log.py index f2379fe11c491..5430accb43ac2 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -22,25 +22,14 @@ import logging import textwrap from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import contextmanager from datetime import datetime, timedelta -from typing import ( - Any, - Callable, - cast, - Dict, - Iterator, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Literal, TYPE_CHECKING from flask import current_app, g, request from flask_appbuilder.const import API_URI_RIS_KEY from sqlalchemy.exc import SQLAlchemyError -from typing_extensions import Literal from superset.extensions import stats_logger_manager from superset.utils.core import get_user_id, LoggerLevel @@ -51,12 +40,12 @@ logger = logging.getLogger(__name__) -def collect_request_payload() -> Dict[str, Any]: +def collect_request_payload() -> dict[str, Any]: """Collect log payload identifiable from request context""" if not request: return {} - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "path": request.path, **request.form.to_dict(), # url search params can overwrite POST body @@ -81,7 +70,7 @@ def collect_request_payload() -> Dict[str, Any]: def get_logger_from_status( status: int, -) -> Tuple[Callable[..., None], str]: +) -> tuple[Callable[..., None], str]: """ Return logger method by status of exception. Maps logger level to status code level @@ -101,10 +90,10 @@ class AbstractEventLogger(ABC): def __call__( self, action: str, - object_ref: Optional[str] = None, + object_ref: str | None = None, log_to_statsd: bool = True, - duration: Optional[timedelta] = None, - **payload_override: Dict[str, Any], + duration: timedelta | None = None, + **payload_override: dict[str, Any], ) -> object: # pylint: disable=W0201 self.action = action @@ -130,12 +119,12 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @abstractmethod def log( # pylint: disable=too-many-arguments self, - user_id: Optional[int], + user_id: int | None, action: str, - dashboard_id: Optional[int], - duration_ms: Optional[int], - slice_id: Optional[int], - referrer: Optional[str], + dashboard_id: int | None, + duration_ms: int | None, + slice_id: int | None, + referrer: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -144,10 +133,10 @@ def log( # pylint: disable=too-many-arguments def log_with_context( # pylint: disable=too-many-locals self, action: str, - duration: Optional[timedelta] = None, - object_ref: Optional[str] = None, + duration: timedelta | None = None, + object_ref: str | None = None, log_to_statsd: bool = True, - **payload_override: Optional[Dict[str, Any]], + **payload_override: dict[str, Any] | None, ) -> None: # pylint: disable=import-outside-toplevel from superset.views.core import get_form_data @@ -176,7 +165,7 @@ def log_with_context( # pylint: disable=too-many-locals if payload_override: payload.update(payload_override) - dashboard_id: Optional[int] = None + dashboard_id: int | None = None try: dashboard_id = int(payload.get("dashboard_id")) # type: ignore except (TypeError, ValueError): @@ -218,7 +207,7 @@ def log_with_context( # pylint: disable=too-many-locals def log_context( self, action: str, - object_ref: Optional[str] = None, + object_ref: str | None = None, log_to_statsd: bool = True, ) -> Iterator[Callable[..., None]]: """ @@ -242,9 +231,9 @@ def log_context( def _wrapper( self, f: Callable[..., Any], - action: Optional[Union[str, Callable[..., str]]] = None, - object_ref: Optional[Union[str, Callable[..., str], Literal[False]]] = None, - allow_extra_payload: Optional[bool] = False, + action: str | Callable[..., str] | None = None, + object_ref: str | Callable[..., str] | Literal[False] | None = None, + allow_extra_payload: bool | None = False, **wrapper_kwargs: Any, ) -> Callable[..., Any]: @functools.wraps(f) @@ -314,7 +303,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger: ) ) - event_logger_type = cast(Type[Any], cfg_value) + event_logger_type = cast(type[Any], cfg_value) result = event_logger_type() # Verify that we have a valid logger impl @@ -333,12 +322,12 @@ class DBEventLogger(AbstractEventLogger): def log( # pylint: disable=too-many-arguments,too-many-locals self, - user_id: Optional[int], + user_id: int | None, action: str, - dashboard_id: Optional[int], - duration_ms: Optional[int], - slice_id: Optional[int], - referrer: Optional[str], + dashboard_id: int | None, + duration_ms: int | None, + slice_id: int | None, + referrer: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -348,7 +337,7 @@ def log( # pylint: disable=too-many-arguments,too-many-locals records = kwargs.get("records", []) logs = [] for record in records: - json_string: Optional[str] + json_string: str | None try: json_string = json.dumps(record) except Exception: # pylint: disable=broad-except diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py index 02c04abe6ae97..7e45fc0f31115 100644 --- a/superset/utils/machine_auth.py +++ b/superset/utils/machine_auth.py @@ -19,7 +19,7 @@ import importlib import logging -from typing import Callable, Dict, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from flask import current_app, Flask, request, Response, session from flask_login import login_user @@ -71,7 +71,7 @@ def authenticate_webdriver( return driver @staticmethod - def get_auth_cookies(user: User) -> Dict[str, str]: + def get_auth_cookies(user: User) -> dict[str, str]: # Login with the user specified to get the reports with current_app.test_request_context("/login"): login_user(user) diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 4b156cc10c10d..7462d1454054e 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -21,8 +21,9 @@ import random import string import sys +from collections.abc import Iterator from datetime import date, datetime, time, timedelta -from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Type +from typing import Any, Callable, cast, Optional from uuid import uuid4 import sqlalchemy.sql.sqltypes @@ -39,17 +40,14 @@ logger = logging.getLogger(__name__) -ColumnInfo = TypedDict( - "ColumnInfo", - { - "name": str, - "type": VisitableType, - "nullable": bool, - "default": Optional[Any], - "autoincrement": str, - "primary_key": int, - }, -) + +class ColumnInfo(TypedDict): + name: str + type: VisitableType + nullable: bool + default: Optional[Any] + autoincrement: str + primary_key: int example_column = { @@ -167,7 +165,7 @@ def get_type_generator( # pylint: disable=too-many-return-statements,too-many-b def add_data( - columns: Optional[List[ColumnInfo]], + columns: Optional[list[ColumnInfo]], num_rows: int, table_name: str, append: bool = True, @@ -212,16 +210,16 @@ def add_data( engine.execute(table.insert(), data) -def get_column_objects(columns: List[ColumnInfo]) -> List[Column]: +def get_column_objects(columns: list[ColumnInfo]) -> list[Column]: out = [] for column in columns: - kwargs = cast(Dict[str, Any], column.copy()) + kwargs = cast(dict[str, Any], column.copy()) kwargs["type_"] = kwargs.pop("type") out.append(Column(**kwargs)) return out -def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, Any]]: +def generate_data(columns: list[ColumnInfo], num_rows: int) -> list[dict[str, Any]]: keys = [column["name"] for column in columns] return [ dict(zip(keys, row)) @@ -229,13 +227,13 @@ def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, An ] -def generate_column_data(column: ColumnInfo, num_rows: int) -> List[Any]: +def generate_column_data(column: ColumnInfo, num_rows: int) -> list[Any]: gen = get_type_generator(column["type"]) return [gen() for _ in range(num_rows)] def add_sample_rows( - session: Session, model: Type[Model], count: int + session: Session, model: type[Model], count: int ) -> Iterator[Model]: """ Add entities of a given model. diff --git a/superset/utils/network.py b/superset/utils/network.py index 7a1aea5a7178c..fea3cfc6b2c0e 100644 --- a/superset/utils/network.py +++ b/superset/utils/network.py @@ -32,10 +32,10 @@ def is_port_open(host: str, port: int) -> bool: s = socket.socket(af, socket.SOCK_STREAM) try: s.settimeout(PORT_TIMEOUT) - s.connect((sockaddr)) + s.connect(sockaddr) s.shutdown(socket.SHUT_RDWR) return True - except socket.error as _: + except OSError as _: continue finally: s.close() diff --git a/superset/utils/pandas_postprocessing/aggregate.py b/superset/utils/pandas_postprocessing/aggregate.py index a863d260c5737..1116e4ec70beb 100644 --- a/superset/utils/pandas_postprocessing/aggregate.py +++ b/superset/utils/pandas_postprocessing/aggregate.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any from pandas import DataFrame @@ -26,7 +26,7 @@ @validate_column_args("groupby") def aggregate( - df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]] + df: DataFrame, groupby: list[str], aggregates: dict[str, dict[str, Any]] ) -> DataFrame: """ Apply aggregations to a DataFrame. diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py index 399cf569fb25c..f9fed40e59618 100644 --- a/superset/utils/pandas_postprocessing/boxplot.py +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np from flask_babel import gettext as _ @@ -27,11 +27,11 @@ def boxplot( df: DataFrame, - groupby: List[str], - metrics: List[str], + groupby: list[str], + metrics: list[str], whisker_type: PostProcessingBoxplotWhiskerType, percentiles: Optional[ - Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]] + Union[list[Union[int, float]], tuple[Union[int, float], Union[int, float]]] ] = None, ) -> DataFrame: """ @@ -102,12 +102,12 @@ def whisker_low(series: Series) -> float: whisker_high = np.max whisker_low = np.min - def outliers(series: Series) -> Set[float]: + def outliers(series: Series) -> set[float]: above = series[series > whisker_high(series)] below = series[series < whisker_low(series)] return above.tolist() + below.tolist() - operators: Dict[str, Callable[[Any], Any]] = { + operators: dict[str, Callable[[Any], Any]] = { "mean": np.mean, "median": np.median, "max": whisker_high, @@ -117,7 +117,7 @@ def outliers(series: Series) -> Set[float]: "count": np.ma.count, "outliers": outliers, } - aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = { + aggregates: dict[str, dict[str, Union[str, Callable[..., Any]]]] = { f"{metric}__{operator_name}": {"column": metric, "operator": operator} for operator_name, operator in operators.items() for metric in metrics diff --git a/superset/utils/pandas_postprocessing/compare.py b/superset/utils/pandas_postprocessing/compare.py index f7c8365508750..b20682027f4a1 100644 --- a/superset/utils/pandas_postprocessing/compare.py +++ b/superset/utils/pandas_postprocessing/compare.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Optional import pandas as pd from flask_babel import gettext as _ @@ -29,8 +29,8 @@ @validate_column_args("source_columns", "compare_columns") def compare( # pylint: disable=too-many-arguments df: DataFrame, - source_columns: List[str], - compare_columns: List[str], + source_columns: list[str], + compare_columns: list[str], compare_type: PandasPostprocessingCompare, drop_original_columns: Optional[bool] = False, precision: Optional[int] = 4, diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index f8519f39a9729..d383312f751a4 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from decimal import Decimal -from typing import List, Optional +from typing import Optional from flask_babel import gettext as _ from pandas import DataFrame @@ -31,8 +31,8 @@ def contribution( orientation: Optional[ PostProcessingContributionOrientation ] = PostProcessingContributionOrientation.COLUMN, - columns: Optional[List[str]] = None, - rename_columns: Optional[List[str]] = None, + columns: Optional[list[str]] = None, + rename_columns: Optional[list[str]] = None, ) -> DataFrame: """ Calculate cell contribution to row/column total for numeric columns. diff --git a/superset/utils/pandas_postprocessing/cum.py b/superset/utils/pandas_postprocessing/cum.py index b94f048e5cd62..128fa970f5f79 100644 --- a/superset/utils/pandas_postprocessing/cum.py +++ b/superset/utils/pandas_postprocessing/cum.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict from flask_babel import gettext as _ from pandas import DataFrame @@ -31,7 +30,7 @@ def cum( df: DataFrame, operator: str, - columns: Dict[str, str], + columns: dict[str, str], ) -> DataFrame: """ Calculate cumulative sum/product/min/max for select columns. diff --git a/superset/utils/pandas_postprocessing/diff.py b/superset/utils/pandas_postprocessing/diff.py index 0cead2de8d232..de68d39439ba4 100644 --- a/superset/utils/pandas_postprocessing/diff.py +++ b/superset/utils/pandas_postprocessing/diff.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict from pandas import DataFrame @@ -28,7 +27,7 @@ @validate_column_args("columns") def diff( df: DataFrame, - columns: Dict[str, str], + columns: dict[str, str], periods: int = 1, axis: PandasAxis = PandasAxis.ROW, ) -> DataFrame: diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index da9954ef111f6..40db86db0607c 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -from collections.abc import Iterable -from typing import Any, Sequence, Union +from collections.abc import Iterable, Sequence +from typing import Any, Union import pandas as pd diff --git a/superset/utils/pandas_postprocessing/geography.py b/superset/utils/pandas_postprocessing/geography.py index 33a27c2df4074..79046cb71a1b2 100644 --- a/superset/utils/pandas_postprocessing/geography.py +++ b/superset/utils/pandas_postprocessing/geography.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Tuple +from typing import Optional import geohash as geohash_lib from flask_babel import gettext as _ @@ -95,7 +95,7 @@ def geodetic_parse( :return: DataFrame with decoded longitudes and latitudes """ - def _parse_location(location: str) -> Tuple[float, float, float]: + def _parse_location(location: str) -> tuple[float, float, float]: """ Parse a string containing a geodetic point and return latitude, longitude and altitude diff --git a/superset/utils/pandas_postprocessing/pivot.py b/superset/utils/pandas_postprocessing/pivot.py index df5fa7e37cf94..28e8ff380fcab 100644 --- a/superset/utils/pandas_postprocessing/pivot.py +++ b/superset/utils/pandas_postprocessing/pivot.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import gettext as _ from pandas import DataFrame @@ -30,9 +30,9 @@ @validate_column_args("index", "columns") def pivot( # pylint: disable=too-many-arguments df: DataFrame, - index: List[str], - aggregates: Dict[str, Dict[str, Any]], - columns: Optional[List[str]] = None, + index: list[str], + aggregates: dict[str, dict[str, Any]], + columns: Optional[list[str]] = None, metric_fill_value: Optional[Any] = None, column_fill_value: Optional[str] = NULL_STRING, drop_missing_columns: Optional[bool] = True, diff --git a/superset/utils/pandas_postprocessing/rename.py b/superset/utils/pandas_postprocessing/rename.py index 0e35a651a8073..4bcd19782c683 100644 --- a/superset/utils/pandas_postprocessing/rename.py +++ b/superset/utils/pandas_postprocessing/rename.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional, Union +from typing import Optional, Union import pandas as pd from flask_babel import gettext as _ @@ -27,7 +27,7 @@ @validate_column_args("columns") def rename( df: pd.DataFrame, - columns: Dict[str, Union[str, None]], + columns: dict[str, Union[str, None]], inplace: bool = False, level: Optional[Level] = None, ) -> pd.DataFrame: diff --git a/superset/utils/pandas_postprocessing/rolling.py b/superset/utils/pandas_postprocessing/rolling.py index 885032eb1780c..f93a047be989e 100644 --- a/superset/utils/pandas_postprocessing/rolling.py +++ b/superset/utils/pandas_postprocessing/rolling.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from flask_babel import gettext as _ from pandas import DataFrame @@ -31,9 +31,9 @@ def rolling( # pylint: disable=too-many-arguments df: DataFrame, rolling_type: str, - columns: Dict[str, str], + columns: dict[str, str], window: Optional[int] = None, - rolling_type_options: Optional[Dict[str, Any]] = None, + rolling_type_options: Optional[dict[str, Any]] = None, center: bool = False, win_type: Optional[str] = None, min_periods: Optional[int] = None, @@ -62,7 +62,7 @@ def rolling( # pylint: disable=too-many-arguments rolling_type_options = rolling_type_options or {} df_rolling = df.loc[:, columns.keys()] - kwargs: Dict[str, Union[str, int]] = {} + kwargs: dict[str, Union[str, int]] = {} if window is None: raise InvalidPostProcessingError(_("Undefined window for rolling operation")) if window == 0: diff --git a/superset/utils/pandas_postprocessing/select.py b/superset/utils/pandas_postprocessing/select.py index 59fe886d4d9c2..c4e02508dfab5 100644 --- a/superset/utils/pandas_postprocessing/select.py +++ b/superset/utils/pandas_postprocessing/select.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional +from typing import Optional from pandas import DataFrame @@ -24,9 +24,9 @@ @validate_column_args("columns", "drop", "rename") def select( df: DataFrame, - columns: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - rename: Optional[Dict[str, str]] = None, + columns: Optional[list[str]] = None, + exclude: Optional[list[str]] = None, + rename: Optional[dict[str, str]] = None, ) -> DataFrame: """ Only select a subset of columns in the original dataset. Can be useful for diff --git a/superset/utils/pandas_postprocessing/sort.py b/superset/utils/pandas_postprocessing/sort.py index 66041a7166b9f..b6470c3546a62 100644 --- a/superset/utils/pandas_postprocessing/sort.py +++ b/superset/utils/pandas_postprocessing/sort.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Union +from typing import Optional, Union from pandas import DataFrame @@ -26,8 +26,8 @@ def sort( df: DataFrame, is_sort_index: bool = False, - by: Optional[Union[List[str], str]] = None, - ascending: Union[List[bool], bool] = True, + by: Optional[Union[list[str], str]] = None, + ascending: Union[list[bool], bool] = True, ) -> DataFrame: """ Sort a DataFrame. diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py index 2b754fbbefae4..37d53697cb89b 100644 --- a/superset/utils/pandas_postprocessing/utils.py +++ b/superset/utils/pandas_postprocessing/utils.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Dict, Sequence +from typing import Any, Callable import numpy as np import pandas as pd @@ -24,7 +25,7 @@ from superset.exceptions import InvalidPostProcessingError -NUMPY_FUNCTIONS: Dict[str, Callable[..., Any]] = { +NUMPY_FUNCTIONS: dict[str, Callable[..., Any]] = { "average": np.average, "argmin": np.argmin, "argmax": np.argmax, @@ -133,8 +134,8 @@ def wrapped(df: DataFrame, **options: Any) -> Any: def _get_aggregate_funcs( df: DataFrame, - aggregates: Dict[str, Dict[str, Any]], -) -> Dict[str, NamedAgg]: + aggregates: dict[str, dict[str, Any]], +) -> dict[str, NamedAgg]: """ Converts a set of aggregate config objects into functions that pandas can use as aggregators. Currently only numpy aggregators are supported. @@ -143,7 +144,7 @@ def _get_aggregate_funcs( :param aggregates: Mapping from column name to aggregate config. :return: Mapping from metric name to function that takes a single input argument. """ - agg_funcs: Dict[str, NamedAgg] = {} + agg_funcs: dict[str, NamedAgg] = {} for name, agg_obj in aggregates.items(): column = agg_obj.get("column", name) if column not in df: @@ -180,7 +181,7 @@ def _get_aggregate_funcs( def _append_columns( - base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str] + base_df: DataFrame, append_df: DataFrame, columns: dict[str, str] ) -> DataFrame: """ Function for adding columns from one DataFrame to another DataFrame. Calls the diff --git a/superset/utils/retries.py b/superset/utils/retries.py index d1c294714607b..8a1e6b95eadcb 100644 --- a/superset/utils/retries.py +++ b/superset/utils/retries.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Generator, List, Optional, Type +from collections.abc import Generator +from typing import Any, Callable, Optional import backoff @@ -24,9 +25,9 @@ def retry_call( func: Callable[..., Any], *args: Any, strategy: Callable[..., Generator[int, None, None]] = backoff.constant, - exception: Type[Exception] = Exception, - fargs: Optional[List[Any]] = None, - fkwargs: Optional[Dict[str, Any]] = None, + exception: type[Exception] = Exception, + fargs: Optional[list[Any]] = None, + fkwargs: Optional[dict[str, Any]] = None, **kwargs: Any ) -> Any: """ diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 88b97901b2893..5c699e9e194f8 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -18,7 +18,7 @@ import logging from io import BytesIO -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING from flask import current_app @@ -53,16 +53,16 @@ class BaseScreenshot: def __init__(self, url: str, digest: str): self.digest: str = digest self.url = url - self.screenshot: Optional[bytes] = None + self.screenshot: bytes | None = None - def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy: + def driver(self, window_size: WindowSize | None = None) -> WebDriverProxy: window_size = window_size or self.window_size return WebDriverProxy(self.driver_type, window_size) def cache_key( self, - window_size: Optional[Union[bool, WindowSize]] = None, - thumb_size: Optional[Union[bool, WindowSize]] = None, + window_size: bool | WindowSize | None = None, + thumb_size: bool | WindowSize | None = None, ) -> str: window_size = window_size or self.window_size thumb_size = thumb_size or self.thumb_size @@ -76,8 +76,8 @@ def cache_key( return md5_sha_from_dict(args) def get_screenshot( - self, user: User, window_size: Optional[WindowSize] = None - ) -> Optional[bytes]: + self, user: User, window_size: WindowSize | None = None + ) -> bytes | None: driver = self.driver(window_size) self.screenshot = driver.get_screenshot(self.url, self.element, user) return self.screenshot @@ -86,8 +86,8 @@ def get( self, user: User = None, cache: Cache = None, - thumb_size: Optional[WindowSize] = None, - ) -> Optional[BytesIO]: + thumb_size: WindowSize | None = None, + ) -> BytesIO | None: """ Get thumbnail screenshot has BytesIO from cache or fetch @@ -95,7 +95,7 @@ def get( :param cache: The cache to use :param thumb_size: Override thumbnail site """ - payload: Optional[bytes] = None + payload: bytes | None = None cache_key = self.cache_key(self.window_size, thumb_size) if cache: payload = cache.get(cache_key) @@ -112,14 +112,14 @@ def get( def get_from_cache( self, cache: Cache, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, - ) -> Optional[BytesIO]: + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, + ) -> BytesIO | None: cache_key = self.cache_key(window_size, thumb_size) return self.get_from_cache_key(cache, cache_key) @staticmethod - def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]: + def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None: logger.info("Attempting to get from cache: %s", cache_key) if payload := cache.get(cache_key): return BytesIO(payload) @@ -129,11 +129,11 @@ def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]: def compute_and_cache( # pylint: disable=too-many-arguments self, user: User = None, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, cache: Cache = None, force: bool = True, - ) -> Optional[bytes]: + ) -> bytes | None: """ Fetches the screenshot, computes the thumbnail and caches the result @@ -178,7 +178,7 @@ def resize_image( cls, img_bytes: bytes, output: str = "png", - thumb_size: Optional[WindowSize] = None, + thumb_size: WindowSize | None = None, crop: bool = True, ) -> bytes: thumb_size = thumb_size or cls.thumb_size @@ -207,8 +207,8 @@ def __init__( self, url: str, digest: str, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, ): # Chart reports are in standalone="true" mode url = modify_url_query( @@ -228,8 +228,8 @@ def __init__( self, url: str, digest: str, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, ): # per the element above, dashboard screenshots # should always capture in standalone diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py index 48ada98dccaaf..8421350f8c140 100644 --- a/superset/utils/ssh_tunnel.py +++ b/superset/utils/ssh_tunnel.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from superset.constants import PASSWORD_MASK from superset.databases.ssh_tunnel.models import SSHTunnel -def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]: +def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]: if ssh_tunnel.pop("password", None) is not None: ssh_tunnel["password"] = PASSWORD_MASK if ssh_tunnel.pop("private_key", None) is not None: @@ -32,8 +32,8 @@ def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]: def unmask_password_info( - ssh_tunnel: Dict[str, Any], model: SSHTunnel -) -> Dict[str, Any]: + ssh_tunnel: dict[str, Any], model: SSHTunnel +) -> dict[str, Any]: if ssh_tunnel.get("password") == PASSWORD_MASK: ssh_tunnel["password"] = model.password if ssh_tunnel.get("private_key") == PASSWORD_MASK: diff --git a/superset/utils/url_map_converters.py b/superset/utils/url_map_converters.py index fbd9c800b0a88..11e40267b30c5 100644 --- a/superset/utils/url_map_converters.py +++ b/superset/utils/url_map_converters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any from werkzeug.routing import BaseConverter, Map @@ -22,7 +22,7 @@ class RegexConverter(BaseConverter): - def __init__(self, url_map: Map, *items: List[str]) -> None: + def __init__(self, url_map: Map, *items: list[str]) -> None: super().__init__(url_map) self.regex = items[0] diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index 05dbee674052d..c302ab89214a4 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -20,7 +20,7 @@ import logging from enum import Enum from time import sleep -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from flask import current_app from selenium.common.exceptions import ( @@ -37,7 +37,7 @@ from superset.extensions import machine_auth_provider_factory from superset.utils.retries import retry_call -WindowSize = Tuple[int, int] +WindowSize = tuple[int, int] logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class ChartStandaloneMode(Enum): SHOW_NAV = 0 -def find_unexpected_errors(driver: WebDriver) -> List[str]: +def find_unexpected_errors(driver: WebDriver) -> list[str]: error_messages = [] try: @@ -111,7 +111,7 @@ def find_unexpected_errors(driver: WebDriver) -> List[str]: class WebDriverProxy: - def __init__(self, driver_type: str, window: Optional[WindowSize] = None): + def __init__(self, driver_type: str, window: WindowSize | None = None): self._driver_type = driver_type self._window: WindowSize = window or (800, 600) self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"] @@ -124,7 +124,7 @@ def create(self) -> WebDriver: options = firefox.options.Options() profile = FirefoxProfile() profile.set_preference("layout.css.devPixelsPerPx", str(pixel_density)) - kwargs: Dict[Any, Any] = dict(options=options, firefox_profile=profile) + kwargs: dict[Any, Any] = dict(options=options, firefox_profile=profile) elif self._driver_type == "chrome": driver_class = chrome.webdriver.WebDriver options = chrome.options.Options() @@ -164,13 +164,11 @@ def destroy(driver: WebDriver, tries: int = 2) -> None: except Exception: # pylint: disable=broad-except pass - def get_screenshot( - self, url: str, element_name: str, user: User - ) -> Optional[bytes]: + def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: driver = self.auth(user) driver.set_window_size(*self._window) driver.get(url) - img: Optional[bytes] = None + img: bytes | None = None selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"] logger.debug("Sleeping for %i seconds", selenium_headstart) sleep(selenium_headstart) diff --git a/superset/views/__init__.py b/superset/views/__init__.py index 5247f215c1870..b5a21c77f0b32 100644 --- a/superset/views/__init__.py +++ b/superset/views/__init__.py @@ -21,8 +21,6 @@ base, core, css_templates, - dashboard, - datasource, dynamic_plugins, health, redirects, diff --git a/superset/views/all_entities.py b/superset/views/all_entities.py index 4031d81d2129e..3de53be461971 100644 --- a/superset/views/all_entities.py +++ b/superset/views/all_entities.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals import logging diff --git a/superset/views/base.py b/superset/views/base.py index 97d8da1d69e62..3a72096ac2fc1 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -20,7 +20,7 @@ import os import traceback from datetime import datetime -from typing import Any, Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Optional, Union import simplejson as json import yaml @@ -140,11 +140,11 @@ def get_error_msg() -> str: def json_error_response( msg: Optional[str] = None, status: int = 500, - payload: Optional[Dict[str, Any]] = None, + payload: Optional[dict[str, Any]] = None, link: Optional[str] = None, ) -> FlaskResponse: if not payload: - payload = {"error": "{}".format(msg)} + payload = {"error": f"{msg}"} if link: payload["link"] = link @@ -156,9 +156,9 @@ def json_error_response( def json_errors_response( - errors: List[SupersetError], + errors: list[SupersetError], status: int = 500, - payload: Optional[Dict[str, Any]] = None, + payload: Optional[dict[str, Any]] = None, ) -> FlaskResponse: if not payload: payload = {} @@ -182,7 +182,7 @@ def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskRe def generate_download_headers( extension: str, filename: Optional[str] = None -) -> Dict[str, Any]: +) -> dict[str, Any]: filename = filename if filename else datetime.now().strftime("%Y%m%d_%H%M%S") content_disp = f"attachment; filename={filename}.{extension}" headers = {"Content-Disposition": content_disp} @@ -332,7 +332,7 @@ def render_app_template(self) -> FlaskResponse: ) -def menu_data(user: User) -> Dict[str, Any]: +def menu_data(user: User) -> dict[str, Any]: menu = appbuilder.menu.get_data() languages = {} @@ -396,7 +396,7 @@ def menu_data(user: User) -> Dict[str, Any]: @cache_manager.cache.memoize(timeout=60) -def cached_common_bootstrap_data(user: User) -> Dict[str, Any]: +def cached_common_bootstrap_data(user: User) -> dict[str, Any]: """Common data always sent to the client The function is memoized as the return value only changes when user permissions @@ -439,7 +439,7 @@ def cached_common_bootstrap_data(user: User) -> Dict[str, Any]: return bootstrap_data -def common_bootstrap_payload(user: User) -> Dict[str, Any]: +def common_bootstrap_payload(user: User) -> dict[str, Any]: return { **(cached_common_bootstrap_data(user)), "flash_messages": get_flashed_messages(with_categories=True), @@ -548,7 +548,7 @@ def show_unexpected_exception(ex: Exception) -> FlaskResponse: @superset_app.context_processor -def get_common_bootstrap_data() -> Dict[str, Any]: +def get_common_bootstrap_data() -> dict[str, Any]: def serialize_bootstrap_data() -> str: return json.dumps( {"common": common_bootstrap_payload(g.user)}, @@ -606,7 +606,7 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods @action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download") def yaml_export( - self, items: Union[ImportExportMixin, List[ImportExportMixin]] + self, items: Union[ImportExportMixin, list[ImportExportMixin]] ) -> FlaskResponse: if not isinstance(items, list): items = [items] @@ -663,7 +663,7 @@ def _delete(self: BaseView, primary_key: int) -> None: @action( "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False ) - def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse: + def muldelete(self: BaseView, items: list[Model]) -> FlaskResponse: if not items: abort(404) for item in items: @@ -709,7 +709,7 @@ class XlsxResponse(Response): def bind_field( - _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] + _: Any, form: DynamicForm, unbound_field: UnboundField, options: dict[Any, Any] ) -> Field: """ Customize how fields are bound by stripping all whitespace. diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 30d25382f37ee..dca7a96b1d90c 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -18,7 +18,7 @@ import functools import logging -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, cast from flask import request, Response from flask_appbuilder import Model, ModelRestApi @@ -87,7 +87,7 @@ def requires_json(f: Callable[..., Any]) -> Callable[..., Any]: Require JSON-like formatted request to the REST API """ - def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response: + def wraps(self: BaseSupersetModelRestApi, *args: Any, **kwargs: Any) -> Response: if not request.is_json: raise InvalidPayloadFormatError(message="Request is not JSON") return f(self, *args, **kwargs) @@ -135,7 +135,7 @@ def wraps(self: BaseSupersetApiMixin, *args: Any, **kwargs: Any) -> Response: class RelatedFieldFilter: # data class to specify what filter to use on a /related endpoint # pylint: disable=too-few-public-methods - def __init__(self, field_name: str, filter_class: Type[BaseFilter]): + def __init__(self, field_name: str, filter_class: type[BaseFilter]): self.field_name = field_name self.filter_class = filter_class @@ -150,7 +150,7 @@ class BaseFavoriteFilter(BaseFilter): # pylint: disable=too-few-public-methods arg_name = "" class_name = "" """ The FavStar class_name to user """ - model: Type[Union[Dashboard, Slice, SqllabQuery]] = Dashboard + model: type[Dashboard | Slice | SqllabQuery] = Dashboard """ The SQLAlchemy model """ def apply(self, query: Query, value: Any) -> Query: @@ -178,7 +178,7 @@ class BaseTagFilter(BaseFilter): # pylint: disable=too-few-public-methods arg_name = "" class_name = "" """ The Tag class_name to user """ - model: Type[Union[Dashboard, Slice, SqllabQuery, SqlaTable]] = Dashboard + model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard """ The SQLAlchemy model """ def apply(self, query: Query, value: Any) -> Query: @@ -229,7 +229,7 @@ def timing_stats(self, action: str, func_name: str, value: float) -> None: ) def send_stats_metrics( - self, response: Response, key: str, time_delta: Optional[float] = None + self, response: Response, key: str, time_delta: float | None = None ) -> None: """ Helper function to handle sending statsd metrics @@ -280,7 +280,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): "viz_types": "list", } - order_rel_fields: Dict[str, Tuple[str, str]] = {} + order_rel_fields: dict[str, tuple[str, str]] = {} """ Impose ordering on related fields query:: @@ -290,7 +290,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - base_related_field_filters: Dict[str, BaseFilter] = {} + base_related_field_filters: dict[str, BaseFilter] = {} """ This is used to specify a base filter for related fields when they are accessed through the '/related/' endpoint. @@ -302,7 +302,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - related_field_filters: Dict[str, Union[RelatedFieldFilter, str]] = {} + related_field_filters: dict[str, RelatedFieldFilter | str] = {} """ Specify a filter for related fields when they are accessed through the '/related/' endpoint. @@ -313,10 +313,10 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): "": ) } """ - allowed_rel_fields: Set[str] = set() + allowed_rel_fields: set[str] = set() # Declare a set of allowed related fields that the `related` endpoint supports. - text_field_rel_fields: Dict[str, str] = {} + text_field_rel_fields: dict[str, str] = {} """ Declare an alternative for the human readable representation of the Model object:: @@ -325,7 +325,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - extra_fields_rel_fields: Dict[str, List[str]] = {"owners": ["email", "active"]} + extra_fields_rel_fields: dict[str, list[str]] = {"owners": ["email", "active"]} """ Declare extra fields for the representation of the Model object:: @@ -334,12 +334,12 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - allowed_distinct_fields: Set[str] = set() + allowed_distinct_fields: set[str] = set() - add_columns: List[str] - edit_columns: List[str] - list_columns: List[str] - show_columns: List[str] + add_columns: list[str] + edit_columns: list[str] + list_columns: list[str] + show_columns: list[str] def __init__(self) -> None: super().__init__() @@ -347,8 +347,8 @@ def __init__(self) -> None: if self.apispec_parameter_schemas is None: # type: ignore self.apispec_parameter_schemas = {} self.apispec_parameter_schemas["get_related_schema"] = get_related_schema - self.openapi_spec_component_schemas: Tuple[ - Type[Schema], ... + self.openapi_spec_component_schemas: tuple[ + type[Schema], ... ] = self.openapi_spec_component_schemas + ( RelatedResponseSchema, DistincResponseSchema, @@ -409,7 +409,7 @@ def _get_text_for_model(self, model: Model, column_name: str) -> str: def _get_extra_field_for_model( self, model: Model, column_name: str - ) -> Dict[str, str]: + ) -> dict[str, str]: ret = {} if column_name in self.extra_fields_rel_fields: model_column_names = self.extra_fields_rel_fields.get(column_name) @@ -419,8 +419,8 @@ def _get_extra_field_for_model( return ret def _get_result_from_rows( - self, datamodel: SQLAInterface, rows: List[Model], column_name: str - ) -> List[Dict[str, Any]]: + self, datamodel: SQLAInterface, rows: list[Model], column_name: str + ) -> list[dict[str, Any]]: return [ { "value": datamodel.get_pk_value(row), @@ -434,8 +434,8 @@ def _add_extra_ids_to_result( self, datamodel: SQLAInterface, column_name: str, - ids: List[int], - result: List[Dict[str, Any]], + ids: list[int], + result: list[dict[str, Any]], ) -> None: if ids: # Filter out already present values on the result diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py index 8f4ed7735cc06..2107558dc7b5c 100644 --- a/superset/views/base_schemas.py +++ b/superset/views/base_schemas.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Optional, Union from flask import current_app, g from flask_appbuilder import Model @@ -54,7 +55,7 @@ def load( # pylint: disable=arguments-differ self, data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], many: Optional[bool] = None, - partial: Union[bool, Sequence[str], Set[str], None] = None, + partial: Union[bool, Sequence[str], set[str], None] = None, instance: Optional[Model] = None, **kwargs: Any, ) -> Any: @@ -67,7 +68,7 @@ def load( # pylint: disable=arguments-differ @post_load def make_object( - self, data: Dict[Any, Any], discard: Optional[List[str]] = None + self, data: dict[Any, Any], discard: Optional[list[str]] = None ) -> Model: """ Creates a Model object from POST or PUT requests. PUT will use self.instance @@ -95,7 +96,7 @@ class BaseOwnedSchema(BaseSupersetSchema): @post_load def make_object( - self, data: Dict[str, Any], discard: Optional[List[str]] = None + self, data: dict[str, Any], discard: Optional[list[str]] = None ) -> Model: discard = discard or [] discard.append(self.owners_field_name) @@ -107,13 +108,13 @@ def make_object( return instance @pre_load - def pre_load(self, data: Dict[Any, Any]) -> None: + def pre_load(self, data: dict[Any, Any]) -> None: # if PUT request don't set owners to empty list if not self.instance: data[self.owners_field_name] = data.get(self.owners_field_name, []) @staticmethod - def set_owners(instance: Model, owners: List[int]) -> None: + def set_owners(instance: Model, owners: list[int]) -> None: owner_objs = [] user_id = get_user_id() if user_id and user_id not in owners: diff --git a/superset/views/core.py b/superset/views/core.py index 24bc16c3106f1..3b63eb74d81ad 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -21,7 +21,7 @@ import re from contextlib import closing from datetime import datetime -from typing import Any, Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Optional from urllib import parse import backoff @@ -202,7 +202,7 @@ "your query again." ) -SqlResults = Dict[str, Any] +SqlResults = dict[str, Any] class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @@ -300,13 +300,11 @@ def request_access(self) -> FlaskResponse: datasources.add(datasource) has_access_ = all( - ( - datasource and security_manager.can_access_datasource(datasource) - for datasource in datasources - ) + datasource and security_manager.can_access_datasource(datasource) + for datasource in datasources ) if has_access_: - return redirect("/superset/dashboard/{}".format(dashboard_id)) + return redirect(f"/superset/dashboard/{dashboard_id}") if request.args.get("action") == "go": for datasource in datasources: @@ -483,7 +481,7 @@ def send_data_payload_response(viz_obj: BaseViz, payload: Any) -> FlaskResponse: return data_payload_response(*viz_obj.payload_json_and_has_error(payload)) def generate_json( - self, viz_obj: BaseViz, response_type: Optional[str] = None + self, viz_obj: BaseViz, response_type: str | None = None ) -> FlaskResponse: if response_type == ChartDataResultFormat.CSV: return CsvResponse( @@ -618,7 +616,7 @@ def explore_json_data(self, cache_key: str) -> FlaskResponse: @check_resource_permissions(check_datasource_perms) @deprecated(eol_version="3.0") def explore_json( - self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None + self, datasource_type: str | None = None, datasource_id: int | None = None ) -> FlaskResponse: """Serves all request that GET or POST form_data @@ -631,7 +629,7 @@ def explore_json( TODO: break into one endpoint for each return shape""" response_type = ChartDataResultFormat.JSON.value - responses: List[Union[ChartDataResultFormat, ChartDataResultType]] = list( + responses: list[ChartDataResultFormat | ChartDataResultType] = list( ChartDataResultFormat ) responses.extend(list(ChartDataResultType)) @@ -814,9 +812,9 @@ def get_redirect_url() -> str: # pylint: disable=too-many-locals,too-many-branches,too-many-statements def explore( self, - datasource_type: Optional[str] = None, - datasource_id: Optional[int] = None, - key: Optional[str] = None, + datasource_type: str | None = None, + datasource_id: int | None = None, + key: str | None = None, ) -> FlaskResponse: if request.method == "GET": return redirect(Superset.get_redirect_url()) @@ -879,7 +877,7 @@ def explore( # fallback unknown datasource to table type datasource_type = SqlaTable.type - datasource: Optional[BaseDatasource] = None + datasource: BaseDatasource | None = None if datasource_id is not None: try: datasource = DatasourceDAO.get_datasource( @@ -965,7 +963,7 @@ def explore( ) standalone_mode = ReservedUrlParameters.is_standalone_mode() force = request.args.get("force") in {"force", "1", "true"} - dummy_datasource_data: Dict[str, Any] = { + dummy_datasource_data: dict[str, Any] = { "type": datasource_type, "name": datasource_name, "columns": [], @@ -1058,14 +1056,14 @@ def filter( # pylint: disable=no-self-use @staticmethod def save_or_overwrite_slice( # pylint: disable=too-many-arguments,too-many-locals - slc: Optional[Slice], + slc: Slice | None, slice_add_perm: bool, slice_overwrite_perm: bool, slice_download_perm: bool, datasource_id: int, datasource_type: str, datasource_name: str, - query_context: Optional[str] = None, + query_context: str | None = None, ) -> FlaskResponse: """Save or overwrite a slice""" slice_name = request.args.get("slice_name") @@ -1100,7 +1098,7 @@ def save_or_overwrite_slice( flash(msg, "success") # Adding slice to a dashboard if requested - dash: Optional[Dashboard] = None + dash: Dashboard | None = None save_to_dashboard_id = request.args.get("save_to_dashboard_id") new_dashboard_name = request.args.get("new_dashboard_name") @@ -1293,7 +1291,7 @@ def copy_dash( # pylint: disable=no-self-use dash.dashboard_title = data["dashboard_title"] dash.css = data.get("css") - old_to_new_slice_ids: Dict[int, int] = {} + old_to_new_slice_ids: dict[int, int] = {} if data["duplicate_slices"]: # Duplicating slices as well, mapping old ids to new ones for slc in original_dash.slices: @@ -1480,7 +1478,7 @@ def testconn(self) -> FlaskResponse: ) @staticmethod - def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]: + def get_user_activity_access_error(user_id: int) -> FlaskResponse | None: try: security_manager.raise_for_user_activity_access(user_id) except SupersetSecurityException as ex: @@ -1567,7 +1565,7 @@ def fave_dashboards(self, user_id: int) -> FlaskResponse: if o.Dashboard.created_by: user = o.Dashboard.created_by dash["creator"] = str(user) - dash["creator_url"] = "/superset/profile/{}/".format(user.username) + dash["creator_url"] = f"/superset/profile/{user.username}/" payload.append(dash) return json_success(json.dumps(payload, default=utils.json_int_dttm_ser)) @@ -1607,7 +1605,7 @@ def created_dashboards(self, user_id: int) -> FlaskResponse: @expose("/user_slices", methods=("GET",)) @expose("/user_slices//", methods=("GET",)) @deprecated(new_target="/api/v1/chart/") - def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: + def user_slices(self, user_id: int | None = None) -> FlaskResponse: """List of slices a user owns, created, modified or faved""" if not user_id: user_id = cast(int, get_user_id()) @@ -1660,7 +1658,7 @@ def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: @expose("/created_slices", methods=("GET",)) @expose("/created_slices//", methods=("GET",)) @deprecated(new_target="api/v1/chart/") - def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: + def created_slices(self, user_id: int | None = None) -> FlaskResponse: """List of slices created by this user""" if not user_id: user_id = cast(int, get_user_id()) @@ -1691,7 +1689,7 @@ def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: @expose("/fave_slices", methods=("GET",)) @expose("/fave_slices//", methods=("GET",)) @deprecated(new_target="api/v1/chart/") - def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: + def fave_slices(self, user_id: int | None = None) -> FlaskResponse: """Favorite slices for a user""" if user_id is None: user_id = cast(int, get_user_id()) @@ -1721,7 +1719,7 @@ def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: if o.Slice.created_by: user = o.Slice.created_by dash["creator"] = str(user) - dash["creator_url"] = "/superset/profile/{}/".format(user.username) + dash["creator_url"] = f"/superset/profile/{user.username}/" payload.append(dash) return json_success(json.dumps(payload, default=utils.json_int_dttm_ser)) @@ -1745,7 +1743,7 @@ def warm_up_cache( # pylint: disable=too-many-locals,no-self-use table_name = request.args.get("table_name") db_name = request.args.get("db_name") extra_filters = request.args.get("extra_filters") - slices: List[Slice] = [] + slices: list[Slice] = [] if not slice_id and not (table_name and db_name): return json_error_response( @@ -1869,7 +1867,7 @@ def dashboard( self, dashboard_id_or_slug: str, # pylint: disable=unused-argument add_extra_log_payload: Callable[..., None] = lambda **kwargs: None, - dashboard: Optional[Dashboard] = None, + dashboard: Dashboard | None = None, ) -> FlaskResponse: """ Server side rendering for a dashboard @@ -2112,7 +2110,7 @@ def extra_table_metadata( # pylint: disable=no-self-use @event_logger.log_this @deprecated(new_target="api/v1/sqllab/estimate/") def estimate_query_cost( # pylint: disable=no-self-use - self, database_id: int, schema: Optional[str] = None + self, database_id: int, schema: str | None = None ) -> FlaskResponse: mydb = db.session.query(Database).get(database_id) @@ -2135,7 +2133,7 @@ def estimate_query_cost( # pylint: disable=no-self-use return json_error_response(utils.error_msg_from_exception(ex)) spec = mydb.db_engine_spec - query_cost_formatters: Dict[str, Any] = app.config[ + query_cost_formatters: dict[str, Any] = app.config[ "QUERY_COST_FORMATTERS_BY_ENGINE" ] query_cost_formatter = query_cost_formatters.get( @@ -2334,14 +2332,14 @@ def validate_sql_json( mydb = session.query(Database).filter_by(id=database_id).one_or_none() if not mydb: return json_error_response( - "Database with id {} is missing.".format(database_id), status=400 + f"Database with id {database_id} is missing.", status=400 ) spec = mydb.db_engine_spec validators_by_engine = app.config["SQL_VALIDATORS_BY_ENGINE"] if not validators_by_engine or spec.engine not in validators_by_engine: return json_error_response( - "no SQL validator is configured for {}".format(spec.engine), status=400 + f"no SQL validator is configured for {spec.engine}", status=400 ) validator_name = validators_by_engine[spec.engine] validator = get_validator_by_name(validator_name) @@ -2403,7 +2401,7 @@ def sql_json(self) -> FlaskResponse: @staticmethod def _create_sql_json_command( - execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]] + execution_context: SqlJsonExecutionContext, log_params: dict[str, Any] | None ) -> ExecuteSqlCommand: query_dao = QueryDAO() sql_json_executor = Superset._create_sql_json_executor( @@ -2556,7 +2554,7 @@ def fetch_datasource_metadata(self) -> FlaskResponse: # pylint: disable=no-self @expose("/queries/") @expose("/queries/") @deprecated(new_target="api/v1/query/updated_since") - def queries(self, last_updated_ms: Union[float, int]) -> FlaskResponse: + def queries(self, last_updated_ms: float | int) -> FlaskResponse: """ Get the updated queries. @@ -2566,7 +2564,7 @@ def queries(self, last_updated_ms: Union[float, int]) -> FlaskResponse: return self.queries_exec(last_updated_ms) @staticmethod - def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse: + def queries_exec(last_updated_ms: float | int) -> FlaskResponse: stats_logger.incr("queries") if not get_user_id(): return json_error_response( @@ -2714,7 +2712,7 @@ def profile(self, username: str) -> FlaskResponse: ) @staticmethod - def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]: + def _get_sqllab_tabs(user_id: int | None) -> dict[str, Any]: # send list of tab state ids tabs_state = ( db.session.query(TabState.id, TabState.label) @@ -2730,13 +2728,13 @@ def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]: .first() ) - databases: Dict[int, Any] = {} + databases: dict[int, Any] = {} for database in DatabaseDAO.find_all(): databases[database.id] = { k: v for k, v in database.to_json().items() if k in DATABASE_KEYS } databases[database.id]["backend"] = database.backend - queries: Dict[str, Any] = {} + queries: dict[str, Any] = {} # These are unnecessary if sqllab backend persistence is disabled if is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"): diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py index 4f122067719c2..71ef212f6d42c 100644 --- a/superset/views/dashboard/views.py +++ b/superset/views/dashboard/views.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import builtins import json import re -from typing import Callable, List, Union +from typing import Callable, Union from flask import g, redirect, request, Response from flask_appbuilder import expose @@ -64,12 +65,13 @@ def list(self) -> FlaskResponse: @action("mulexport", __("Export"), __("Export dashboards?"), "fa-database") def mulexport( # pylint: disable=no-self-use - self, items: Union["DashboardModelView", List["DashboardModelView"]] + self, + items: Union["DashboardModelView", builtins.list["DashboardModelView"]], ) -> FlaskResponse: if not isinstance(items, list): items = [items] - ids = "".join("&id={}".format(d.id) for d in items) - return redirect("/dashboard/export_dashboards_form?{}".format(ids[1:])) + ids = "".join(f"&id={d.id}" for d in items) + return redirect(f"/dashboard/export_dashboards_form?{ids[1:]}") @event_logger.log_this @has_access diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py index 5e2347528a3b0..b906e5e70b880 100644 --- a/superset/views/database/forms.py +++ b/superset/views/database/forms.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Contains the logic to create cohesive forms on the explore view""" -from typing import List from flask_appbuilder.fields import QuerySelectField from flask_appbuilder.fieldwidgets import BS3TextFieldWidget @@ -44,7 +43,7 @@ class UploadToDatabaseForm(DynamicForm): @staticmethod - def file_allowed_dbs() -> List[Database]: + def file_allowed_dbs() -> list[Database]: file_enabled_dbs = ( db.session.query(Database).filter_by(allow_file_upload=True).all() ) diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index efd0b6c6eb25e..deb1b88f1f89d 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -227,7 +227,7 @@ def pre_delete(self, database: Database) -> None: # pylint: disable=no-self-use Markup( "Cannot delete a database that has tables attached. " "Here's the list of associated tables: " - + ", ".join("{}".format(table) for table in database.tables) + + ", ".join(f"{table}" for table in database.tables) ) ) diff --git a/superset/views/database/validators.py b/superset/views/database/validators.py index 29d80611a2421..2ee49c8210736 100644 --- a/superset/views/database/validators.py +++ b/superset/views/database/validators.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Type +from typing import Optional from flask_babel import lazy_gettext as _ from marshmallow import ValidationError @@ -27,7 +27,7 @@ def sqlalchemy_uri_validator( - uri: str, exception: Type[ValidationError] = ValidationError + uri: str, exception: type[ValidationError] = ValidationError ) -> None: """ Check if a user has submitted a valid SQLAlchemy URI diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index b71b3defa850a..5b1700708ad82 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from marshmallow import fields, post_load, pre_load, Schema, validate from typing_extensions import TypedDict @@ -76,7 +76,7 @@ class SamplesPayloadSchema(Schema): @pre_load # pylint: disable=no-self-use, unused-argument - def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def handle_none(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: if data is None: return {} return data diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 42cddf416794d..a4cf0c5e9063f 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from superset import app, db from superset.common.chart_data import ChartDataResultType @@ -27,7 +27,7 @@ from superset.views.datasource.schemas import SamplesPayloadSchema -def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]: +def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> dict[str, int]: samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000) limit = samples_row_limit offset = 0 @@ -50,7 +50,7 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals page: int = 1, per_page: int = 1000, payload: Optional[SamplesPayloadSchema] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = DatasourceDAO.get_datasource( session=db.session, datasource_type=datasource_type, diff --git a/superset/views/log/dao.py b/superset/views/log/dao.py index 71d8a62348641..87bc0817daf98 100644 --- a/superset/views/log/dao.py +++ b/superset/views/log/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import Any import humanize from sqlalchemy import and_, or_ @@ -34,8 +34,8 @@ class LogDAO(BaseDAO): @staticmethod def get_recent_activity( - user_id: int, actions: List[str], distinct: bool, page: int, page_size: int - ) -> List[Dict[str, Any]]: + user_id: int, actions: list[str], distinct: bool, page: int, page_size: int + ) -> list[dict[str, Any]]: has_subject_title = or_( and_( Dashboard.dashboard_title is not None, diff --git a/superset/views/tags.py b/superset/views/tags.py index bd4f43a0d978a..4f9d55aed75f8 100644 --- a/superset/views/tags.py +++ b/superset/views/tags.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals import logging diff --git a/superset/views/users/__init__.py b/superset/views/users/__init__.py index fd9417fe5c1e9..13a83393a9124 100644 --- a/superset/views/users/__init__.py +++ b/superset/views/users/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/superset/views/utils.py b/superset/views/utils.py index a366ac683c206..9b515edc26f2e 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -17,7 +17,7 @@ import logging from collections import defaultdict from functools import wraps -from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, DefaultDict, Optional, Union from urllib import parse import msgpack @@ -55,12 +55,12 @@ logger = logging.getLogger(__name__) stats_logger = app.config["STATS_LOGGER"] -REJECTED_FORM_DATA_KEYS: List[str] = [] +REJECTED_FORM_DATA_KEYS: list[str] = [] if not feature_flag_manager.is_feature_enabled("ENABLE_JAVASCRIPT_CONTROLS"): REJECTED_FORM_DATA_KEYS = ["js_tooltip", "js_onclick_href", "js_data_mutator"] -def sanitize_datasource_data(datasource_data: Dict[str, Any]) -> Dict[str, Any]: +def sanitize_datasource_data(datasource_data: dict[str, Any]) -> dict[str, Any]: if datasource_data: datasource_database = datasource_data.get("database") if datasource_database: @@ -69,7 +69,7 @@ def sanitize_datasource_data(datasource_data: Dict[str, Any]) -> Dict[str, Any]: return datasource_data -def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, Any]: +def bootstrap_user_data(user: User, include_perms: bool = False) -> dict[str, Any]: if user.is_anonymous: payload = {} user.roles = (security_manager.find_role("Public"),) @@ -103,7 +103,7 @@ def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, An def get_permissions( user: User, -) -> Tuple[Dict[str, List[Tuple[str]]], DefaultDict[str, List[str]]]: +) -> tuple[dict[str, list[tuple[str]]], DefaultDict[str, list[str]]]: if not user.roles: raise AttributeError("User object does not have roles") @@ -138,7 +138,7 @@ def get_viz( return viz_obj -def loads_request_json(request_json_data: str) -> Dict[Any, Any]: +def loads_request_json(request_json_data: str) -> dict[Any, Any]: try: return json.loads(request_json_data) except (TypeError, json.JSONDecodeError): @@ -148,9 +148,9 @@ def loads_request_json(request_json_data: str) -> Dict[Any, Any]: def get_form_data( # pylint: disable=too-many-locals slice_id: Optional[int] = None, use_slice_data: bool = False, - initial_form_data: Optional[Dict[str, Any]] = None, -) -> Tuple[Dict[str, Any], Optional[Slice]]: - form_data: Dict[str, Any] = initial_form_data or {} + initial_form_data: Optional[dict[str, Any]] = None, +) -> tuple[dict[str, Any], Optional[Slice]]: + form_data: dict[str, Any] = initial_form_data or {} if has_request_context(): # chart data API requests are JSON @@ -222,7 +222,7 @@ def get_form_data( # pylint: disable=too-many-locals return form_data, slc -def add_sqllab_custom_filters(form_data: Dict[Any, Any]) -> Any: +def add_sqllab_custom_filters(form_data: dict[Any, Any]) -> Any: """ SQLLab can include a "filters" attribute in the templateParams. The filters attribute is a list of filters to include in the @@ -244,7 +244,7 @@ def add_sqllab_custom_filters(form_data: Dict[Any, Any]) -> Any: def get_datasource_info( datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData -) -> Tuple[int, Optional[str]]: +) -> tuple[int, Optional[str]]: """ Compatibility layer for handling of datasource info @@ -277,8 +277,8 @@ def get_datasource_info( def apply_display_max_row_limit( - sql_results: Dict[str, Any], rows: Optional[int] = None -) -> Dict[str, Any]: + sql_results: dict[str, Any], rows: Optional[int] = None +) -> dict[str, Any]: """ Given a `sql_results` nested structure, applies a limit to the number of rows @@ -311,7 +311,7 @@ def apply_display_max_row_limit( def get_dashboard_extra_filters( slice_id: int, dashboard_id: int -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: session = db.session() dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none() @@ -348,11 +348,11 @@ def get_dashboard_extra_filters( def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-blocks - layout: Dict[str, Dict[str, Any]], - filter_scopes: Dict[str, Dict[str, Any]], - default_filters: Dict[str, Dict[str, List[Any]]], + layout: dict[str, dict[str, Any]], + filter_scopes: dict[str, dict[str, Any]], + default_filters: dict[str, dict[str, list[Any]]], slice_id: int, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: extra_filters = [] # do not apply filters if chart is not in filter's scope or chart is immune to the @@ -360,7 +360,7 @@ def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-bloc for filter_id, columns in default_filters.items(): filter_slice = db.session.query(Slice).filter_by(id=filter_id).one_or_none() - filter_configs: List[Dict[str, Any]] = [] + filter_configs: list[dict[str, Any]] = [] if filter_slice: filter_configs = ( json.loads(filter_slice.params or "{}").get("filter_configs") or [] @@ -403,7 +403,7 @@ def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-bloc def is_slice_in_container( - layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int + layout: dict[str, dict[str, Any]], container_id: str, slice_id: int ) -> bool: if container_id == "ROOT_ID": return True @@ -551,7 +551,7 @@ def check_slice_perms(_self: Any, slice_id: int) -> None: def _deserialize_results_payload( payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False -) -> Dict[str, Any]: +) -> dict[str, Any]: logger.debug("Deserializing from msgpack: %r", use_msgpack) if use_msgpack: with stats_timing( diff --git a/superset/viz.py b/superset/viz.py index a7b4a8952abe9..3bb6204524c9d 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -30,19 +30,7 @@ from collections import defaultdict, OrderedDict from datetime import date, datetime, timedelta from itertools import product -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Optional, TYPE_CHECKING import geohash import numpy as np @@ -124,7 +112,7 @@ class BaseViz: # pylint: disable=too-many-public-methods """All visualizations derive this base class""" - viz_type: Optional[str] = None + viz_type: str | None = None verbose_name = "Base Viz" credits = "" is_timeseries = False @@ -134,8 +122,8 @@ class BaseViz: # pylint: disable=too-many-public-methods @deprecated(deprecated_in="3.0") def __init__( self, - datasource: "BaseDatasource", - form_data: Dict[str, Any], + datasource: BaseDatasource, + form_data: dict[str, Any], force: bool = False, force_cached: bool = False, ) -> None: @@ -150,25 +138,25 @@ def __init__( self.query = "" self.token = utils.get_form_data_token(form_data) - self.groupby: List[Column] = self.form_data.get("groupby") or [] + self.groupby: list[Column] = self.form_data.get("groupby") or [] self.time_shift = timedelta() - self.status: Optional[str] = None + self.status: str | None = None self.error_msg = "" - self.results: Optional[QueryResult] = None - self.applied_filter_columns: List[Column] = [] - self.rejected_filter_columns: List[Column] = [] - self.errors: List[Dict[str, Any]] = [] + self.results: QueryResult | None = None + self.applied_filter_columns: list[Column] = [] + self.rejected_filter_columns: list[Column] = [] + self.errors: list[dict[str, Any]] = [] self.force = force self._force_cached = force_cached - self.from_dttm: Optional[datetime] = None - self.to_dttm: Optional[datetime] = None - self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = [] + self.from_dttm: datetime | None = None + self.to_dttm: datetime | None = None + self._extra_chart_data: list[tuple[str, pd.DataFrame]] = [] self.process_metrics() - self.applied_filters: List[Dict[str, str]] = [] - self.rejected_filters: List[Dict[str, str]] = [] + self.applied_filters: list[dict[str, str]] = [] + self.rejected_filters: list[dict[str, str]] = [] @property @deprecated(deprecated_in="3.0") @@ -196,8 +184,8 @@ def process_metrics(self) -> None: @staticmethod @deprecated(deprecated_in="3.0") def handle_js_int_overflow( - data: Dict[str, List[Dict[str, Any]]] - ) -> Dict[str, List[Dict[str, Any]]]: + data: dict[str, list[dict[str, Any]]] + ) -> dict[str, list[dict[str, Any]]]: for record in data.get("records", {}): for k, v in list(record.items()): if isinstance(v, int): @@ -259,7 +247,7 @@ def apply_rolling(self, df: pd.DataFrame) -> pd.DataFrame: return df @deprecated(deprecated_in="3.0") - def get_samples(self) -> Dict[str, Any]: + def get_samples(self) -> dict[str, Any]: query_obj = self.query_obj() query_obj.update( { @@ -281,7 +269,7 @@ def get_samples(self) -> Dict[str, Any]: } @deprecated(deprecated_in="3.0") - def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: + def get_df(self, query_obj: QueryObjectDict | None = None) -> pd.DataFrame: """Returns a pandas dataframe based on the query object""" if not query_obj: query_obj = self.query_obj() @@ -346,10 +334,10 @@ def process_query_filters(self) -> None: @staticmethod @deprecated(deprecated_in="3.0") - def dedup_columns(*columns_args: Optional[List[Column]]) -> List[Column]: + def dedup_columns(*columns_args: list[Column] | None) -> list[Column]: # dedup groupby and columns while preserving order - labels: List[str] = [] - deduped_columns: List[Column] = [] + labels: list[str] = [] + deduped_columns: list[Column] = [] for columns in columns_args: for column in columns or []: label = get_column_name(column) @@ -492,7 +480,7 @@ def cache_key(self, query_obj: QueryObjectDict, **extra: Any) -> str: return md5_sha_from_str(json_data) @deprecated(deprecated_in="3.0") - def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload: + def get_payload(self, query_obj: QueryObjectDict | None = None) -> VizPayload: """Returns a payload of metadata and data""" try: @@ -534,8 +522,8 @@ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload @deprecated(deprecated_in="3.0") def get_df_payload( # pylint: disable=too-many-statements - self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any - ) -> Dict[str, Any]: + self, query_obj: QueryObjectDict | None = None, **kwargs: Any + ) -> dict[str, Any]: """Handles caching around the df payload retrieval""" if not query_obj: query_obj = self.query_obj() @@ -587,7 +575,7 @@ def get_df_payload( # pylint: disable=too-many-statements ) + get_column_names_from_columns(query_obj.get("groupby") or []) + utils.get_column_names_from_metrics( - cast(List[Metric], query_obj.get("metrics") or []) + cast(list[Metric], query_obj.get("metrics") or []) ) if col not in self.datasource.column_names ] @@ -676,12 +664,12 @@ def has_error(payload: VizPayload) -> bool: ) @deprecated(deprecated_in="3.0") - def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]: + def payload_json_and_has_error(self, payload: VizPayload) -> tuple[str, bool]: return self.json_dumps(payload), self.has_error(payload) @property @deprecated(deprecated_in="3.0") - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: """This is the data object serialized to the js layer""" content = { "form_data": self.form_data, @@ -692,7 +680,7 @@ def data(self) -> Dict[str, Any]: return content @deprecated(deprecated_in="3.0") - def get_csv(self) -> Optional[str]: + def get_csv(self) -> str | None: df = self.get_df_payload()["df"] # leverage caching logic include_index = not isinstance(df.index, pd.RangeIndex) return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"]) @@ -766,8 +754,8 @@ def process_metrics(self) -> None: else QueryMode.AGGREGATE ) - columns: List[str] # output columns sans time and percent_metric column - percent_columns: List[str] = [] # percent columns that needs extra computation + columns: list[str] # output columns sans time and percent_metric column + percent_columns: list[str] = [] # percent columns that needs extra computation if self.query_mode == QueryMode.RAW: columns = get_metric_names(self.form_data.get("all_columns")) @@ -906,7 +894,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: return None columns = None - values: Union[List[str], str] = self.metric_labels + values: list[str] | str = self.metric_labels if self.form_data.get("groupby"): values = self.metric_labels[0] columns = get_column_names(self.form_data.get("groupby")) @@ -948,10 +936,8 @@ def query_obj(self) -> QueryObjectDict: if transpose and not columns: raise QueryObjectValidationError( _( - ( - "Please choose at least one 'Columns' field when " - "select 'Transpose Pivot' option" - ) + "Please choose at least one 'Columns' field when " + "select 'Transpose Pivot' option" ) ) if not metrics: @@ -973,8 +959,8 @@ def query_obj(self) -> QueryObjectDict: @staticmethod @deprecated(deprecated_in="3.0") def get_aggfunc( - metric: str, df: pd.DataFrame, form_data: Dict[str, Any] - ) -> Union[str, Callable[[Any], Any]]: + metric: str, df: pd.DataFrame, form_data: dict[str, Any] + ) -> str | Callable[[Any], Any]: aggfunc = form_data.get("pandas_aggfunc") or "sum" if pd.api.types.is_numeric_dtype(df[metric]): # Ensure that Pandas's sum function mimics that of SQL. @@ -985,7 +971,7 @@ def get_aggfunc( @staticmethod @deprecated(deprecated_in="3.0") - def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str: + def _format_datetime(value: pd.Timestamp | datetime | date | str) -> str: """ Format a timestamp in such a way that the viz will be able to apply the correct formatting in the frontend. @@ -994,7 +980,7 @@ def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str: :return: formatted timestamp if it is a valid timestamp, otherwise the original value """ - tstamp: Optional[pd.Timestamp] = None + tstamp: pd.Timestamp | None = None if isinstance(value, pd.Timestamp): tstamp = value if isinstance(value, (date, datetime)): @@ -1018,7 +1004,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: del df[DTTM_ALIAS] metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]] - aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {} + aggfuncs: dict[str, str | Callable[[Any], Any]] = {} for metric in metrics: aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data) @@ -1088,7 +1074,7 @@ def query_obj(self) -> QueryObjectDict: return query_obj @deprecated(deprecated_in="3.0") - def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]: + def _nest(self, metric: str, df: pd.DataFrame) -> list[dict[str, Any]]: nlevels = df.index.nlevels if nlevels == 1: result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])] @@ -1200,7 +1186,7 @@ class NVD3Viz(BaseViz): """Base class for all nvd3 vizs""" credits = 'NVD3.org' - viz_type: Optional[str] = None + viz_type: str | None = None verbose_name = "Base NVD3 Viz" is_timeseries = False @@ -1249,7 +1235,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: df["shape"] = "circle" df["group"] = df[[get_column_name(self.series)]] # type: ignore - series: Dict[Any, List[Any]] = defaultdict(list) + series: dict[Any, list[Any]] = defaultdict(list) for row in df.to_dict(orient="records"): series[row["group"]].append(row) chart_data = [] @@ -1357,7 +1343,7 @@ class NVD3TimeSeriesViz(NVD3Viz): verbose_name = _("Time Series - Line Chart") sort_series = False is_timeseries = True - pivot_fill_value: Optional[int] = None + pivot_fill_value: int | None = None @deprecated(deprecated_in="3.0") def query_obj(self) -> QueryObjectDict: @@ -1376,7 +1362,7 @@ def query_obj(self) -> QueryObjectDict: @deprecated(deprecated_in="3.0") def to_series( # pylint: disable=too-many-branches self, df: pd.DataFrame, classed: str = "", title_suffix: str = "" - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: cols = [] for col in df.columns: if col == "": @@ -1393,7 +1379,7 @@ def to_series( # pylint: disable=too-many-branches ys = series[name] if df[name].dtype.kind not in "biufc": continue - series_title: Union[List[str], str, Tuple[str, ...]] + series_title: list[str] | str | tuple[str, ...] if isinstance(name, list): series_title = [str(title) for title in name] elif isinstance(name, tuple): @@ -1510,7 +1496,7 @@ def run_extra_queries(self) -> None: dttm_series = df2[DTTM_ALIAS] + delta df2 = df2.drop(DTTM_ALIAS, axis=1) df2 = pd.concat([dttm_series, df2], axis=1) - label = "{} offset".format(option) + label = f"{option} offset" df2 = self.process_data(df2) self._extra_chart_data.append((label, df2)) @@ -1524,9 +1510,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: for i, (label, df2) in enumerate(self._extra_chart_data): chart_data.extend( - self.to_series( - df2, classed="time-shift-{}".format(i), title_suffix=label - ) + self.to_series(df2, classed=f"time-shift-{i}", title_suffix=label) ) else: chart_data = [] @@ -1547,16 +1531,14 @@ def get_data(self, df: pd.DataFrame) -> VizData: diff = df / df2 else: raise QueryObjectValidationError( - "Invalid `comparison_type`: {0}".format(comparison_type) + f"Invalid `comparison_type`: {comparison_type}" ) # remove leading/trailing NaNs from the time shift difference diff = diff[diff.first_valid_index() : diff.last_valid_index()] chart_data.extend( - self.to_series( - diff, classed="time-shift-{}".format(i), title_suffix=label - ) + self.to_series(diff, classed=f"time-shift-{i}", title_suffix=label) ) if not self.sort_series: @@ -1670,7 +1652,7 @@ def query_obj(self) -> QueryObjectDict: return query_obj @deprecated(deprecated_in="3.0") - def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]: + def to_series(self, df: pd.DataFrame, classed: str = "") -> list[dict[str, Any]]: cols = [] for col in df.columns: if col == "": @@ -1823,7 +1805,7 @@ def query_obj(self) -> QueryObjectDict: return query_obj @deprecated(deprecated_in="3.0") - def labelify(self, keys: Union[List[str], str], column: str) -> str: + def labelify(self, keys: list[str] | str, column: str) -> str: if isinstance(keys, str): keys = [keys] # removing undesirable characters @@ -2033,17 +2015,17 @@ def get_data(self, df: pd.DataFrame) -> VizData: df["target"] = df["target"].astype(str) recs = df.to_dict(orient="records") - hierarchy: Dict[str, Set[str]] = defaultdict(set) + hierarchy: dict[str, set[str]] = defaultdict(set) for row in recs: hierarchy[row["source"]].add(row["target"]) @deprecated(deprecated_in="3.0") - def find_cycle(graph: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]: + def find_cycle(graph: dict[str, set[str]]) -> tuple[str, str] | None: """Whether there's a cycle in a directed graph""" path = set() @deprecated(deprecated_in="3.0") - def visit(vertex: str) -> Optional[Tuple[str, str]]: + def visit(vertex: str) -> tuple[str, str] | None: path.add(vertex) for neighbour in graph.get(vertex, ()): if neighbour in path or visit(neighbour): @@ -2214,7 +2196,7 @@ class FilterBoxViz(BaseViz): """A multi filter, multi-choice filter box to make dashboards interactive""" - query_context_factory: Optional[QueryContextFactory] = None + query_context_factory: QueryContextFactory | None = None viz_type = "filter_box" verbose_name = _("Filters") is_timeseries = False @@ -2581,20 +2563,20 @@ class BaseDeckGLViz(BaseViz): is_timeseries = False credits = 'deck.gl' - spatial_control_keys: List[str] = [] + spatial_control_keys: list[str] = [] @deprecated(deprecated_in="3.0") - def get_metrics(self) -> List[str]: + def get_metrics(self) -> list[str]: # pylint: disable=attribute-defined-outside-init self.metric = self.form_data.get("size") return [self.metric] if self.metric else [] @deprecated(deprecated_in="3.0") - def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None: + def process_spatial_query_obj(self, key: str, group_by: list[str]) -> None: group_by.extend(self.get_spatial_columns(key)) @deprecated(deprecated_in="3.0") - def get_spatial_columns(self, key: str) -> List[str]: + def get_spatial_columns(self, key: str) -> list[str]: spatial = self.form_data.get(key) if spatial is None: raise ValueError(_("Bad spatial key")) @@ -2611,7 +2593,7 @@ def get_spatial_columns(self, key: str) -> List[str]: @staticmethod @deprecated(deprecated_in="3.0") - def parse_coordinates(latlog: Any) -> Optional[Tuple[float, float]]: + def parse_coordinates(latlog: Any) -> tuple[float, float] | None: if not latlog: return None try: @@ -2624,7 +2606,7 @@ def parse_coordinates(latlog: Any) -> Optional[Tuple[float, float]]: @staticmethod @deprecated(deprecated_in="3.0") - def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]: + def reverse_geohash_decode(geohash_code: str) -> tuple[str, str]: lat, lng = geohash.decode(geohash_code) return (lng, lat) @@ -2692,7 +2674,7 @@ def query_obj(self) -> QueryObjectDict: self.add_null_filters() query_obj = super().query_obj() - group_by: List[str] = [] + group_by: list[str] = [] for key in self.spatial_control_keys: self.process_spatial_query_obj(key, group_by) @@ -2720,7 +2702,7 @@ def query_obj(self) -> QueryObjectDict: return query_obj @deprecated(deprecated_in="3.0") - def get_js_columns(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_js_columns(self, data: dict[str, Any]) -> dict[str, Any]: cols = self.form_data.get("js_columns") or [] return {col: data.get(col) for col in cols} @@ -2748,7 +2730,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: } @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: raise NotImplementedError() @@ -2774,7 +2756,7 @@ def query_obj(self) -> QueryObjectDict: return super().query_obj() @deprecated(deprecated_in="3.0") - def get_metrics(self) -> List[str]: + def get_metrics(self) -> list[str]: # pylint: disable=attribute-defined-outside-init self.metric = None if self.point_radius_fixed.get("type") == "metric": @@ -2783,7 +2765,7 @@ def get_metrics(self) -> List[str]: return [] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "metric": data.get(self.metric_label) if self.metric_label else None, "radius": self.fixed_value @@ -2825,7 +2807,7 @@ def query_obj(self) -> QueryObjectDict: return super().query_obj() @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -2849,7 +2831,7 @@ class DeckGrid(BaseDeckGLViz): spatial_control_keys = ["spatial"] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -2864,7 +2846,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: @deprecated(deprecated_in="3.0") -def geohash_to_json(geohash_code: str) -> List[List[float]]: +def geohash_to_json(geohash_code: str) -> list[list[float]]: bbox = geohash.bbox(geohash_code) return [ [bbox.get("w"), bbox.get("n")], @@ -2907,7 +2889,7 @@ def query_obj(self) -> QueryObjectDict: return query_obj @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: line_type = self.form_data["line_type"] deser = self.deser_map[line_type] line_column = self.form_data["line_column"] @@ -2946,14 +2928,14 @@ def query_obj(self) -> QueryObjectDict: return super().query_obj() @deprecated(deprecated_in="3.0") - def get_metrics(self) -> List[str]: + def get_metrics(self) -> list[str]: metrics = [self.form_data.get("metric")] if self.elevation.get("type") == "metric": metrics.append(self.elevation.get("value")) return [metric for metric in metrics if metric] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: super().get_properties(data) elevation = self.form_data["point_radius_fixed"]["value"] type_ = self.form_data["point_radius_fixed"]["type"] @@ -2974,7 +2956,7 @@ class DeckHex(BaseDeckGLViz): spatial_control_keys = ["spatial"] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -2996,7 +2978,7 @@ class DeckHeatmap(BaseDeckGLViz): verbose_name = _("Deck.gl - Heatmap") spatial_control_keys = ["spatial"] - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -3025,7 +3007,7 @@ def query_obj(self) -> QueryObjectDict: return query_obj @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: geojson = data[get_column_name(self.form_data["geojson"])] return json.loads(geojson) @@ -3047,7 +3029,7 @@ def query_obj(self) -> QueryObjectDict: return super().query_obj() @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: dim = self.form_data.get("dimension") return { "sourcePosition": data.get("start_spatial"), @@ -3153,7 +3135,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: else: cols.append(col) df.columns = cols - data: Dict[str, List[Dict[str, Any]]] = {} + data: dict[str, list[dict[str, Any]]] = {} series = df.to_dict("series") for name_set in df.columns: # If no groups are defined, nameSet will be the metric name @@ -3188,7 +3170,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: return None data = super().get_data(df) - result: Dict[str, List[Dict[str, str]]] = {} + result: dict[str, list[dict[str, str]]] = {} for datum in data: key = datum["key"] for val in datum["values"]: @@ -3227,8 +3209,8 @@ def query_obj(self) -> QueryObjectDict: @staticmethod @deprecated(deprecated_in="3.0") def levels_for( - time_op: str, groups: List[str], df: pd.DataFrame - ) -> Dict[int, pd.Series]: + time_op: str, groups: list[str], df: pd.DataFrame + ) -> dict[int, pd.Series]: """ Compute the partition at each `level` from the dataframe. """ @@ -3245,8 +3227,8 @@ def levels_for( @staticmethod @deprecated(deprecated_in="3.0") def levels_for_diff( - time_op: str, groups: List[str], df: pd.DataFrame - ) -> Dict[int, pd.DataFrame]: + time_op: str, groups: list[str], df: pd.DataFrame + ) -> dict[int, pd.DataFrame]: # Obtain a unique list of the time grains times = list(set(df[DTTM_ALIAS])) times.sort() @@ -3282,8 +3264,8 @@ def levels_for_diff( @deprecated(deprecated_in="3.0") def levels_for_time( - self, groups: List[str], df: pd.DataFrame - ) -> Dict[int, VizData]: + self, groups: list[str], df: pd.DataFrame + ) -> dict[int, VizData]: procs = {} for i in range(0, len(groups) + 1): self.form_data["groupby"] = groups[:i] @@ -3295,11 +3277,11 @@ def levels_for_time( @deprecated(deprecated_in="3.0") def nest_values( self, - levels: Dict[int, pd.DataFrame], + levels: dict[int, pd.DataFrame], level: int = 0, - metric: Optional[str] = None, - dims: Optional[List[str]] = None, - ) -> List[Dict[str, Any]]: + metric: str | None = None, + dims: list[str] | None = None, + ) -> list[dict[str, Any]]: """ Nest values at each level on the back-end with access and setting, instead of summing from the bottom. @@ -3340,11 +3322,11 @@ def nest_values( @deprecated(deprecated_in="3.0") def nest_procs( self, - procs: Dict[int, pd.DataFrame], + procs: dict[int, pd.DataFrame], level: int = -1, - dims: Optional[Tuple[str, ...]] = None, + dims: tuple[str, ...] | None = None, time: Any = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: if dims is None: dims = () if level == -1: @@ -3395,7 +3377,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: @deprecated(deprecated_in="3.0") -def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]: +def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]: return set(cls.__subclasses__()).union( [sc for c in cls.__subclasses__() for sc in get_subclasses(c)] ) diff --git a/tests/common/logger_utils.py b/tests/common/logger_utils.py index 98471342b7f1a..8cb443cac8d37 100644 --- a/tests/common/logger_utils.py +++ b/tests/common/logger_utils.py @@ -29,7 +29,7 @@ Signature, ) from logging import Logger -from typing import Any, Callable, cast, Optional, Type, Union +from typing import Any, Callable, cast, Union _DEFAULT_ENTER_MSG_PREFIX = "enter to " _DEFAULT_ENTER_MSG_SUFFIX = "" @@ -48,11 +48,11 @@ Function = Callable[..., Any] -Decorated = Union[Type[Any], Function] +Decorated = Union[type[Any], Function] def log( - decorated: Optional[Decorated] = None, + decorated: Decorated | None = None, *, prefix_enter_msg: str = _DEFAULT_ENTER_MSG_PREFIX, suffix_enter_msg: str = _DEFAULT_ENTER_MSG_SUFFIX, @@ -85,11 +85,11 @@ def _make_decorator( def decorator(decorated: Decorated): decorated_logger = _get_logger(decorated) - def decorator_class(clazz: Type[Any]) -> Type[Any]: + def decorator_class(clazz: type[Any]) -> type[Any]: _decorate_class_members_with_logs(clazz) return clazz - def _decorate_class_members_with_logs(clazz: Type[Any]) -> None: + def _decorate_class_members_with_logs(clazz: type[Any]) -> None: members = getmembers( clazz, predicate=lambda val: ismethod(val) or isfunction(val) ) @@ -160,7 +160,7 @@ def _log_exit_of_function(return_value: Any) -> None: return _wrapper_func if isclass(decorated): - return decorator_class(cast(Type[Any], decorated)) + return decorator_class(cast(type[Any], decorated)) return decorator_func(cast(Function, decorated)) return decorator diff --git a/tests/common/query_context_generator.py b/tests/common/query_context_generator.py index 15b013dc845c2..32b40639742e6 100644 --- a/tests/common/query_context_generator.py +++ b/tests/common/query_context_generator.py @@ -16,7 +16,7 @@ # under the License. import copy import dataclasses -from typing import Any, Dict, List, Optional +from typing import Any, Optional from superset.common.chart_data import ChartDataResultType from superset.utils.core import AnnotationType, DTTM_ALIAS @@ -42,7 +42,7 @@ "where": "", } -QUERY_OBJECTS: Dict[str, Dict[str, object]] = { +QUERY_OBJECTS: dict[str, dict[str, object]] = { "birth_names": query_birth_names, # `:suffix` are overrides only "birth_names:include_time": { @@ -205,7 +205,7 @@ def get_query_object( query_name: str, add_postprocessing_operations: bool, add_time_offsets: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: if query_name not in QUERY_OBJECTS: raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") obj = QUERY_OBJECTS[query_name] @@ -227,7 +227,7 @@ def get_query_object( return query_object -def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]: +def _get_postprocessing_operation(query_name: str) -> list[dict[str, Any]]: if query_name not in QUERY_OBJECTS: raise Exception( f"Post-processing fixture not defined for datasource: {query_name}" @@ -250,8 +250,8 @@ def generate( add_time_offsets: bool = False, table_id=1, table_type="table", - form_data: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + form_data: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: form_data = form_data or {} table_name = query_name.split(":")[0] table = self.get_table(table_name, table_id, table_type) diff --git a/tests/example_data/data_generator/base_generator.py b/tests/example_data/data_generator/base_generator.py index 023b929091439..38ab2e5413d0a 100644 --- a/tests/example_data/data_generator/base_generator.py +++ b/tests/example_data/data_generator/base_generator.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any class ExampleDataGenerator(ABC): @abstractmethod - def generate(self) -> Iterable[Dict[Any, Any]]: + def generate(self) -> Iterable[dict[Any, Any]]: ... diff --git a/tests/example_data/data_generator/birth_names/birth_names_generator.py b/tests/example_data/data_generator/birth_names/birth_names_generator.py index 2b68abbd4f12b..a8e8c45e280a4 100644 --- a/tests/example_data/data_generator/birth_names/birth_names_generator.py +++ b/tests/example_data/data_generator/birth_names/birth_names_generator.py @@ -16,9 +16,10 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable from datetime import datetime from random import choice, randint -from typing import Any, Dict, Iterable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from tests.consts.birth_names import ( BOY, @@ -58,7 +59,7 @@ def __init__( self._until_not_include_year = start_year + years_amount self._rows_per_year = rows_per_year - def generate(self) -> Iterable[Dict[Any, Any]]: + def generate(self) -> Iterable[dict[Any, Any]]: for year in range(self._start_year, self._until_not_include_year): ds = self._make_year(year) for _ in range(self._rows_per_year): @@ -67,7 +68,7 @@ def generate(self) -> Iterable[Dict[Any, Any]]: def _make_year(self, year: int): return datetime(year, 1, 1, 0, 0, 0) - def generate_row(self, dt: datetime) -> Dict[Any, Any]: + def generate_row(self, dt: datetime) -> dict[Any, Any]: gender = choice([BOY, GIRL]) num = randint(1, 100000) return { diff --git a/tests/example_data/data_loading/data_definitions/types.py b/tests/example_data/data_loading/data_definitions/types.py index e393019e0192f..a1ed1043489f8 100644 --- a/tests/example_data/data_loading/data_definitions/types.py +++ b/tests/example_data/data_loading/data_definitions/types.py @@ -24,8 +24,9 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Dict, Iterable, Optional +from typing import Any, Optional from sqlalchemy.types import TypeEngine @@ -33,14 +34,14 @@ @dataclass class TableMetaData: table_name: str - types: Optional[Dict[str, TypeEngine]] + types: Optional[dict[str, TypeEngine]] @dataclass class Table: table_name: str table_metadata: TableMetaData - data: Iterable[Dict[Any, Any]] + data: Iterable[dict[Any, Any]] class TableMetaDataFactory(ABC): @@ -48,6 +49,6 @@ class TableMetaDataFactory(ABC): def make(self) -> TableMetaData: ... - def make_table(self, data: Iterable[Dict[Any, Any]]) -> Table: + def make_table(self, data: Iterable[dict[Any, Any]]) -> Table: metadata = self.make() return Table(metadata.table_name, metadata, data) diff --git a/tests/example_data/data_loading/pandas/pandas_data_loader.py b/tests/example_data/data_loading/pandas/pandas_data_loader.py index 7f41602054e18..49dcf3b2db725 100644 --- a/tests/example_data/data_loading/pandas/pandas_data_loader.py +++ b/tests/example_data/data_loading/pandas/pandas_data_loader.py @@ -17,7 +17,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from pandas import DataFrame from sqlalchemy.inspection import inspect @@ -63,10 +63,10 @@ def load_table(self, table: Table) -> None: schema=self._detect_schema_name(), ) - def _detect_schema_name(self) -> Optional[str]: + def _detect_schema_name(self) -> str | None: return inspect(self._db_engine).default_schema_name - def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]: + def _take_data_types(self, table: Table) -> dict[str, str] | None: if metadata_table := table.table_metadata: types = metadata_table.types if types: diff --git a/tests/example_data/data_loading/pandas/pands_data_loading_conf.py b/tests/example_data/data_loading/pandas/pands_data_loading_conf.py index 1c43adc9316e9..8de12b39eff48 100644 --- a/tests/example_data/data_loading/pandas/pands_data_loading_conf.py +++ b/tests/example_data/data_loading/pandas/pands_data_loading_conf.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict +from typing import Any default_pandas_data_loader_config = { "if_exists": "replace", @@ -54,7 +54,7 @@ def __init__( self.support_datetime_type = support_datetime_type @classmethod - def make_from_dict(cls, _dict: Dict[str, Any]) -> PandasLoaderConfigurations: + def make_from_dict(cls, _dict: dict[str, Any]) -> PandasLoaderConfigurations: copy_dict = default_pandas_data_loader_config.copy() copy_dict.update(_dict) return PandasLoaderConfigurations(**copy_dict) # type: ignore diff --git a/tests/example_data/data_loading/pandas/table_df_convertor.py b/tests/example_data/data_loading/pandas/table_df_convertor.py index e801c8464e9e8..aad1077ce5bf7 100644 --- a/tests/example_data/data_loading/pandas/table_df_convertor.py +++ b/tests/example_data/data_loading/pandas/table_df_convertor.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from pandas import DataFrame @@ -30,10 +30,10 @@ @log class TableToDfConvertorImpl(TableToDfConvertor): convert_datetime_to_str: bool - _time_format: Optional[str] + _time_format: str | None def __init__( - self, convert_ds_to_datetime: bool, time_format: Optional[str] = None + self, convert_ds_to_datetime: bool, time_format: str | None = None ) -> None: self.convert_datetime_to_str = convert_ds_to_datetime self._time_format = time_format diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 38fd10524019f..79fdff634623e 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -516,7 +516,7 @@ def test_request_access(self): ) self.assertEqual( access_request3.roles_with_datasource, - "
    • {}
    ".format(approve_link_3), + f"
    • {approve_link_3}
    ", ) # cleanup diff --git a/tests/integration_tests/advanced_data_type/api_tests.py b/tests/integration_tests/advanced_data_type/api_tests.py index 5bfe308e1683b..e865069462e22 100644 --- a/tests/integration_tests/advanced_data_type/api_tests.py +++ b/tests/integration_tests/advanced_data_type/api_tests.py @@ -24,7 +24,7 @@ from tests.integration_tests.utils.get_dashboards import get_dashboards_ids from unittest import mock from sqlalchemy import Column -from typing import Any, List +from typing import Any from superset.advanced_data_type.types import ( AdvancedDataType, AdvancedDataTypeRequest, @@ -52,7 +52,7 @@ def translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse: return target_resp -def translate_filter_func(col: Column, op: FilterOperator, values: List[Any]): +def translate_filter_func(col: Column, op: FilterOperator, values: list[Any]): pass diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index f70f0f63bde36..fec66f88d2da6 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -20,7 +20,7 @@ import imp import json from contextlib import contextmanager -from typing import Any, Dict, Union, List, Optional +from typing import Any, Union, Optional from unittest.mock import Mock, patch, MagicMock import pandas as pd @@ -67,12 +67,12 @@ def get_resp( else: resp = client.get(url, follow_redirects=follow_redirects) if raise_on_error and resp.status_code > 400: - raise Exception("http request failed with code {}".format(resp.status_code)) + raise Exception(f"http request failed with code {resp.status_code}") return resp.data.decode("utf-8") def post_assert_metric( - client: Any, uri: str, data: Dict[str, Any], func_name: str + client: Any, uri: str, data: dict[str, Any], func_name: str ) -> Response: """ Simple client post with an extra assertion for statsd metrics @@ -121,7 +121,7 @@ def get_birth_names_dataset() -> SqlaTable: @staticmethod def create_user_with_roles( - username: str, roles: List[str], should_create_roles: bool = False + username: str, roles: list[str], should_create_roles: bool = False ): user_to_create = security_manager.find_user(username) if not user_to_create: @@ -485,12 +485,12 @@ def delete_assert_metric(self, uri: str, func_name: str) -> Response: return rv def post_assert_metric( - self, uri: str, data: Dict[str, Any], func_name: str + self, uri: str, data: dict[str, Any], func_name: str ) -> Response: return post_assert_metric(self.client, uri, data, func_name) def put_assert_metric( - self, uri: str, data: Dict[str, Any], func_name: str + self, uri: str, data: dict[str, Any], func_name: str ) -> Response: """ Simple client put with an extra assertion for statsd metrics diff --git a/tests/integration_tests/cachekeys/api_tests.py b/tests/integration_tests/cachekeys/api_tests.py index d3552bfc8df26..c867ce7f5135a 100644 --- a/tests/integration_tests/cachekeys/api_tests.py +++ b/tests/integration_tests/cachekeys/api_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from typing import Dict, Any +from typing import Any import pytest @@ -31,7 +31,7 @@ @pytest.fixture def invalidate(test_client, login_as_admin): - def _invalidate(params: Dict[str, Any]): + def _invalidate(params: dict[str, Any]): return post_assert_metric( test_client, "api/v1/cachekey/invalidate", params, "invalidate" ) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 99b32752814ba..fa09e56675547 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -17,7 +17,6 @@ # isort:skip_file """Unit tests for Superset""" import json -import logging from io import BytesIO from zipfile import is_zipfile, ZipFile diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index b02ccb5b965f1..f9e6b5e3b1545 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -21,7 +21,7 @@ import copy from datetime import datetime from io import BytesIO -from typing import Any, Dict, Optional, List +from typing import Any, Optional from unittest import mock from zipfile import ZipFile @@ -740,11 +740,11 @@ def test_with_series_limit(self): data = rv.json["result"][0]["data"] - unique_names = set(row["name"] for row in data) + unique_names = {row["name"] for row in data} self.maxDiff = None self.assertEqual(len(unique_names), SERIES_LIMIT) self.assertEqual( - set(column for column in data[0].keys()), {"state", "name", "sum__num"} + {column for column in data[0].keys()}, {"state", "name", "sum__num"} ) @pytest.mark.usefixtures( @@ -1124,7 +1124,7 @@ def test_chart_data_with_incompatible_adhoc_column(self): @pytest.fixture() -def physical_query_context(physical_dataset) -> Dict[str, Any]: +def physical_query_context(physical_dataset) -> dict[str, Any]: return { "datasource": { "type": physical_dataset.type, @@ -1218,7 +1218,7 @@ def test_data_cache_default_timeout( def test_chart_cache_timeout( - load_energy_table_with_slice: List[Slice], + load_energy_table_with_slice: list[Slice], test_client, login_as_admin, physical_query_context, diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 0ea5bb5106b15..28da7b79133b1 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -19,7 +19,7 @@ import contextlib import functools import os -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from unittest.mock import patch import pytest @@ -55,7 +55,7 @@ def test_client(app_context: AppContext): @pytest.fixture -def login_as(test_client: "FlaskClient[Any]"): +def login_as(test_client: FlaskClient[Any]): """Fixture with app context and logged in admin user.""" def _login_as(username: str, password: str = "general"): @@ -160,7 +160,7 @@ def drop_from_schema(engine: Engine, schema_name: str): @pytest.fixture(scope="session") def example_db_provider() -> Callable[[], Database]: # type: ignore class _example_db_provider: - _db: Optional[Database] = None + _db: Database | None = None def __call__(self) -> Database: with app.app_context(): @@ -257,7 +257,7 @@ def wrapper(*args, **kwargs): return decorate -def with_config(override_config: Dict[str, Any]): +def with_config(override_config: dict[str, Any]): """ Use this decorator to mock specific config keys. diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 2e9e287620417..f0c72b068036b 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -23,7 +23,6 @@ import io import json import logging -from typing import Dict, List from urllib.parse import quote import superset.utils.database @@ -37,7 +36,6 @@ import pytest import pytz import random -import re import unittest from unittest import mock @@ -496,7 +494,7 @@ def test_testconn_failed_conn(self, username="admin"): assert response.headers["Content-Type"] == "application/json" response_body = json.loads(response.data.decode("utf-8")) expected_body = {"error": "Could not load database driver: broken"} - assert response_body == expected_body, "%s != %s" % ( + assert response_body == expected_body, "{} != {}".format( response_body, expected_body, ) @@ -515,7 +513,7 @@ def test_testconn_failed_conn(self, username="admin"): assert response.headers["Content-Type"] == "application/json" response_body = json.loads(response.data.decode("utf-8")) expected_body = {"error": "Could not load database driver: mssql+pymssql"} - assert response_body == expected_body, "%s != %s" % ( + assert response_body == expected_body, "{} != {}".format( response_body, expected_body, ) @@ -563,7 +561,7 @@ def test_databaseview_edit(self, username="admin"): self.login(username=username) database = superset.utils.database.get_example_database() sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted - url = "databaseview/edit/{}".format(database.id) + url = f"databaseview/edit/{database.id}" data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns} data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri() self.client.post(url, data=data) @@ -582,7 +580,7 @@ def test_databaseview_edit(self, username="admin"): def test_warm_up_cache(self): self.login() slc = self.get_slice("Girls", db.session) - data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id)) + data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") self.assertEqual( data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] ) @@ -609,7 +607,7 @@ def test_cache_logging(self): store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True girls_slice = self.get_slice("Girls", db.session) - self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id)) + self.get_json_resp(f"/superset/warm_up_cache?slice_id={girls_slice.id}") ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() assert ck.datasource_uid == f"{girls_slice.table.id}__table" app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys @@ -650,7 +648,7 @@ def test_kv_enabled(self): kv_value = kv.value self.assertEqual(json.loads(value), json.loads(kv_value)) - resp = self.client.get("/kv/{}/".format(kv.id)) + resp = self.client.get(f"/kv/{kv.id}/") self.assertEqual(resp.status_code, 200) self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8"))) @@ -662,7 +660,7 @@ def test_gamma(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_csv_endpoint(self): self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] get_name_sql = """ SELECT name FROM birth_names @@ -676,17 +674,17 @@ def test_csv_endpoint(self): WHERE name = '{name}' LIMIT 1 """ - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] self.run_sql(sql, client_id, raise_on_error=True) - resp = self.get_resp("/superset/csv/{}".format(client_id)) + resp = self.get_resp(f"/superset/csv/{client_id}") data = csv.reader(io.StringIO(resp)) expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] self.run_sql(sql, client_id, raise_on_error=True) - resp = self.get_resp("/superset/csv/{}".format(client_id)) + resp = self.get_resp(f"/superset/csv/{client_id}") data = csv.reader(io.StringIO(resp)) expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) @@ -704,7 +702,7 @@ def test_extra_table_metadata(self): def test_required_params_in_sql_json(self): self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] data = {"client_id": client_id} rv = self.client.post( @@ -876,12 +874,12 @@ def test_slice_id_is_always_logged_correctly_on_web_request(self): self.get_resp(slc.slice_url) self.assertEqual(1, qry.count()) - def create_sample_csvfile(self, filename: str, content: List[str]) -> None: + def create_sample_csvfile(self, filename: str, content: list[str]) -> None: with open(filename, "w+") as test_file: for l in content: test_file.write(f"{l}\n") - def create_sample_excelfile(self, filename: str, content: Dict[str, str]) -> None: + def create_sample_excelfile(self, filename: str, content: dict[str, str]) -> None: pd.DataFrame(content).to_excel(filename) def enable_csv_upload(self, database: models.Database) -> None: diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 91a76f97cf298..9bc204ff06b45 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -20,7 +20,7 @@ import logging import os import shutil -from typing import Dict, Optional, Union +from typing import Optional, Union from unittest import mock @@ -132,7 +132,7 @@ def get_upload_db(): def upload_csv( filename: str, table_name: str, - extra: Optional[Dict[str, str]] = None, + extra: Optional[dict[str, str]] = None, dtype: Union[str, None] = None, ): csv_upload_db_id = get_upload_db().id @@ -155,7 +155,7 @@ def upload_csv( def upload_excel( - filename: str, table_name: str, extra: Optional[Dict[str, str]] = None + filename: str, table_name: str, extra: Optional[dict[str, str]] = None ): excel_upload_db_id = get_upload_db().id form_data = { @@ -175,7 +175,7 @@ def upload_excel( def upload_columnar( - filename: str, table_name: str, extra: Optional[Dict[str, str]] = None + filename: str, table_name: str, extra: Optional[dict[str, str]] = None ): columnar_upload_db_id = get_upload_db().id form_data = { @@ -218,7 +218,7 @@ def mock_upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: def escaped_double_quotes(text): - return f"\"{text}\"" + return rf"\"{text}\"" def escaped_parquet(text): diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py index d54151db83c2d..669bc936934e5 100644 --- a/tests/integration_tests/dashboard_tests.py +++ b/tests/integration_tests/dashboard_tests.py @@ -115,7 +115,7 @@ def load_dashboard(self): def get_mock_positions(self, dash): positions = {"DASHBOARD_VERSION_KEY": "v2"} for i, slc in enumerate(dash.slices): - id = "DASHBOARD_CHART_TYPE-{}".format(i) + id = f"DASHBOARD_CHART_TYPE-{i}" d = { "type": "CHART", "id": id, @@ -167,7 +167,7 @@ def test_save_dash(self, username="admin"): # set a further modified_time for unit test "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" resp = self.get_resp(url, data=dict(data=json.dumps(data))) self.assertIn("SUCCESS", resp) @@ -189,7 +189,7 @@ def test_save_dash_with_filter(self, username="admin"): "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" resp = self.get_resp(url, data=dict(data=json.dumps(data))) self.assertIn("SUCCESS", resp) @@ -217,7 +217,7 @@ def test_save_dash_with_invalid_filters(self, username="admin"): "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" resp = self.get_resp(url, data=dict(data=json.dumps(data))) self.assertIn("SUCCESS", resp) @@ -239,7 +239,7 @@ def test_save_dash_with_dashboard_title(self, username="admin"): # set a further modified_time for unit test "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" self.get_resp(url, data=dict(data=json.dumps(data))) updatedDash = db.session.query(Dashboard).filter_by(slug="births").first() self.assertEqual(updatedDash.dashboard_title, "new title") @@ -264,7 +264,7 @@ def test_save_dash_with_colors(self, username="admin"): # set a further modified_time for unit test "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" self.get_resp(url, data=dict(data=json.dumps(data))) updatedDash = db.session.query(Dashboard).filter_by(slug="births").first() self.assertIn("color_namespace", updatedDash.json_metadata) @@ -301,13 +301,13 @@ def test_copy_dash(self, username="admin"): # Save changes to Births dashboard and retrieve updated dash dash_id = dash.id - url = "/superset/save_dash/{}/".format(dash_id) + url = f"/superset/save_dash/{dash_id}/" self.client.post(url, data=dict(data=json.dumps(data))) dash = db.session.query(Dashboard).filter_by(id=dash_id).first() orig_json_data = dash.data # Verify that copy matches original - url = "/superset/copy_dash/{}/".format(dash_id) + url = f"/superset/copy_dash/{dash_id}/" resp = self.get_json_resp(url, data=dict(data=json.dumps(data))) self.assertEqual(resp["dashboard_title"], "Copy Of Births") self.assertEqual(resp["position_json"], orig_json_data["position_json"]) @@ -334,7 +334,7 @@ def test_add_slices(self, username="admin"): data = { "slice_ids": [new_slice.data["slice_id"], existing_slice.data["slice_id"]] } - url = "/superset/add_slices/{}/".format(dash.id) + url = f"/superset/add_slices/{dash.id}/" resp = self.client.post(url, data=dict(data=json.dumps(data))) assert "SLICES ADDED" in resp.data.decode("utf-8") @@ -375,7 +375,7 @@ def test_remove_slices(self, username="admin"): # save dash dash_id = dash.id - url = "/superset/save_dash/{}/".format(dash_id) + url = f"/superset/save_dash/{dash_id}/" self.client.post(url, data=dict(data=json.dumps(data))) dash = db.session.query(Dashboard).filter_by(id=dash_id).first() diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index bea724dafc95d..6c3d000051f35 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -17,7 +17,7 @@ """Utils to provide dashboards for tests""" import json -from typing import Any, Dict, List, Optional +from typing import Optional from pandas import DataFrame @@ -65,7 +65,7 @@ def create_table_metadata( def create_slice( - title: str, viz_type: str, table: SqlaTable, slices_dict: Dict[str, str] + title: str, viz_type: str, table: SqlaTable, slices_dict: dict[str, str] ) -> Slice: return Slice( slice_name=title, @@ -77,7 +77,7 @@ def create_slice( def create_dashboard( - slug: str, title: str, position: str, slices: List[Slice] + slug: str, title: str, position: str, slices: list[Slice] ) -> Dashboard: dash = db.session.query(Dashboard).filter_by(slug=slug).one_or_none() if dash: diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index f57afd95a60fe..49a6bbecbc85f 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -19,7 +19,7 @@ import json from io import BytesIO from time import sleep -from typing import List, Optional +from typing import Optional from unittest.mock import ANY, patch from zipfile import is_zipfile, ZipFile @@ -66,7 +66,7 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): resource_name = "dashboard" - dashboards: List[Dashboard] = [] + dashboards: list[Dashboard] = [] dashboard_data = { "dashboard_title": "title1_changed", "slug": "slug1_changed", @@ -80,10 +80,10 @@ def insert_dashboard( self, dashboard_title: str, slug: Optional[str], - owners: List[int], - roles: List[int] = [], + owners: list[int], + roles: list[int] = [], created_by=None, - slices: Optional[List[Slice]] = None, + slices: Optional[list[Slice]] = None, position_json: str = "", css: str = "", json_metadata: str = "", @@ -211,9 +211,9 @@ def test_get_dashboard_datasets(self): self.assertEqual(response.status_code, 200) data = json.loads(response.data.decode("utf-8")) dashboard = Dashboard.get("world_health") - expected_dataset_ids = set([s.datasource_id for s in dashboard.slices]) + expected_dataset_ids = {s.datasource_id for s in dashboard.slices} result = data["result"] - actual_dataset_ids = set([dataset["id"] for dataset in result]) + actual_dataset_ids = {dataset["id"] for dataset in result} self.assertEqual(actual_dataset_ids, expected_dataset_ids) expected_values = [0, 1] if backend() == "presto" else [0, 1, 2] self.assertEqual(result[0]["column_types"], expected_values) @@ -927,7 +927,7 @@ def create_invalid_dashboard_import(self): buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("sql/dump.sql", "w") as fp: - fp.write("CREATE TABLE foo (bar INT)".encode()) + fp.write(b"CREATE TABLE foo (bar INT)") buf.seek(0) return buf diff --git a/tests/integration_tests/dashboards/base_case.py b/tests/integration_tests/dashboards/base_case.py index a0a1ff630f08d..db85cd6409a12 100644 --- a/tests/integration_tests/dashboards/base_case.py +++ b/tests/integration_tests/dashboards/base_case.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Any, Dict, Union +from typing import Any, Union import prison from flask import Response @@ -49,13 +49,13 @@ def get_dashboards_api_response(self) -> Response: return self.client.get(DASHBOARDS_API_URL) def save_dashboard_via_view( - self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any] + self, dashboard_id: Union[str, int], dashboard_data: dict[str, Any] ) -> Response: save_dash_url = SAVE_DASHBOARD_URL_FORMAT.format(dashboard_id) return self.get_resp(save_dash_url, data=dict(data=json.dumps(dashboard_data))) def save_dashboard( - self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any] + self, dashboard_id: Union[str, int], dashboard_data: dict[str, Any] ) -> Response: return self.save_dashboard_via_view(dashboard_id, dashboard_data) diff --git a/tests/integration_tests/dashboards/dashboard_test_utils.py b/tests/integration_tests/dashboards/dashboard_test_utils.py index df2687fba939f..ee8001cdba78f 100644 --- a/tests/integration_tests/dashboards/dashboard_test_utils.py +++ b/tests/integration_tests/dashboards/dashboard_test_utils.py @@ -17,7 +17,7 @@ import logging import random import string -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from sqlalchemy import func @@ -32,10 +32,10 @@ session = appbuilder.get_session -def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]: +def get_mock_positions(dashboard: Dashboard) -> dict[str, Any]: positions = {"DASHBOARD_VERSION_KEY": "v2"} for i, slc in enumerate(dashboard.slices): - id_ = "DASHBOARD_CHART_TYPE-{}".format(i) + id_ = f"DASHBOARD_CHART_TYPE-{i}" position_data: Any = { "type": "CHART", "id": id_, @@ -48,7 +48,7 @@ def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]: def build_save_dash_parts( dashboard_slug: Optional[str] = None, dashboard_to_edit: Optional[Dashboard] = None -) -> Tuple[Dashboard, Dict[str, Any], Dict[str, Any]]: +) -> tuple[Dashboard, dict[str, Any], dict[str, Any]]: if not dashboard_to_edit: dashboard_slug = ( dashboard_slug if dashboard_slug else DEFAULT_DASHBOARD_SLUG_TO_TEST @@ -68,7 +68,7 @@ def build_save_dash_parts( return dashboard_to_edit, data_before_change, data_after_change -def get_all_dashboards() -> List[Dashboard]: +def get_all_dashboards() -> list[Dashboard]: return db.session.query(Dashboard).all() diff --git a/tests/integration_tests/dashboards/filter_sets/conftest.py b/tests/integration_tests/dashboards/filter_sets/conftest.py index b7a28273b0a7e..b19e929f9d1f7 100644 --- a/tests/integration_tests/dashboards/filter_sets/conftest.py +++ b/tests/integration_tests/dashboards/filter_sets/conftest.py @@ -17,7 +17,8 @@ from __future__ import annotations import json -from typing import Any, Dict, Generator, List, TYPE_CHECKING +from collections.abc import Generator +from typing import Any, TYPE_CHECKING import pytest @@ -67,7 +68,7 @@ @pytest.fixture(autouse=True, scope="module") -def test_users() -> Generator[Dict[str, int], None, None]: +def test_users() -> Generator[dict[str, int], None, None]: usernames = [ ADMIN_USERNAME_FOR_TEST, DASHBOARD_OWNER_USERNAME, @@ -82,16 +83,16 @@ def test_users() -> Generator[Dict[str, int], None, None]: delete_users(usernames_to_ids) -def delete_users(usernames_to_ids: Dict[str, int]) -> None: +def delete_users(usernames_to_ids: dict[str, int]) -> None: for username in usernames_to_ids.keys(): db.session.delete(security_manager.find_user(username)) db.session.commit() def create_test_users( - admin_role: Role, filter_set_role: Role, usernames: List[str] -) -> Dict[str, int]: - users: List[User] = [] + admin_role: Role, filter_set_role: Role, usernames: list[str] +) -> dict[str, int]: + users: list[User] = [] for username in usernames: user = build_user(username, filter_set_role, admin_role) users.append(user) @@ -108,7 +109,7 @@ def build_user(username: str, filter_set_role: Role, admin_role: Role) -> User: if not user: user = security_manager.find_user(username) if user is None: - raise Exception("Failed to build the user {}".format(username)) + raise Exception(f"Failed to build the user {username}") return user @@ -118,7 +119,7 @@ def build_filter_set_role() -> Role: all_datasource_view_name: ViewMenu = security_manager.find_view_menu( "all_datasource_access" ) - pvms: List[PermissionView] = security_manager.find_permissions_view_menu( + pvms: list[PermissionView] = security_manager.find_permissions_view_menu( filterset_view_name ) + security_manager.find_permissions_view_menu(all_datasource_view_name) for pvm in pvms: @@ -167,8 +168,8 @@ def dashboard_id(dashboard: Dashboard) -> Generator[int, None, None]: @pytest.fixture def filtersets( - dashboard_id: int, test_users: Dict[str, int], dumped_valid_json_metadata: str -) -> Generator[Dict[str, List[FilterSet]], None, None]: + dashboard_id: int, test_users: dict[str, int], dumped_valid_json_metadata: str +) -> Generator[dict[str, list[FilterSet]], None, None]: first_filter_set = FilterSet( name="filter_set_1_of_" + str(dashboard_id), dashboard_id=dashboard_id, @@ -216,17 +217,17 @@ def filtersets( @pytest.fixture -def filterset_id(filtersets: Dict[str, List[FilterSet]]) -> int: +def filterset_id(filtersets: dict[str, list[FilterSet]]) -> int: return filtersets["Dashboard"][0].id @pytest.fixture -def valid_json_metadata() -> Dict[str, Any]: +def valid_json_metadata() -> dict[str, Any]: return {"nativeFilters": {}} @pytest.fixture -def dumped_valid_json_metadata(valid_json_metadata: Dict[str, Any]) -> str: +def dumped_valid_json_metadata(valid_json_metadata: dict[str, Any]) -> str: return json.dumps(valid_json_metadata) @@ -238,7 +239,7 @@ def exists_user_id() -> int: @pytest.fixture def valid_filter_set_data_for_create( dashboard_id: int, dumped_valid_json_metadata: str, exists_user_id: int -) -> Dict[str, Any]: +) -> dict[str, Any]: name = "test_filter_set_of_dashboard_" + str(dashboard_id) return { NAME_FIELD: name, @@ -252,7 +253,7 @@ def valid_filter_set_data_for_create( @pytest.fixture def valid_filter_set_data_for_update( dashboard_id: int, dumped_valid_json_metadata: str, exists_user_id: int -) -> Dict[str, Any]: +) -> dict[str, Any]: name = "name_changed_test_filter_set_of_dashboard_" + str(dashboard_id) return { NAME_FIELD: name, @@ -273,13 +274,13 @@ def not_exists_user_id() -> int: @pytest.fixture() def dashboard_based_filter_set_dict( - filtersets: Dict[str, List[FilterSet]] -) -> Dict[str, Any]: + filtersets: dict[str, list[FilterSet]] +) -> dict[str, Any]: return filtersets["Dashboard"][0].to_dict() @pytest.fixture() def user_based_filter_set_dict( - filtersets: Dict[str, List[FilterSet]] -) -> Dict[str, Any]: + filtersets: dict[str, list[FilterSet]] +) -> dict[str, Any]: return filtersets[FILTER_SET_OWNER_USERNAME][0].to_dict() diff --git a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py index b5d1919dd430a..9891266101677 100644 --- a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict +from typing import Any from flask.testing import FlaskClient @@ -42,11 +42,11 @@ from tests.integration_tests.test_app import login -def assert_filterset_was_not_created(filter_set_data: Dict[str, Any]) -> None: +def assert_filterset_was_not_created(filter_set_data: dict[str, Any]) -> None: assert get_filter_set_by_name(str(filter_set_data["name"])) is None -def assert_filterset_was_created(filter_set_data: Dict[str, Any]) -> None: +def assert_filterset_was_created(filter_set_data: dict[str, Any]) -> None: assert get_filter_set_by_name(filter_set_data["name"]) is not None @@ -54,7 +54,7 @@ class TestCreateFilterSetsApi: def test_with_extra_field__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -74,7 +74,7 @@ def test_with_extra_field__400( def test_with_id_field__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -94,7 +94,7 @@ def test_with_id_field__400( def test_with_dashboard_not_exists__404( self, not_exists_dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # act @@ -110,7 +110,7 @@ def test_with_dashboard_not_exists__404( def test_without_name__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -129,7 +129,7 @@ def test_without_name__400( def test_with_none_name__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -148,7 +148,7 @@ def test_with_none_name__400( def test_with_int_as_name__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -167,7 +167,7 @@ def test_with_int_as_name__400( def test_without_description__201( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -186,7 +186,7 @@ def test_without_description__201( def test_with_none_description__201( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -205,7 +205,7 @@ def test_with_none_description__201( def test_with_int_as_description__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -224,7 +224,7 @@ def test_with_int_as_description__400( def test_without_json_metadata__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -243,7 +243,7 @@ def test_without_json_metadata__400( def test_with_invalid_json_metadata__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -262,7 +262,7 @@ def test_with_invalid_json_metadata__400( def test_without_owner_type__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -281,7 +281,7 @@ def test_without_owner_type__400( def test_with_invalid_owner_type__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -300,7 +300,7 @@ def test_with_invalid_owner_type__400( def test_without_owner_id_when_owner_type_is_user__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -320,7 +320,7 @@ def test_without_owner_id_when_owner_type_is_user__400( def test_without_owner_id_when_owner_type_is_dashboard__201( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -340,7 +340,7 @@ def test_without_owner_id_when_owner_type_is_dashboard__201( def test_with_not_exists_owner__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], not_exists_user_id: int, client: FlaskClient[Any], ): @@ -361,8 +361,8 @@ def test_with_not_exists_owner__400( def test_when_caller_is_admin_and_owner_is_admin__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -384,8 +384,8 @@ def test_when_caller_is_admin_and_owner_is_admin__201( def test_when_caller_is_admin_and_owner_is_dashboard_owner__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -407,8 +407,8 @@ def test_when_caller_is_admin_and_owner_is_dashboard_owner__201( def test_when_caller_is_admin_and_owner_is_regular_user__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -430,8 +430,8 @@ def test_when_caller_is_admin_and_owner_is_regular_user__201( def test_when_caller_is_admin_and_owner_type_is_dashboard__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -451,8 +451,8 @@ def test_when_caller_is_admin_and_owner_type_is_dashboard__201( def test_when_caller_is_dashboard_owner_and_owner_is_admin__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -474,8 +474,8 @@ def test_when_caller_is_dashboard_owner_and_owner_is_admin__201( def test_when_caller_is_dashboard_owner_and_owner_is_dashboard_owner__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -497,8 +497,8 @@ def test_when_caller_is_dashboard_owner_and_owner_is_dashboard_owner__201( def test_when_caller_is_dashboard_owner_and_owner_is_regular_user__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -520,8 +520,8 @@ def test_when_caller_is_dashboard_owner_and_owner_is_regular_user__201( def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -541,8 +541,8 @@ def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__201( def test_when_caller_is_regular_user_and_owner_is_admin__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -564,8 +564,8 @@ def test_when_caller_is_regular_user_and_owner_is_admin__201( def test_when_caller_is_regular_user_and_owner_is_dashboard_owner__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -587,8 +587,8 @@ def test_when_caller_is_regular_user_and_owner_is_dashboard_owner__201( def test_when_caller_is_regular_user_and_owner_is_regular_user__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -610,8 +610,8 @@ def test_when_caller_is_regular_user_and_owner_is_regular_user__201( def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange diff --git a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py index 7011cb5781282..41d7ea59f71f7 100644 --- a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, @@ -36,11 +36,11 @@ from superset.models.filter_set import FilterSet -def assert_filterset_was_not_deleted(filter_set_dict: Dict[str, Any]) -> None: +def assert_filterset_was_not_deleted(filter_set_dict: dict[str, Any]) -> None: assert get_filter_set_by_name(filter_set_dict["name"]) is not None -def assert_filterset_deleted(filter_set_dict: Dict[str, Any]) -> None: +def assert_filterset_deleted(filter_set_dict: dict[str, Any]) -> None: assert get_filter_set_by_name(filter_set_dict["name"]) is None @@ -48,7 +48,7 @@ class TestDeleteFilterSet: def test_with_dashboard_exists_filterset_not_exists__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -62,7 +62,7 @@ def test_with_dashboard_exists_filterset_not_exists__200( def test_with_dashboard_not_exists_filterset_not_exists__404( self, not_exists_dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -78,7 +78,7 @@ def test_with_dashboard_not_exists_filterset_not_exists__404( def test_with_dashboard_not_exists_filterset_exists__404( self, not_exists_dashboard_id: int, - dashboard_based_filter_set_dict: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -94,9 +94,9 @@ def test_with_dashboard_not_exists_filterset_exists__404( def test_when_caller_is_admin_and_owner_type_is_user__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -110,9 +110,9 @@ def test_when_caller_is_admin_and_owner_type_is_user__200( def test_when_caller_is_admin_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -126,9 +126,9 @@ def test_when_caller_is_admin_and_owner_type_is_dashboard__200( def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -143,9 +143,9 @@ def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403( def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -160,9 +160,9 @@ def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200( def test_when_caller_is_filterset_owner__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -177,9 +177,9 @@ def test_when_caller_is_filterset_owner__200( def test_when_caller_is_regular_user_and_owner_type_is_user__403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -194,9 +194,9 @@ def test_when_caller_is_regular_user_and_owner_type_is_user__403( def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange diff --git a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py index ad40d0e33c859..71c985310d0a0 100644 --- a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Set, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, @@ -66,12 +66,12 @@ def test_dashboards_without_filtersets__200( def test_when_caller_admin__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange login(client, "admin") - expected_ids: Set[int] = collect_all_ids(filtersets) + expected_ids: set[int] = collect_all_ids(filtersets) # act response = call_get_filter_sets(client, dashboard_id) @@ -83,7 +83,7 @@ def test_when_caller_admin__200( def test_when_caller_dashboard_owner__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -100,7 +100,7 @@ def test_when_caller_dashboard_owner__200( def test_when_caller_filterset_owner__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -117,12 +117,12 @@ def test_when_caller_filterset_owner__200( def test_when_caller_regular_user__200( self, dashboard_id: int, - filtersets: Dict[str, List[int]], + filtersets: dict[str, list[int]], client: FlaskClient[Any], ): # arrange login(client, REGULAR_USER) - expected_ids: Set[int] = set() + expected_ids: set[int] = set() # act response = call_get_filter_sets(client, dashboard_id) diff --git a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py index 07db98f617815..a6e895a460732 100644 --- a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py @@ -17,7 +17,7 @@ from __future__ import annotations import json -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from superset.dashboards.filter_sets.consts import ( DESCRIPTION_FIELD, @@ -45,8 +45,8 @@ def merge_two_filter_set_dict( - first: Dict[Any, Any], second: Dict[Any, Any] -) -> Dict[Any, Any]: + first: dict[Any, Any], second: dict[Any, Any] +) -> dict[Any, Any]: for d in [first, second]: if JSON_METADATA_FIELD in d: if PARAMS_PROPERTY not in d: @@ -55,12 +55,12 @@ def merge_two_filter_set_dict( return {**first, **second} -def assert_filterset_was_not_updated(filter_set_dict: Dict[str, Any]) -> None: +def assert_filterset_was_not_updated(filter_set_dict: dict[str, Any]) -> None: assert filter_set_dict == get_filter_set_by_name(filter_set_dict["name"]).to_dict() def assert_filterset_updated( - filter_set_dict_before: Dict[str, Any], data_updated: Dict[str, Any] + filter_set_dict_before: dict[str, Any], data_updated: dict[str, Any] ) -> None: expected_data = merge_two_filter_set_dict(filter_set_dict_before, data_updated) assert expected_data == get_filter_set_by_name(expected_data["name"]).to_dict() @@ -70,7 +70,7 @@ class TestUpdateFilterSet: def test_with_dashboard_exists_filterset_not_exists__404( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -86,7 +86,7 @@ def test_with_dashboard_exists_filterset_not_exists__404( def test_with_dashboard_not_exists_filterset_not_exists__404( self, not_exists_dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -102,7 +102,7 @@ def test_with_dashboard_not_exists_filterset_not_exists__404( def test_with_dashboard_not_exists_filterset_exists__404( self, not_exists_dashboard_id: int, - dashboard_based_filter_set_dict: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -118,8 +118,8 @@ def test_with_dashboard_not_exists_filterset_exists__404( def test_with_extra_field__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -138,8 +138,8 @@ def test_with_extra_field__400( def test_with_id_field__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -158,8 +158,8 @@ def test_with_id_field__400( def test_with_none_name__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -177,8 +177,8 @@ def test_with_none_name__400( def test_with_int_as_name__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -196,8 +196,8 @@ def test_with_int_as_name__400( def test_without_name__200( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -217,8 +217,8 @@ def test_without_name__200( def test_with_none_description__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -236,8 +236,8 @@ def test_with_none_description__400( def test_with_int_as_description__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -255,8 +255,8 @@ def test_with_int_as_description__400( def test_without_description__200( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -276,8 +276,8 @@ def test_without_description__200( def test_with_invalid_json_metadata__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -295,9 +295,9 @@ def test_with_invalid_json_metadata__400( def test_with_json_metadata__200( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], - valid_json_metadata: Dict[Any, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], + valid_json_metadata: dict[Any, Any], client: FlaskClient[Any], ): # arrange @@ -320,8 +320,8 @@ def test_with_json_metadata__200( def test_with_invalid_owner_type__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -339,8 +339,8 @@ def test_with_invalid_owner_type__400( def test_with_user_owner_type__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -358,8 +358,8 @@ def test_with_user_owner_type__400( def test_with_dashboard_owner_type__200( self, - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -382,9 +382,9 @@ def test_with_dashboard_owner_type__200( def test_when_caller_is_admin_and_owner_type_is_user__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -402,9 +402,9 @@ def test_when_caller_is_admin_and_owner_type_is_user__200( def test_when_caller_is_admin_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -422,9 +422,9 @@ def test_when_caller_is_admin_and_owner_type_is_dashboard__200( def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -441,9 +441,9 @@ def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403( def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -462,9 +462,9 @@ def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200( def test_when_caller_is_filterset_owner__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -483,9 +483,9 @@ def test_when_caller_is_filterset_owner__200( def test_when_caller_is_regular_user_and_owner_type_is_user__403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -502,9 +502,9 @@ def test_when_caller_is_regular_user_and_owner_type_is_user__403( def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange diff --git a/tests/integration_tests/dashboards/filter_sets/utils.py b/tests/integration_tests/dashboards/filter_sets/utils.py index a63e4164d8959..d728bf6fc3d71 100644 --- a/tests/integration_tests/dashboards/filter_sets/utils.py +++ b/tests/integration_tests/dashboards/filter_sets/utils.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from superset.models.filter_set import FilterSet from tests.integration_tests.dashboards.filter_sets.consts import FILTER_SET_URI @@ -28,7 +28,7 @@ def call_create_filter_set( - client: FlaskClient[Any], dashboard_id: int, data: Dict[str, Any] + client: FlaskClient[Any], dashboard_id: int, data: dict[str, Any] ) -> Response: uri = FILTER_SET_URI.format(dashboard_id=dashboard_id) return client.post(uri, json=data) @@ -41,8 +41,8 @@ def call_get_filter_sets(client: FlaskClient[Any], dashboard_id: int) -> Respons def call_delete_filter_set( client: FlaskClient[Any], - filter_set_dict_to_update: Dict[str, Any], - dashboard_id: Optional[int] = None, + filter_set_dict_to_update: dict[str, Any], + dashboard_id: int | None = None, ) -> Response: dashboard_id = ( dashboard_id @@ -58,9 +58,9 @@ def call_delete_filter_set( def call_update_filter_set( client: FlaskClient[Any], - filter_set_dict_to_update: Dict[str, Any], - data: Dict[str, Any], - dashboard_id: Optional[int] = None, + filter_set_dict_to_update: dict[str, Any], + data: dict[str, Any], + dashboard_id: int | None = None, ) -> Response: dashboard_id = ( dashboard_id @@ -90,12 +90,12 @@ def get_filter_set_by_dashboard_id(dashboard_id: int) -> FilterSet: def collect_all_ids( - filtersets: Union[Dict[str, List[FilterSet]], List[FilterSet]] -) -> Set[int]: + filtersets: dict[str, list[FilterSet]] | list[FilterSet] +) -> set[int]: if isinstance(filtersets, dict): - filtersets_lists: List[List[FilterSet]] = list(filtersets.values()) - ids: Set[int] = set() - lst: List[FilterSet] + filtersets_lists: list[list[FilterSet]] = list(filtersets.values()) + ids: set[int] = set() + lst: list[FilterSet] for lst in filtersets_lists: ids.update(set(map(lambda fs: fs.id, lst))) return ids diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index b20112334d1d8..3c560a4469d4b 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json -from typing import Iterator +from collections.abc import Iterator from unittest.mock import patch from uuid import uuid3 diff --git a/tests/integration_tests/dashboards/security/base_case.py b/tests/integration_tests/dashboards/security/base_case.py index bbb5fad831166..e60fa96d44798 100644 --- a/tests/integration_tests/dashboards/security/base_case.py +++ b/tests/integration_tests/dashboards/security/base_case.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Optional import pytest from flask import escape, Response @@ -37,8 +37,8 @@ def assert_dashboards_api_response( self, response: Response, expected_counts: int, - expected_dashboards: Optional[List[Dashboard]] = None, - not_expected_dashboards: Optional[List[Dashboard]] = None, + expected_dashboards: Optional[list[Dashboard]] = None, + not_expected_dashboards: Optional[list[Dashboard]] = None, ) -> None: self.assert200(response) response_data = response.json diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py index b160a56a33fbf..88495b03b45cc 100644 --- a/tests/integration_tests/dashboards/superset_factory_util.py +++ b/tests/integration_tests/dashboards/superset_factory_util.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from flask_appbuilder import Model from flask_appbuilder.security.sqla.models import User @@ -50,8 +50,8 @@ def create_dashboard_to_db( dashboard_title: Optional[str] = None, slug: Optional[str] = None, published: bool = False, - owners: Optional[List[User]] = None, - slices: Optional[List[Slice]] = None, + owners: Optional[list[User]] = None, + slices: Optional[list[Slice]] = None, css: str = "", json_metadata: str = "", position_json: str = "", @@ -76,8 +76,8 @@ def create_dashboard( dashboard_title: Optional[str] = None, slug: Optional[str] = None, published: bool = False, - owners: Optional[List[User]] = None, - slices: Optional[List[Slice]] = None, + owners: Optional[list[User]] = None, + slices: Optional[list[Slice]] = None, css: str = "", json_metadata: str = "", position_json: str = "", @@ -107,7 +107,7 @@ def insert_model(dashboard: Model) -> None: def create_slice_to_db( name: Optional[str] = None, datasource_id: Optional[int] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> Slice: slice_ = create_slice(datasource_id, name=name, owners=owners) insert_model(slice_) @@ -119,7 +119,7 @@ def create_slice( datasource_id: Optional[int] = None, datasource: Optional[SqlaTable] = None, name: Optional[str] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> Slice: name = name if name is not None else random_str() owners = owners if owners is not None else [] @@ -149,7 +149,7 @@ def create_slice( def create_datasource_table_to_db( name: Optional[str] = None, db_id: Optional[int] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> SqlaTable: sqltable = create_datasource_table(name, db_id, owners=owners) insert_model(sqltable) @@ -161,7 +161,7 @@ def create_datasource_table( name: Optional[str] = None, db_id: Optional[int] = None, database: Optional[Database] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> SqlaTable: name = name if name is not None else random_str() owners = owners if owners is not None else [] @@ -192,7 +192,7 @@ def delete_all_inserted_objects() -> None: def delete_all_inserted_dashboards(): try: - dashboards_to_delete: List[Dashboard] = ( + dashboards_to_delete: list[Dashboard] = ( session.query(Dashboard) .filter(Dashboard.id.in_(inserted_dashboards_ids)) .all() @@ -241,7 +241,7 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None: def delete_all_inserted_slices(): try: - slices_to_delete: List[Slice] = ( + slices_to_delete: list[Slice] = ( session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all() ) for slice in slices_to_delete: @@ -272,7 +272,7 @@ def delete_slice_users_associations(slice_: Slice) -> None: def delete_all_inserted_tables(): try: - tables_to_delete: List[SqlaTable] = ( + tables_to_delete: list[SqlaTable] = ( session.query(SqlaTable) .filter(SqlaTable.id.in_(inserted_sqltables_ids)) .all() @@ -307,7 +307,7 @@ def delete_table_users_associations(table: SqlaTable) -> None: def delete_all_inserted_dbs(): try: - dbs_to_delete: List[Database] = ( + dbs_to_delete: list[Database] = ( session.query(Database) .filter(Database.id.in_(inserted_databases_ids)) .all() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index a30a951884497..6fa1288067e87 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -23,7 +23,6 @@ from unittest import mock from unittest.mock import patch, MagicMock from zipfile import is_zipfile, ZipFile -from operator import itemgetter import prison import pytest @@ -52,7 +51,6 @@ load_birth_names_dashboard_with_slices, load_birth_names_data, ) -from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_with_slice, load_energy_table_data, @@ -1805,10 +1803,10 @@ def test_database_tables(self): schemas = [ s[0] for s in database.get_all_table_names_in_schema(schema_name) ] - self.assertEquals(response["count"], len(schemas)) + self.assertEqual(response["count"], len(schemas)) for option in response["result"]: - self.assertEquals(option["extra"], None) - self.assertEquals(option["type"], "table") + self.assertEqual(option["extra"], None) + self.assertEqual(option["type"], "table") self.assertTrue(option["value"] in schemas) def test_database_tables_not_found(self): diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 553fae4fbf730..b47d3d89fe108 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import mock, skip +from unittest import skip from unittest.mock import patch import pytest diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py index 86c280b9bb1c4..64bc0d85725ff 100644 --- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import mock, skip +from unittest import mock from unittest.mock import patch import pytest diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 2c358d7114092..6c99efd358c93 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -18,7 +18,7 @@ import json import unittest from io import BytesIO -from typing import List, Optional +from typing import Optional from unittest.mock import ANY, patch from zipfile import is_zipfile, ZipFile @@ -68,7 +68,7 @@ class TestDatasetApi(SupersetTestCase): @staticmethod def insert_dataset( table_name: str, - owners: List[int], + owners: list[int], database: Database, sql: Optional[str] = None, schema: Optional[str] = None, @@ -94,7 +94,7 @@ def insert_default_dataset(self): "ab_permission", [self.get_user("admin").id], get_main_database() ) - def get_fixture_datasets(self) -> List[SqlaTable]: + def get_fixture_datasets(self) -> list[SqlaTable]: return ( db.session.query(SqlaTable) .options(joinedload(SqlaTable.database)) @@ -102,7 +102,7 @@ def get_fixture_datasets(self) -> List[SqlaTable]: .all() ) - def get_fixture_virtual_datasets(self) -> List[SqlaTable]: + def get_fixture_virtual_datasets(self) -> list[SqlaTable]: return ( db.session.query(SqlaTable) .filter(SqlaTable.table_name.in_(self.fixture_virtual_table_names)) @@ -410,13 +410,11 @@ def pg_test_query_parameter(query_parameter, expected_response): ) all_datasets = db.session.query(SqlaTable).all() schema_values = sorted( - set( - [ - dataset.schema - for dataset in all_datasets - if dataset.schema is not None - ] - ) + { + dataset.schema + for dataset in all_datasets + if dataset.schema is not None + } ) expected_response = { "count": len(schema_values), diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 8753b1d273a25..953c34059fdd8 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from operator import itemgetter -from typing import Any, List +from typing import Any from unittest.mock import patch import pytest @@ -312,7 +312,7 @@ def test_import_v0_dataset_ui_export(self): assert len(dataset.metrics) == 2 assert dataset.main_dttm_col == "ds" assert dataset.filter_select_enabled - assert set(col.column_name for col in dataset.columns) == { + assert {col.column_name for col in dataset.columns} == { "num_california", "ds", "state", @@ -526,7 +526,7 @@ def test_import_v1_dataset_existing_database(self, mock_g): db.session.commit() -def _get_table_from_list_by_name(name: str, tables: List[Any]): +def _get_table_from_list_by_name(name: str, tables: list[Any]): for table in tables: if table.table_name == name: return table diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/tests/integration_tests/db_engine_specs/base_tests.py index e20ea35ae4131..2d4f72c4f47e6 100644 --- a/tests/integration_tests/db_engine_specs/base_tests.py +++ b/tests/integration_tests/db_engine_specs/base_tests.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -from datetime import datetime -from typing import Tuple, Type from tests.integration_tests.test_app import app from tests.integration_tests.base_tests import SupersetTestCase diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 2f4f1c70cc82f..c4f04584fa2bb 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import unittest.mock as mock import pytest @@ -95,7 +94,7 @@ def test_fetch_data(self): """ # Mock a google.cloud.bigquery.table.Row - class Row(object): + class Row: def __init__(self, value): self._value = value diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index b63f64ab03cb8..341b494927004 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -15,9 +15,7 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -from datetime import datetime from unittest import mock -from typing import List import pytest import pandas as pd @@ -377,7 +375,7 @@ def test_where_latest_partition_no_columns_no_values(mock_method): def test__latest_partition_from_df(): - def is_correct_result(data: List, result: List) -> bool: + def is_correct_result(data: list, result: list) -> bool: df = pd.DataFrame({"partition": data}) return HiveEngineSpec._latest_partition_from_df(df) == result diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py index de0aa832626ac..6018e59a926e8 100644 --- a/tests/integration_tests/dict_import_export_tests.py +++ b/tests/integration_tests/dict_import_export_tests.py @@ -61,7 +61,7 @@ def create_table( self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[] ): database_name = "main" - name = "{0}{1}".format(NAME_PREFIX, name) + name = f"{NAME_PREFIX}{name}" params = {DBREF: id, "database_name": database_name} if cols_uuids is None: @@ -100,12 +100,12 @@ def assert_table_equals(self, expected_ds, actual_ds): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def assert_datasource_equals(self, expected_ds, actual_ds): @@ -114,12 +114,12 @@ def assert_datasource_equals(self, expected_ds, actual_ds): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def test_import_table_no_metadata(self): diff --git a/tests/integration_tests/email_tests.py b/tests/integration_tests/email_tests.py index 381b8cda1b771..7c7cc1683089f 100644 --- a/tests/integration_tests/email_tests.py +++ b/tests/integration_tests/email_tests.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py index fa965ebd7d208..3b20f6a91887a 100644 --- a/tests/integration_tests/event_logger_tests.py +++ b/tests/integration_tests/event_logger_tests.py @@ -17,8 +17,8 @@ import logging import time import unittest -from datetime import datetime, timedelta -from typing import Any, Callable, cast, Dict, Iterator, Optional, Type, Union +from datetime import timedelta +from typing import Any, Optional from unittest.mock import patch from flask import current_app diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index b9b1bfd0fbcd0..81be2f0de8b6c 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Any, Dict, Iterator +from collections.abc import Iterator +from typing import Any from uuid import uuid3 import pytest @@ -43,7 +44,7 @@ def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice: @pytest.fixture -def form_data(chart) -> Dict[str, Any]: +def form_data(chart) -> dict[str, Any]: datasource = f"{chart.datasource.id}__{chart.datasource.type}" return { "chart_id": chart.id, @@ -68,7 +69,7 @@ def permalink_salt() -> Iterator[str]: def test_post( - form_data: Dict[str, Any], permalink_salt: str, test_client, login_as_admin + form_data: dict[str, Any], permalink_salt: str, test_client, login_as_admin ): resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) assert resp.status_code == 201 @@ -125,7 +126,7 @@ def test_post_invalid_schema(test_client, login_as_admin) -> None: def test_get( - form_data: Dict[str, Any], permalink_salt: str, test_client, login_as_admin + form_data: dict[str, Any], permalink_salt: str, test_client, login_as_admin ) -> None: resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) data = json.loads(resp.data.decode("utf-8")) diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py index 63ed02cd7bd91..eace978d78f26 100644 --- a/tests/integration_tests/explore/permalink/commands_tests.py +++ b/tests/integration_tests/explore/permalink/commands_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import json from unittest.mock import patch import pytest diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index be680a720dd84..d9a4a5d9e02d7 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Optional +from typing import Callable, Optional import pytest @@ -93,7 +93,7 @@ def _create_table( return table -def _cleanup(dash_id: int, slice_ids: List[int]) -> None: +def _cleanup(dash_id: int, slice_ids: list[int]) -> None: schema = get_example_default_schema() for datasource in db.session.query(SqlaTable).filter_by( table_name="birth_names", schema=schema diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index f394d68a0e76b..279b67eda0ccf 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. """Fixtures for test_datasource.py""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any import pytest from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table @@ -31,7 +32,7 @@ from tests.integration_tests.test_app import app -def get_datasource_post() -> Dict[str, Any]: +def get_datasource_post() -> dict[str, Any]: schema = get_example_default_schema() return { diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index effe59a75544a..8b597bf3be0d4 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import random -from typing import Dict, List, Set import pandas as pd import pytest @@ -29,7 +28,7 @@ from tests.integration_tests.dashboard_utils import create_slice, create_table_metadata from tests.integration_tests.test_app import app -misc_dash_slices: Set[str] = set() +misc_dash_slices: set[str] = set() ENERGY_USAGE_TBL_NAME = "energy_usage" @@ -70,7 +69,7 @@ def _get_dataframe(): return pd.DataFrame.from_dict(data) -def _create_energy_table() -> List[Slice]: +def _create_energy_table() -> list[Slice]: table = create_table_metadata( table_name=ENERGY_USAGE_TBL_NAME, database=get_example_database(), @@ -100,7 +99,7 @@ def _create_energy_table() -> List[Slice]: def _create_and_commit_energy_slice( - table: SqlaTable, title: str, viz_type: str, param: Dict[str, str] + table: SqlaTable, title: str, viz_type: str, param: dict[str, str] ): slice = create_slice(title, viz_type, table, param) existing_slice = ( diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index d5c898eba2c9a..5fddb071e2235 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any # example V0 import/export format -dataset_ui_export: List[Dict[str, Any]] = [ +dataset_ui_export: list[dict[str, Any]] = [ { "columns": [ { @@ -48,7 +48,7 @@ } ] -dataset_cli_export: Dict[str, Any] = { +dataset_cli_export: dict[str, Any] = { "databases": [ { "allow_run_async": True, @@ -59,7 +59,7 @@ ] } -dashboard_export: Dict[str, Any] = { +dashboard_export: dict[str, Any] = { "dashboards": [ { "__Dashboard__": { @@ -318,35 +318,35 @@ } # example V1 import/export format -database_metadata_config: Dict[str, Any] = { +database_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "Database", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -dataset_metadata_config: Dict[str, Any] = { +dataset_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "SqlaTable", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -chart_metadata_config: Dict[str, Any] = { +chart_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "Slice", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -dashboard_metadata_config: Dict[str, Any] = { +dashboard_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "Dashboard", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -saved_queries_metadata_config: Dict[str, Any] = { +saved_queries_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "SavedQuery", "timestamp": "2021-03-30T20:37:54.791187+00:00", } -database_config: Dict[str, Any] = { +database_config: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -361,7 +361,7 @@ "version": "1.0.0", } -database_with_ssh_tunnel_config_private_key: Dict[str, Any] = { +database_with_ssh_tunnel_config_private_key: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -383,7 +383,7 @@ "version": "1.0.0", } -database_with_ssh_tunnel_config_password: Dict[str, Any] = { +database_with_ssh_tunnel_config_password: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -404,7 +404,7 @@ "version": "1.0.0", } -database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = { +database_with_ssh_tunnel_config_no_credentials: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -424,7 +424,7 @@ "version": "1.0.0", } -database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = { +database_with_ssh_tunnel_config_mix_credentials: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -446,7 +446,7 @@ "version": "1.0.0", } -database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = { +database_with_ssh_tunnel_config_private_pass_only: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -468,7 +468,7 @@ } -dataset_config: Dict[str, Any] = { +dataset_config: dict[str, Any] = { "table_name": "imported_dataset", "main_dttm_col": None, "description": "This is a dataset that was exported", @@ -513,7 +513,7 @@ "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", } -chart_config: Dict[str, Any] = { +chart_config: dict[str, Any] = { "slice_name": "Deck Path", "viz_type": "deck_path", "params": { @@ -557,7 +557,7 @@ "dataset_uuid": "10808100-158b-42c4-842e-f32b99d88dfb", } -dashboard_config: Dict[str, Any] = { +dashboard_config: dict[str, Any] = { "dashboard_title": "Test dash", "description": None, "css": "", diff --git a/tests/integration_tests/fixtures/query_context.py b/tests/integration_tests/fixtures/query_context.py index 00a3036e01c25..9efa589ba82b2 100644 --- a/tests/integration_tests/fixtures/query_context.py +++ b/tests/integration_tests/fixtures/query_context.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from tests.common.query_context_generator import QueryContextGenerator from tests.integration_tests.base_tests import SupersetTestCase @@ -29,8 +29,8 @@ def get_query_context( query_name: str, add_postprocessing_operations: bool = False, add_time_offsets: bool = False, - form_data: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + form_data: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: """ Create a request payload for retrieving a QueryContext object via the `api/v1/chart/data` endpoint. By default returns a payload corresponding to one diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 561bbe10b2709..18ceba9af20bb 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -17,7 +17,7 @@ import json import string from random import choice, randint, random, uniform -from typing import Any, Dict, List +from typing import Any import pandas as pd import pytest @@ -94,7 +94,7 @@ def create_dashboard_for_loaded_data(): return dash_id_to_delete, slices_ids_to_delete -def _create_world_bank_slices(table: SqlaTable) -> List[Slice]: +def _create_world_bank_slices(table: SqlaTable) -> list[Slice]: from superset.examples.world_bank import create_slices slices = create_slices(table) @@ -102,7 +102,7 @@ def _create_world_bank_slices(table: SqlaTable) -> List[Slice]: return slices -def _commit_slices(slices: List[Slice]): +def _commit_slices(slices: list[Slice]): for slice in slices: o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none() if o: @@ -128,7 +128,7 @@ def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard: return dash -def _cleanup(dash_id: int, slices_ids: List[int]) -> None: +def _cleanup(dash_id: int, slices_ids: list[int]) -> None: dash = db.session.query(Dashboard).filter_by(id=dash_id).first() db.session.delete(dash) for slice_id in slices_ids: @@ -148,7 +148,7 @@ def _get_dataframe(database: Database) -> DataFrame: return df -def _get_world_bank_data() -> List[Dict[Any, Any]]: +def _get_world_bank_data() -> list[dict[Any, Any]]: data = [] for _ in range(100): data.append( diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index 5bbc985a36672..d44745377f562 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -115,7 +115,7 @@ def create_dashboard(self, title, id=0, slcs=[]): dashboard_title=title, slices=slcs, position_json='{"size_y": 2, "size_x": 2}', - slug="{}_imported".format(title.lower()), + slug=f"{title.lower()}_imported", json_metadata=json.dumps(json_metadata), ) @@ -160,12 +160,12 @@ def assert_table_equals(self, expected_ds, actual_ds): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def assert_datasource_equals(self, expected_ds, actual_ds): @@ -174,12 +174,12 @@ def assert_datasource_equals(self, expected_ds, actual_ds): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def assert_slice_equals(self, expected_slc, actual_slc): @@ -404,8 +404,8 @@ def test_import_dashboard_2_slices(self): { "remote_id": 10003, "expanded_slices": { - "{}".format(e_slc.id): True, - "{}".format(b_slc.id): False, + f"{e_slc.id}": True, + f"{b_slc.id}": False, }, # mocked filter_scope metadata "filter_scopes": { @@ -437,8 +437,8 @@ def test_import_dashboard_2_slices(self): } }, "expanded_slices": { - "{}".format(i_e_slc.id): True, - "{}".format(i_b_slc.id): False, + f"{i_e_slc.id}": True, + f"{i_b_slc.id}": False, }, } self.assertEqual( diff --git a/tests/integration_tests/insert_chart_mixin.py b/tests/integration_tests/insert_chart_mixin.py index da05d0c49d043..722e387a543a0 100644 --- a/tests/integration_tests/insert_chart_mixin.py +++ b/tests/integration_tests/insert_chart_mixin.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Optional from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable @@ -29,7 +29,7 @@ class InsertChartMixin: def insert_chart( self, slice_name: str, - owners: List[int], + owners: list[int], datasource_id: int, created_by=None, datasource_type: str = "table", diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py index 66aea8a4edd27..ac33d003e0013 100644 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ b/tests/integration_tests/key_value/commands/fixtures.py @@ -18,7 +18,8 @@ from __future__ import annotations import json -from typing import Generator, TYPE_CHECKING +from collections.abc import Generator +from typing import TYPE_CHECKING from uuid import UUID import pytest diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index d5684b1b62109..c4bc7aa89bd9f 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -417,14 +417,14 @@ def test_get_timestamp_expression(self): assert str(sqla_literal.compile()) == "ds" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": assert compiled == "DATE(ds)" prev_ds_expr = ds_col.expression ds_col.expression = "DATE_ADD(ds, 1)" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": assert compiled == "DATE(DATE_ADD(ds, 1))" ds_col.expression = prev_ds_expr @@ -437,20 +437,20 @@ def test_get_timestamp_expression_epoch(self): ds_col.expression = None ds_col.python_date_format = "epoch_s" sqla_literal = ds_col.get_timestamp_expression(None) - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": self.assertEqual(compiled, "from_unixtime(ds)") ds_col.python_date_format = "epoch_s" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": self.assertEqual(compiled, "DATE(from_unixtime(ds))") prev_ds_expr = ds_col.expression ds_col.expression = "DATE_ADD(ds, 1)" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))") ds_col.expression = prev_ds_expr diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 5e5beae345b86..7a3d4e4a1e873 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -16,7 +16,7 @@ # under the License. import re import time -from typing import Any, Dict +from typing import Any import numpy as np import pandas as pd @@ -49,7 +49,7 @@ from tests.integration_tests.fixtures.query_context import get_query_context -def get_sql_text(payload: Dict[str, Any]) -> str: +def get_sql_text(payload: dict[str, Any]) -> str: payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() diff --git a/tests/integration_tests/reports/alert_tests.py b/tests/integration_tests/reports/alert_tests.py index 32cc2dcefb572..4920a96283d7e 100644 --- a/tests/integration_tests/reports/alert_tests.py +++ b/tests/integration_tests/reports/alert_tests.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel from contextlib import nullcontext -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import pandas as pd import pytest @@ -56,10 +56,10 @@ ], ) def test_execute_query_as_report_executor( - owner_names: List[str], + owner_names: list[str], creator_name: Optional[str], - config: List[ExecutorType], - expected_result: Union[Tuple[ExecutorType, str], Exception], + config: list[ExecutorType], + expected_result: Union[tuple[ExecutorType, str], Exception], mocker: MockFixture, app_context: None, get_user, diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index a81bc6fa66adc..db80079d77079 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -17,7 +17,7 @@ import json from contextlib import contextmanager from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import Optional from unittest.mock import call, Mock, patch from uuid import uuid4 @@ -100,7 +100,7 @@ ) -def get_target_from_report_schedule(report_schedule: ReportSchedule) -> List[str]: +def get_target_from_report_schedule(report_schedule: ReportSchedule) -> list[str]: return [ json.loads(recipient.recipient_config_json)["target"] for recipient in report_schedule.recipients @@ -1976,9 +1976,7 @@ def test__send_with_client_errors(notification_mock, logger_mock): assert excinfo.errisinstance(SupersetException) logger_mock.warning.assert_called_with( - ( - "SupersetError(message='', error_type=, level=, extra=None)" - ) + "SupersetError(message='', error_type=, level=, extra=None)" ) @@ -2021,7 +2019,5 @@ def test__send_with_server_errors(notification_mock, logger_mock): assert excinfo.errisinstance(SupersetException) # it logs the error logger_mock.warning.assert_called_with( - ( - "SupersetError(message='', error_type=, level=, extra=None)" - ) + "SupersetError(message='', error_type=, level=, extra=None)" ) diff --git a/tests/integration_tests/reports/scheduler_tests.py b/tests/integration_tests/reports/scheduler_tests.py index 4b8968592b7e4..3284ee9772469 100644 --- a/tests/integration_tests/reports/scheduler_tests.py +++ b/tests/integration_tests/reports/scheduler_tests.py @@ -16,7 +16,6 @@ # under the License. from random import randint -from typing import List from unittest.mock import patch import pytest @@ -32,7 +31,7 @@ @pytest.fixture -def owners(get_user) -> List[User]: +def owners(get_user) -> list[User]: return [get_user("admin")] diff --git a/tests/integration_tests/reports/utils.py b/tests/integration_tests/reports/utils.py index 3801beb1a328e..7672c5c94046a 100644 --- a/tests/integration_tests/reports/utils.py +++ b/tests/integration_tests/reports/utils.py @@ -17,7 +17,7 @@ import json from contextlib import contextmanager -from typing import Any, Dict, List, Optional +from typing import Any, Optional from uuid import uuid4 from flask_appbuilder.security.sqla.models import User @@ -49,7 +49,7 @@ def insert_report_schedule( type: str, name: str, crontab: str, - owners: List[User], + owners: list[User], timezone: Optional[str] = None, sql: Optional[str] = None, description: Optional[str] = None, @@ -61,10 +61,10 @@ def insert_report_schedule( log_retention: Optional[int] = None, last_state: Optional[ReportState] = None, grace_period: Optional[int] = None, - recipients: Optional[List[ReportRecipients]] = None, + recipients: Optional[list[ReportRecipients]] = None, report_format: Optional[ReportDataFormat] = None, - logs: Optional[List[ReportExecutionLog]] = None, - extra: Optional[Dict[Any, Any]] = None, + logs: Optional[list[ReportExecutionLog]] = None, + extra: Optional[dict[Any, Any]] = None, force_screenshot: bool = False, ) -> ReportSchedule: owners = owners or [] @@ -113,9 +113,9 @@ def create_report_notification( grace_period: Optional[int] = None, report_format: Optional[ReportDataFormat] = None, name: Optional[str] = None, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, force_screenshot: bool = False, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> ReportSchedule: if not owners: owners = [ diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py index a541f00952773..ae89fea068661 100644 --- a/tests/integration_tests/security/migrate_roles_tests.py +++ b/tests/integration_tests/security/migrate_roles_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Unit tests for alerting in Superset""" -import json import logging from contextlib import contextmanager from unittest.mock import patch diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 51aa76ee2781a..2a28089c3e51e 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file import re -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest import mock import pytest @@ -55,7 +55,7 @@ class TestRowLevelSecurity(SupersetTestCase): """ rls_entry = None - query_obj: Dict[str, Any] = dict( + query_obj: dict[str, Any] = dict( groupby=[], metrics=None, filter=[], @@ -542,8 +542,8 @@ def test_rls_tables_related_api(self): db_tables = db.session.query(SqlaTable).all() - db_table_names = set([t.name for t in db_tables]) - received_tables = set([table["text"] for table in result]) + db_table_names = {t.name for t in db_tables} + received_tables = {table["text"] for table in result} assert data["count"] == len(db_tables) assert len(result) == len(db_tables) @@ -558,8 +558,8 @@ def test_rls_roles_related_api(self): data = json.loads(rv.data.decode("utf-8")) result = data["result"] - db_role_names = set([r.name for r in security_manager.get_all_roles()]) - received_roles = set([role["text"] for role in result]) + db_role_names = {r.name for r in security_manager.get_all_roles()} + received_roles = {role["text"] for role in result} assert data["count"] == len(db_role_names) assert len(result) == len(db_role_names) @@ -580,7 +580,7 @@ def test_table_related_filter(self): self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) result = data["result"] - received_tables = set([table["text"].split(".")[-1] for table in result]) + received_tables = {table["text"].split(".")[-1] for table in result} assert data["count"] == 1 assert len(result) == 1 @@ -615,7 +615,7 @@ def _base_filter(query): EMBEDDED_SUPERSET=True, ) class GuestTokenRowLevelSecurityTests(SupersetTestCase): - query_obj: Dict[str, Any] = dict( + query_obj: dict[str, Any] = dict( groupby=[], metrics=None, filter=[], @@ -633,7 +633,7 @@ def default_rls_rule(self): "clause": "name = 'Alice'", } - def guest_user_with_rls(self, rules: Optional[List[Any]] = None) -> GuestUser: + def guest_user_with_rls(self, rules: Optional[list[Any]] = None) -> GuestUser: if rules is None: rules = [self.default_rls_rule()] return security_manager.get_guest_user_from_token( diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index a57d24c3e45ed..89aefdfd0979d 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -112,7 +112,7 @@ def test_estimate_valid_request(self): @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_execute_required_params(self): self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] data = {"client_id": client_id} rv = self.client.post( @@ -157,7 +157,7 @@ def test_execute_valid_request(self) -> None: core.results_backend.get.return_value = {} self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id} rv = self.client.post( diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 3d505ee2f544a..d76924a8fb1cc 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import mock, skip +from unittest import mock from unittest.mock import Mock, patch import pandas as pd diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 4003913516fee..854a0c9be020b 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -17,7 +17,8 @@ # isort:skip_file import re from datetime import datetime -from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union +from typing import Any, NamedTuple, Optional, Union +from re import Pattern from unittest.mock import patch import pytest @@ -50,7 +51,7 @@ from .base_tests import SupersetTestCase from .conftest import only_postgresql -VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = { +VIRTUAL_TABLE_INT_TYPES: dict[str, Pattern[str]] = { "hive": re.compile(r"^INT_TYPE$"), "mysql": re.compile("^LONGLONG$"), "postgresql": re.compile(r"^INTEGER$"), @@ -58,7 +59,7 @@ "sqlite": re.compile(r"^INT$"), } -VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = { +VIRTUAL_TABLE_STRING_TYPES: dict[str, Pattern[str]] = { "hive": re.compile(r"^STRING_TYPE$"), "mysql": re.compile(r"^VAR_STRING$"), "postgresql": re.compile(r"^STRING$"), @@ -70,8 +71,8 @@ class FilterTestCase(NamedTuple): column: str operator: str - value: Union[float, int, List[Any], str] - expected: Union[str, List[str]] + value: Union[float, int, list[Any], str] + expected: Union[str, list[str]] class TestDatabaseModel(SupersetTestCase): @@ -101,7 +102,7 @@ def test_temporal_varchar(self): assert col.is_temporal is True def test_db_column_types(self): - test_cases: Dict[str, GenericDataType] = { + test_cases: dict[str, GenericDataType] = { # string "CHAR": GenericDataType.STRING, "VARCHAR": GenericDataType.STRING, @@ -291,7 +292,7 @@ def test_adhoc_metrics_and_calc_columns(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_where_operators(self): - filters: Tuple[FilterTestCase, ...] = ( + filters: tuple[FilterTestCase, ...] = ( FilterTestCase("num", FilterOperator.IS_NULL, "", "IS NULL"), FilterTestCase("num", FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"), # Some db backends translate true/false to 1/0 @@ -493,7 +494,7 @@ def test_fetch_metadata_for_updated_virtual_table(self): "mycase", "expr", } - cols: Dict[str, TableColumn] = {col.column_name: col for col in table.columns} + cols: dict[str, TableColumn] = {col.column_name: col for col in table.columns} # assert that the type for intcol has been updated (asserting CI types) backend = table.database.backend assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type) @@ -802,7 +803,7 @@ def test__normalize_prequery_result_type( result: Any, ) -> None: def _convert_dttm( - target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: if target_type.upper() == "TIMESTAMP": return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 16cc16d264540..e9892b1d36c4b 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -359,7 +359,7 @@ def test_queries_endpoint(self): db.session.commit() data = self.get_json_resp( - "/superset/queries/{}".format(float(datetime_to_epoch(now)) - 1000) + f"/superset/queries/{float(datetime_to_epoch(now)) - 1000}" ) self.assertEqual(1, len(data)) @@ -391,13 +391,13 @@ def test_search_query_on_user(self): # Test search queries on user Id user_id = security_manager.find_user("admin").id - data = self.get_json_resp("/superset/search_queries?user_id={}".format(user_id)) + data = self.get_json_resp(f"/superset/search_queries?user_id={user_id}") self.assertEqual(2, len(data)) user_ids = {k["userId"] for k in data} - self.assertEqual(set([user_id]), user_ids) + self.assertEqual({user_id}, user_ids) user_id = security_manager.find_user("gamma_sqllab").id - resp = self.get_resp("/superset/search_queries?user_id={}".format(user_id)) + resp = self.get_resp(f"/superset/search_queries?user_id={user_id}") data = json.loads(resp) self.assertEqual(1, len(data)) self.assertEqual(data[0]["userId"], user_id) @@ -451,7 +451,7 @@ def test_search_query_only_owned(self) -> None: self.assertEqual(1, len(data)) user_ids = {k["userId"] for k in data} - self.assertEqual(set([user_id]), user_ids) + self.assertEqual({user_id}, user_ids) def test_alias_duplicate(self): self.run_sql( @@ -593,7 +593,7 @@ def test_sql_limit(self): self.assertEqual(len(data["data"]), test_limit) data = self.run_sql( - "SELECT * FROM birth_names LIMIT {}".format(test_limit), + f"SELECT * FROM birth_names LIMIT {test_limit}", client_id="sql_limit_3", query_limit=test_limit + 1, ) @@ -601,7 +601,7 @@ def test_sql_limit(self): self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.QUERY) data = self.run_sql( - "SELECT * FROM birth_names LIMIT {}".format(test_limit + 1), + f"SELECT * FROM birth_names LIMIT {test_limit + 1}", client_id="sql_limit_4", query_limit=test_limit, ) @@ -609,7 +609,7 @@ def test_sql_limit(self): self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.DROPDOWN) data = self.run_sql( - "SELECT * FROM birth_names LIMIT {}".format(test_limit), + f"SELECT * FROM birth_names LIMIT {test_limit}", client_id="sql_limit_5", query_limit=test_limit, ) diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py index e54ae865e3c15..f6d664c649971 100644 --- a/tests/integration_tests/strategy_tests.py +++ b/tests/integration_tests/strategy_tests.py @@ -16,8 +16,6 @@ # under the License. # isort:skip_file """Unit tests for Superset cache warmup""" -import datetime -import json from unittest.mock import MagicMock from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py index c3f9b350f8e81..77e007a2ddbd8 100644 --- a/tests/integration_tests/superset_test_config.py +++ b/tests/integration_tests/superset_test_config.py @@ -130,7 +130,7 @@ def GET_FEATURE_FLAGS_FUNC(ff): ALERT_REPORTS_QUERY_EXECUTION_MAX_TRIES = 3 -class CeleryConfig(object): +class CeleryConfig: BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" CELERY_IMPORTS = ("superset.sql_lab",) CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}" diff --git a/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py b/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py index 9f6dd2ead1fa2..31d14ef71b357 100644 --- a/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py +++ b/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py @@ -16,8 +16,6 @@ # under the License. # flake8: noqa # type: ignore -import os -from copy import copy from .superset_test_config import * diff --git a/tests/integration_tests/superset_test_config_thumbnails.py b/tests/integration_tests/superset_test_config_thumbnails.py index 9f621efabbf4d..5bd02e7b0fd5b 100644 --- a/tests/integration_tests/superset_test_config_thumbnails.py +++ b/tests/integration_tests/superset_test_config_thumbnails.py @@ -61,7 +61,7 @@ def GET_FEATURE_FLAGS_FUNC(ff): REDIS_RESULTS_DB = os.environ.get("REDIS_RESULTS_DB", 3) -class CeleryConfig(object): +class CeleryConfig: BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks.thumbnails") CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}} diff --git a/tests/integration_tests/tagging_tests.py b/tests/integration_tests/tagging_tests.py index 71fb7e4e4e89d..72ba577d9ffc0 100644 --- a/tests/integration_tests/tagging_tests.py +++ b/tests/integration_tests/tagging_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from unittest import mock import pytest diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index 7bf21da4fcd71..b047388a68cb0 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -16,10 +16,7 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from datetime import datetime, timedelta import json -import random -import string import pytest import prison diff --git a/tests/integration_tests/tags/commands_tests.py b/tests/integration_tests/tags/commands_tests.py index 8f44d2ebda0dd..cd5a024840b1c 100644 --- a/tests/integration_tests/tags/commands_tests.py +++ b/tests/integration_tests/tags/commands_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import itertools -import json from unittest.mock import MagicMock, patch import pytest diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py index f46abaa723bda..49b22d260b048 100644 --- a/tests/integration_tests/tags/dao_tests.py +++ b/tests/integration_tests/tags/dao_tests.py @@ -15,10 +15,7 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -import copy -import json from operator import and_ -import time from unittest.mock import patch import pytest from superset.dao.exceptions import DAOCreateFailedError, DAOException diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py index 228da6de79c95..eb2be859ba3ac 100644 --- a/tests/integration_tests/thumbnails_tests.py +++ b/tests/integration_tests/thumbnails_tests.py @@ -20,7 +20,6 @@ import json import urllib.request from io import BytesIO -from typing import Tuple from unittest import skipUnless from unittest.mock import ANY, call, MagicMock, patch @@ -203,7 +202,7 @@ class TestThumbnails(SupersetTestCase): digest_return_value = "foo_bar" digest_hash = "5c7d96a3dd7a87850a2ef34087565a6e" - def _get_id_and_thumbnail_url(self, url: str) -> Tuple[int, str]: + def _get_id_and_thumbnail_url(self, url: str) -> tuple[int, str]: rv = self.client.get(url) resp = json.loads(rv.data.decode("utf-8")) obj = resp["result"][0] diff --git a/tests/integration_tests/users/__init__.py b/tests/integration_tests/users/__init__.py index fd9417fe5c1e9..13a83393a9124 100644 --- a/tests/integration_tests/users/__init__.py +++ b/tests/integration_tests/users/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/integration_tests/utils/csv_tests.py b/tests/integration_tests/utils/csv_tests.py index e514efb1d2108..38c1dd51ac257 100644 --- a/tests/integration_tests/utils/csv_tests.py +++ b/tests/integration_tests/utils/csv_tests.py @@ -43,13 +43,13 @@ def test_escape_value(): assert result == "'=value" result = csv.escape_value("|value") - assert result == "'\|value" + assert result == r"'\|value" result = csv.escape_value("%value") assert result == "'%value" result = csv.escape_value("=cmd|' /C calc'!A0") - assert result == "'=cmd\|' /C calc'!A0" + assert result == r"'=cmd\|' /C calc'!A0" result = csv.escape_value('""=10+2') assert result == '\'""=10+2' @@ -74,7 +74,7 @@ def test_df_to_escaped_csv(): assert escaped_csv_rows == [ ["col_a", "'=func()"], - ["-10", "'=cmd\|' /C calc'!A0"], + ["-10", r"'=cmd\|' /C calc'!A0"], ["a", "'=b"], # pandas seems to be removing the leading "" ["' =a", "b"], ] diff --git a/tests/integration_tests/utils/encrypt_tests.py b/tests/integration_tests/utils/encrypt_tests.py index 2199783529b88..45fd291ee8523 100644 --- a/tests/integration_tests/utils/encrypt_tests.py +++ b/tests/integration_tests/utils/encrypt_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from sqlalchemy import String, TypeDecorator from sqlalchemy_utils import EncryptedType @@ -28,9 +28,9 @@ class CustomEncFieldAdapter(AbstractEncryptedFieldAdapter): def create( self, - app_config: Optional[Dict[str, Any]], - *args: List[Any], - **kwargs: Optional[Dict[str, Any]] + app_config: Optional[dict[str, Any]], + *args: list[Any], + **kwargs: Optional[dict[str, Any]] ) -> TypeDecorator: if app_config: return StringEncryptedType(*args, app_config["SECRET_KEY"], **kwargs) diff --git a/tests/integration_tests/utils/get_dashboards.py b/tests/integration_tests/utils/get_dashboards.py index 03260fb94d07f..7012bf08a054f 100644 --- a/tests/integration_tests/utils/get_dashboards.py +++ b/tests/integration_tests/utils/get_dashboards.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List from flask_appbuilder import SQLA from superset.models.dashboard import Dashboard -def get_dashboards_ids(db: SQLA, dashboard_slugs: List[str]) -> List[int]: +def get_dashboards_ids(db: SQLA, dashboard_slugs: list[str]) -> list[int]: result = ( db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all() ) diff --git a/tests/integration_tests/utils/public_interfaces_test.py b/tests/integration_tests/utils/public_interfaces_test.py index 7b5d6712464df..af67bb6ca3cc6 100644 --- a/tests/integration_tests/utils/public_interfaces_test.py +++ b/tests/integration_tests/utils/public_interfaces_test.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict +from typing import Any, Callable import pytest @@ -23,7 +23,7 @@ # These are public interfaces exposed by Superset. Make sure # to only change the interfaces and update the hashes in new # major versions of Superset. -hashes: Dict[Callable[..., Any], str] = {} +hashes: dict[Callable[..., Any], str] = {} @pytest.mark.parametrize("interface,expected_hash", list(hashes.items())) diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 2db008fdb72e6..b4d750c8d0bcd 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -21,7 +21,7 @@ import json import os import re -from typing import Any, Tuple, List, Optional +from typing import Any, Optional from unittest.mock import Mock, patch from superset.databases.commands.exceptions import DatabaseInvalidError @@ -121,12 +121,12 @@ def test_base_json_conv(self): assert isinstance(base_json_conv(np.int64(1)), int) assert isinstance(base_json_conv(np.array([1, 2, 3])), list) assert base_json_conv(np.array(None)) is None - assert isinstance(base_json_conv(set([1])), list) + assert isinstance(base_json_conv({1}), list) assert isinstance(base_json_conv(Decimal("1.0")), float) assert isinstance(base_json_conv(uuid.uuid4()), str) assert isinstance(base_json_conv(time()), str) assert isinstance(base_json_conv(timedelta(0)), str) - assert isinstance(base_json_conv(bytes()), str) + assert isinstance(base_json_conv(b""), str) assert base_json_conv(bytes("", encoding="utf-16")) == "[bytes]" with pytest.raises(TypeError): @@ -1054,7 +1054,7 @@ def test_get_form_data_token(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_extract_dataframe_dtypes(self): slc = self.get_slice("Girls", db.session) - cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = ( + cols: tuple[tuple[str, GenericDataType, list[Any]], ...] = ( ("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]), ( "dttm", diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py index 137e2a474c344..30d4d1e1830a9 100644 --- a/tests/integration_tests/viz_tests.py +++ b/tests/integration_tests/viz_tests.py @@ -19,7 +19,6 @@ import logging from math import nan from unittest.mock import Mock, patch -from typing import Any, Dict, List, Set import numpy as np import pandas as pd @@ -1009,7 +1008,7 @@ def test_get_data_metrics(self): test_viz = viz.TimeTableViz(datasource, form_data) data = test_viz.get_data(df) # Check method correctly transforms data - self.assertEqual(set(["count", "sum__A"]), set(data["columns"])) + self.assertEqual({"count", "sum__A"}, set(data["columns"])) time_format = "%Y-%m-%d %H:%M:%S" expected = { t1.strftime(time_format): {"sum__A": 15, "count": 6}, @@ -1030,7 +1029,7 @@ def test_get_data_group_by(self): test_viz = viz.TimeTableViz(datasource, form_data) data = test_viz.get_data(df) # Check method correctly transforms data - self.assertEqual(set(["a1", "a2", "a3"]), set(data["columns"])) + self.assertEqual({"a1", "a2", "a3"}, set(data["columns"])) time_format = "%Y-%m-%d %H:%M:%S" expected = { t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25}, diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py index 72ae9dbba7fef..b1d5cc64881da 100644 --- a/tests/unit_tests/charts/dao/dao_tests.py +++ b/tests/unit_tests/charts/dao/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py index 84496bf1cfdeb..b7cdda6e68f07 100644 --- a/tests/unit_tests/charts/test_post_processing.py +++ b/tests/unit_tests/charts/test_post_processing.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import json import pandas as pd import pytest diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py index 4fd906f648ee9..02304828dca82 100644 --- a/tests/unit_tests/common/test_query_object_factory.py +++ b/tests/unit_tests/common/test_query_object_factory.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import Mock, patch from pytest import fixture, mark @@ -23,7 +23,7 @@ from tests.common.query_context_generator import QueryContextGenerator -def create_app_config() -> Dict[str, Any]: +def create_app_config() -> dict[str, Any]: return { "ROW_LIMIT": 5000, "DEFAULT_RELATIVE_START_TIME": "today", @@ -34,7 +34,7 @@ def create_app_config() -> Dict[str, Any]: @fixture -def app_config() -> Dict[str, Any]: +def app_config() -> dict[str, Any]: return create_app_config().copy() @@ -58,7 +58,7 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: @fixture def query_object_factory( - app_config: Dict[str, Any], connector_registry: Mock, session_factory: Mock + app_config: dict[str, Any], connector_registry: Mock, session_factory: Mock ) -> QueryObjectFactory: import superset.common.query_object_factory as mod @@ -67,7 +67,7 @@ def query_object_factory( @fixture -def raw_query_context() -> Dict[str, Any]: +def raw_query_context() -> dict[str, Any]: return QueryContextGenerator().generate("birth_names") @@ -75,7 +75,7 @@ class TestQueryObjectFactory: def test_query_context_limit_and_offset_defaults( self, query_object_factory: QueryObjectFactory, - raw_query_context: Dict[str, Any], + raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object.pop("row_limit", None) @@ -89,7 +89,7 @@ def test_query_context_limit_and_offset_defaults( def test_query_context_limit( self, query_object_factory: QueryObjectFactory, - raw_query_context: Dict[str, Any], + raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["row_limit"] = 100 @@ -104,7 +104,7 @@ def test_query_context_limit( def test_query_context_null_post_processing_op( self, query_object_factory: QueryObjectFactory, - raw_query_context: Dict[str, Any], + raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["post_processing"] = [None] diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py index 021193a6cd36e..4a62f26e6f85e 100644 --- a/tests/unit_tests/config_test.py +++ b/tests/unit_tests/config_test.py @@ -17,7 +17,7 @@ # pylint: disable=import-outside-toplevel, unused-argument, redefined-outer-name, invalid-name from functools import partial -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import pytest from pytest_mock import MockerFixture @@ -44,7 +44,7 @@ } -def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: Dict[str, Any]) -> None: +def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: dict[str, Any]) -> None: """Applies dttm defaults to the table, mutates in place.""" for dbcol in table.columns: # Set is_dttm is column is listed in dttm_columns. diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 6740a8b6e280b..6a4f1e550cd26 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -19,7 +19,8 @@ import importlib import os import unittest.mock -from typing import Any, Callable, Iterator +from collections.abc import Iterator +from typing import Any, Callable import pytest from _pytest.fixtures import SubRequest diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py index 62eeff31065e0..d0ab3ec8a51f3 100644 --- a/tests/unit_tests/dao/queries_test.py +++ b/tests/unit_tests/dao/queries_test.py @@ -14,9 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json from datetime import datetime, timedelta -from typing import Any, Iterator +from typing import Any import pytest from pytest_mock import MockFixture diff --git a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py index 0392acb31596a..60a659159a332 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=import-outside-toplevel, unused-argument -from typing import Any, Dict +from typing import Any def test_update_id_refs_immune_missing( # pylint: disable=invalid-name @@ -59,7 +59,7 @@ def test_update_id_refs_immune_missing( # pylint: disable=invalid-name }, } chart_ids = {"uuid1": 1, "uuid2": 2} - dataset_info: Dict[str, Dict[str, Any]] = {} # not used + dataset_info: dict[str, dict[str, Any]] = {} # not used fixed = update_id_refs(config, chart_ids, dataset_info) assert fixed == { @@ -103,7 +103,7 @@ def test_update_native_filter_config_scope_excluded(): }, } chart_ids = {"uuid1": 1, "uuid2": 2} - dataset_info: Dict[str, Dict[str, Any]] = {} # not used + dataset_info: dict[str, dict[str, Any]] = {} # not used fixed = update_id_refs(config, chart_ids, dataset_info) assert fixed == { diff --git a/tests/unit_tests/dashboards/dao_tests.py b/tests/unit_tests/dashboards/dao_tests.py index a8f93e7513906..c94d2ab15750b 100644 --- a/tests/unit_tests/dashboards/dao_tests.py +++ b/tests/unit_tests/dashboards/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index 47db402670dee..f085cb53c7913 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 2a5738ebd396a..fbad104c1da00 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index b5adf765fa5ab..de0b70db9cbf3 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from pytest_mock import MockFixture diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index 58f90054ccd1f..544cf3434a47c 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py index ae5b6e9bd3c39..27f9c3b8ad548 100644 --- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py +++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index 839374425b4d7..9e8690e6e35fc 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -20,7 +20,7 @@ import json import re import uuid -from typing import Any, Dict +from typing import Any from unittest.mock import Mock, patch import pytest @@ -296,7 +296,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) -> session.flush() dataset_uuid = uuid.uuid4() - yaml_config: Dict[str, Any] = { + yaml_config: dict[str, Any] = { "version": "1.0.0", "table_name": "my_table", "main_dttm_col": "ds", @@ -388,7 +388,7 @@ def test_import_column_allowed_data_url( session.flush() dataset_uuid = uuid.uuid4() - yaml_config: Dict[str, Any] = { + yaml_config: dict[str, Any] = { "version": "1.0.0", "table_name": "my_table", "main_dttm_col": "ds", diff --git a/tests/unit_tests/datasets/conftest.py b/tests/unit_tests/datasets/conftest.py index 8d217ae27a7e7..8bef6945a6931 100644 --- a/tests/unit_tests/datasets/conftest.py +++ b/tests/unit_tests/datasets/conftest.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import pytest @@ -23,7 +23,7 @@ @pytest.fixture -def columns_default() -> Dict[str, Any]: +def columns_default() -> dict[str, Any]: """Default props for new columns""" return { "changed_by": 1, @@ -49,7 +49,7 @@ def columns_default() -> Dict[str, Any]: @pytest.fixture -def sample_columns() -> Dict["TableColumn", Dict[str, Any]]: +def sample_columns() -> dict["TableColumn", dict[str, Any]]: from superset.connectors.sqla.models import TableColumn return { @@ -93,7 +93,7 @@ def sample_columns() -> Dict["TableColumn", Dict[str, Any]]: @pytest.fixture -def sample_metrics() -> Dict["SqlMetric", Dict[str, Any]]: +def sample_metrics() -> dict["SqlMetric", dict[str, Any]]: from superset.connectors.sqla.models import SqlMetric return { diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 350425d08e897..4eb43cd9de1bc 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index 16334066d7ba1..99a485030195f 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/db_engine_specs/test_athena.py b/tests/unit_tests/db_engine_specs/test_athena.py index 51ec6656aa7f0..f0811a3e1471f 100644 --- a/tests/unit_tests/db_engine_specs/test_athena.py +++ b/tests/unit_tests/db_engine_specs/test_athena.py @@ -81,7 +81,7 @@ def test_get_text_clause_with_colon() -> None: from superset.db_engine_specs.athena import AthenaEngineSpec query = ( - "SELECT foo FROM tbl WHERE " "abc >= TIMESTAMP '2021-11-26T00\:00\:00.000000'" + "SELECT foo FROM tbl WHERE " r"abc >= TIMESTAMP '2021-11-26T00\:00\:00.000000'" ) text_clause = AthenaEngineSpec.get_text_clause(query) assert text_clause.text == query diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 868a6bbdc3fe6..33083f03997c3 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -17,7 +17,7 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access from textwrap import dedent -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import types @@ -130,8 +130,8 @@ def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None: ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index 0c437bc00998c..6dfeddaf37cfd 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from unittest.mock import Mock import pytest @@ -189,8 +189,8 @@ def test_connect_convert_dttm( ) def test_connect_get_column_spec( native_type: str, - sqla_type: Type[TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_elasticsearch.py b/tests/unit_tests/db_engine_specs/test_elasticsearch.py index de55c63426b70..0c1597766948b 100644 --- a/tests/unit_tests/db_engine_specs/test_elasticsearch.py +++ b/tests/unit_tests/db_engine_specs/test_elasticsearch.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import MagicMock import pytest @@ -49,7 +49,7 @@ ) def test_elasticsearch_convert_dttm( target_type: str, - db_extra: Optional[Dict[str, Any]], + db_extra: Optional[dict[str, Any]], expected_result: Optional[str], dttm: datetime, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index acd35a4ecf476..673f4817be735 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -17,7 +17,7 @@ import unittest.mock as mock from datetime import datetime from textwrap import dedent -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import column, table @@ -50,8 +50,8 @@ ) def test_get_column_spec( native_type: str, - sqla_type: Type[TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 07ce6838fc20b..89abf2321d79b 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Optional from unittest.mock import Mock, patch import pytest @@ -71,8 +71,8 @@ ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -166,7 +166,7 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: ], ) def test_adjust_engine_params( - sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any] + sqlalchemy_uri: str, connect_args: dict[str, Any], returns: dict[str, Any] ) -> None: from superset.db_engine_specs.mysql import MySQLEngineSpec diff --git a/tests/unit_tests/db_engine_specs/test_ocient.py b/tests/unit_tests/db_engine_specs/test_ocient.py index af9fd2ad1681f..a58f31d242a7e 100644 --- a/tests/unit_tests/db_engine_specs/test_ocient.py +++ b/tests/unit_tests/db_engine_specs/test_ocient.py @@ -17,7 +17,7 @@ # pylint: disable=import-outside-toplevel -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable import pytest @@ -33,7 +33,7 @@ def ocient_is_installed() -> bool: # (msg,expected) -MARSHALED_OCIENT_ERRORS: List[Tuple[str, SupersetError]] = [ +MARSHALED_OCIENT_ERRORS: list[tuple[str, SupersetError]] = [ ( "The referenced user does not exist (User 'mj' not found)", SupersetError( @@ -224,7 +224,7 @@ def test_connection_errors(msg: str, expected: SupersetError) -> None: def _generate_gis_type_sanitization_test_cases() -> ( - List[Tuple[str, int, Any, Dict[str, Any]]] + list[tuple[str, int, Any, dict[str, Any]]] ): if not ocient_is_installed(): return [] diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index fef864795962e..145d398898d13 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import types @@ -82,8 +82,8 @@ def test_convert_dttm( ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index df2ed58c376ed..7739361cf3f09 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from unittest import mock import pytest @@ -77,8 +77,8 @@ def test_convert_dttm( ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py b/tests/unit_tests/db_engine_specs/test_starrocks.py index 7812a16830b25..ac246e3d5bc86 100644 --- a/tests/unit_tests/db_engine_specs/test_starrocks.py +++ b/tests/unit_tests/db_engine_specs/test_starrocks.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import types @@ -45,8 +45,8 @@ ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -74,9 +74,9 @@ def test_get_column_spec( ) def test_adjust_engine_params( sqlalchemy_uri: str, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], return_schema: str, - return_connect_args: Dict[str, Any], + return_connect_args: dict[str, Any], ) -> None: from superset.db_engine_specs.starrocks import StarRocksEngineSpec diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 0ea296a075e71..963953d18b48e 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -17,7 +17,7 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access import json from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from unittest.mock import Mock, patch import pandas as pd @@ -57,7 +57,7 @@ ), ], ) -def test_get_extra_params(extra: Dict[str, Any], expected: Dict[str, Any]) -> None: +def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() @@ -95,7 +95,7 @@ def test_auth_basic(mock_auth: Mock) -> None: {"auth_method": "basic", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -117,7 +117,7 @@ def test_auth_kerberos(mock_auth: Mock) -> None: {"auth_method": "kerberos", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -134,7 +134,7 @@ def test_auth_certificate(mock_auth: Mock) -> None: {"auth_method": "certificate", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -152,7 +152,7 @@ def test_auth_jwt(mock_auth: Mock) -> None: {"auth_method": "jwt", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -176,7 +176,7 @@ def test_auth_custom_auth() -> None: {"trino": {"custom_auth": auth_class}}, clear=True, ): - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) @@ -243,8 +243,8 @@ def test_auth_custom_auth_denied() -> None: ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -324,8 +324,8 @@ def test_cancel_query_failed(engine_mock: Mock) -> None: ], ) def test_prepare_cancel_query( - initial_extra: Dict[str, Any], - final_extra: Dict[str, Any], + initial_extra: dict[str, Any], + final_extra: dict[str, Any], mocker: MockerFixture, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec diff --git a/tests/unit_tests/db_engine_specs/utils.py b/tests/unit_tests/db_engine_specs/utils.py index 13ae7a34d2931..774ca3eaf20db 100644 --- a/tests/unit_tests/db_engine_specs/utils.py +++ b/tests/unit_tests/db_engine_specs/utils.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import types @@ -28,11 +28,11 @@ def assert_convert_dttm( - db_engine_spec: Type[BaseEngineSpec], + db_engine_spec: type[BaseEngineSpec], target_type: str, - expected_result: Optional[str], + expected_result: str | None, dttm: datetime, - db_extra: Optional[Dict[str, Any]] = None, + db_extra: dict[str, Any] | None = None, ) -> None: for target in ( target_type, @@ -50,10 +50,10 @@ def assert_convert_dttm( def assert_column_spec( - db_engine_spec: Type[BaseEngineSpec], + db_engine_spec: type[BaseEngineSpec], native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: dict[str, Any] | None, generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/extensions/ssh_test.py b/tests/unit_tests/extensions/ssh_test.py index 0e997729d96fe..4538d719697ea 100644 --- a/tests/unit_tests/extensions/ssh_test.py +++ b/tests/unit_tests/extensions/ssh_test.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any from unittest.mock import Mock, patch import pytest diff --git a/tests/unit_tests/fixtures/assets_configs.py b/tests/unit_tests/fixtures/assets_configs.py index 73bc5921ec42f..bda84c13356b5 100644 --- a/tests/unit_tests/fixtures/assets_configs.py +++ b/tests/unit_tests/fixtures/assets_configs.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any -databases_config: Dict[str, Any] = { +databases_config: dict[str, Any] = { "databases/examples.yaml": { "database_name": "examples", "sqlalchemy_uri": "sqlite:///test.db", @@ -32,7 +32,7 @@ "allow_csv_upload": False, }, } -datasets_config: Dict[str, Any] = { +datasets_config: dict[str, Any] = { "datasets/examples/video_game_sales.yaml": { "table_name": "video_game_sales", "main_dttm_col": None, @@ -80,7 +80,7 @@ "database_uuid": "a2dc77af-e654-49bb-b321-40f6b559a1ee", }, } -charts_config_1: Dict[str, Any] = { +charts_config_1: dict[str, Any] = { "charts/Games_per_Genre_over_time_95.yaml": { "slice_name": "Games per Genre over time", "viz_type": "line", @@ -100,7 +100,7 @@ "dataset_uuid": "53d47c0c-c03d-47f0-b9ac-81225f808283", }, } -dashboards_config_1: Dict[str, Any] = { +dashboards_config_1: dict[str, Any] = { "dashboards/Video_Game_Sales_11.yaml": { "dashboard_title": "Video Game Sales", "description": None, @@ -182,7 +182,7 @@ }, } -charts_config_2: Dict[str, Any] = { +charts_config_2: dict[str, Any] = { "charts/Games_per_Genre_131.yaml": { "slice_name": "Games per Genre", "viz_type": "treemap", @@ -193,7 +193,7 @@ "dataset_uuid": "53d47c0c-c03d-47f0-b9ac-81225f808283", }, } -dashboards_config_2: Dict[str, Any] = { +dashboards_config_2: dict[str, Any] = { "dashboards/Video_Game_Sales_11.yaml": { "dashboard_title": "Video Game Sales", "description": None, diff --git a/tests/unit_tests/fixtures/datasets.py b/tests/unit_tests/fixtures/datasets.py index 5d5466a5e8135..7bddae0b8138b 100644 --- a/tests/unit_tests/fixtures/datasets.py +++ b/tests/unit_tests/fixtures/datasets.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from unittest.mock import Mock -def get_column_mock(params: Dict[str, Any]) -> Mock: +def get_column_mock(params: dict[str, Any]) -> Mock: mock = Mock() mock.id = params["id"] mock.column_name = params["column_name"] @@ -32,7 +32,7 @@ def get_column_mock(params: Dict[str, Any]) -> Mock: return mock -def get_metric_mock(params: Dict[str, Any]) -> Mock: +def get_metric_mock(params: dict[str, Any]) -> Mock: mock = Mock() mock.id = params["id"] mock.metric_name = params["metric_name"] diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index bf8f589913086..d37296447ad63 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -18,7 +18,7 @@ # pylint: disable=import-outside-toplevel import json from datetime import datetime -from typing import List, Optional +from typing import Optional import pytest from pytest_mock import MockFixture @@ -54,7 +54,7 @@ def get_metrics( inspector: Inspector, table_name: str, schema: Optional[str], - ) -> List[MetricType]: + ) -> list[MetricType]: return [ { "expression": "COUNT(DISTINCT user_id)", diff --git a/tests/unit_tests/pandas_postprocessing/utils.py b/tests/unit_tests/pandas_postprocessing/utils.py index 07366b15774d1..fa9fa30d36a10 100644 --- a/tests/unit_tests/pandas_postprocessing/utils.py +++ b/tests/unit_tests/pandas_postprocessing/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import math -from typing import Any, List, Optional +from typing import Any, Optional from pandas import Series @@ -26,7 +26,7 @@ } -def series_to_list(series: Series) -> List[Any]: +def series_to_list(series: Series) -> list[Any]: """ Converts a `Series` to a regular list, and replaces non-numeric values to Nones. @@ -43,8 +43,8 @@ def series_to_list(series: Series) -> List[Any]: def round_floats( - floats: List[Optional[float]], precision: int -) -> List[Optional[float]]: + floats: list[Optional[float]], precision: int +) -> list[Optional[float]]: """ Round list of floats to certain precision diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index cfe6e213b297e..e00dc3166e024 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -16,8 +16,7 @@ # under the License. # pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines -import unittest -from typing import Optional, Set +from typing import Optional import pytest import sqlparse @@ -40,7 +39,7 @@ ) -def extract_tables(query: str) -> Set[Table]: +def extract_tables(query: str) -> set[Table]: """ Helper function to extract tables referenced in a query. """ diff --git a/tests/unit_tests/tasks/test_cron_util.py b/tests/unit_tests/tasks/test_cron_util.py index 282dc99860f33..5bc22273f544e 100644 --- a/tests/unit_tests/tasks/test_cron_util.py +++ b/tests/unit_tests/tasks/test_cron_util.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import datetime -from typing import List import pytest import pytz @@ -49,7 +47,7 @@ ], ) def test_cron_schedule_window_los_angeles( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Los_Angeles" @@ -86,7 +84,7 @@ def test_cron_schedule_window_los_angeles( ], ) def test_cron_schedule_window_invalid_timezone( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "invalid timezone" @@ -124,7 +122,7 @@ def test_cron_schedule_window_invalid_timezone( ], ) def test_cron_schedule_window_new_york( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/New_York" @@ -161,7 +159,7 @@ def test_cron_schedule_window_new_york( ], ) def test_cron_schedule_window_chicago( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Chicago" @@ -198,7 +196,7 @@ def test_cron_schedule_window_chicago( ], ) def test_cron_schedule_window_chicago_daylight( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Chicago" diff --git a/tests/unit_tests/tasks/test_utils.py b/tests/unit_tests/tasks/test_utils.py index 7854717201229..b3fbfca8a24fd 100644 --- a/tests/unit_tests/tasks/test_utils.py +++ b/tests/unit_tests/tasks/test_utils.py @@ -18,7 +18,7 @@ from contextlib import nullcontext from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union import pytest from flask_appbuilder.security.sqla.models import User @@ -31,8 +31,8 @@ def _get_users( - params: Optional[Union[int, List[int]]] -) -> Optional[Union[User, List[User]]]: + params: Optional[Union[int, list[int]]] +) -> Optional[Union[User, list[User]]]: if params is None: return None if isinstance(params, int): @@ -42,7 +42,7 @@ def _get_users( @dataclass class ModelConfig: - owners: List[int] + owners: list[int] creator: Optional[int] = None modifier: Optional[int] = None @@ -268,18 +268,18 @@ class ModelType(int, Enum): ) def test_get_executor( model_type: ModelType, - executor_types: List[ExecutorType], + executor_types: list[ExecutorType], model_config: ModelConfig, current_user: Optional[int], - expected_result: Tuple[int, ExecutorNotFoundError], + expected_result: tuple[int, ExecutorNotFoundError], ) -> None: from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.reports.models import ReportSchedule from superset.tasks.utils import get_executor - model: Type[Union[Dashboard, ReportSchedule, Slice]] - model_kwargs: Dict[str, Any] = {} + model: type[Union[Dashboard, ReportSchedule, Slice]] + model_kwargs: dict[str, Any] = {} if model_type == ModelType.REPORT_SCHEDULE: model = ReportSchedule model_kwargs = { diff --git a/tests/unit_tests/thumbnails/test_digest.py b/tests/unit_tests/thumbnails/test_digest.py index 04f244e629b59..68bd7a58f79aa 100644 --- a/tests/unit_tests/thumbnails/test_digest.py +++ b/tests/unit_tests/thumbnails/test_digest.py @@ -17,7 +17,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from unittest.mock import patch import pytest @@ -31,7 +31,7 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice -_DEFAULT_DASHBOARD_KWARGS: Dict[str, Any] = { +_DEFAULT_DASHBOARD_KWARGS: dict[str, Any] = { "id": 1, "dashboard_title": "My Title", "slices": [{"id": 1, "slice_name": "My Chart"}], @@ -150,11 +150,11 @@ def CUSTOM_CHART_FUNC( ], ) def test_dashboard_digest( - dashboard_overrides: Optional[Dict[str, Any]], - execute_as: List[ExecutorType], + dashboard_overrides: dict[str, Any] | None, + execute_as: list[ExecutorType], has_current_user: bool, use_custom_digest: bool, - expected_result: Union[str, Exception], + expected_result: str | Exception, ) -> None: from superset import app from superset.models.dashboard import Dashboard @@ -167,7 +167,7 @@ def test_dashboard_digest( } slices = [Slice(**slice_kwargs) for slice_kwargs in kwargs.pop("slices")] dashboard = Dashboard(**kwargs, slices=slices) - user: Optional[User] = None + user: User | None = None if has_current_user: user = User(id=1, username="1") func = CUSTOM_DASHBOARD_FUNC if use_custom_digest else None @@ -222,11 +222,11 @@ def test_dashboard_digest( ], ) def test_chart_digest( - chart_overrides: Optional[Dict[str, Any]], - execute_as: List[ExecutorType], + chart_overrides: dict[str, Any] | None, + execute_as: list[ExecutorType], has_current_user: bool, use_custom_digest: bool, - expected_result: Union[str, Exception], + expected_result: str | Exception, ) -> None: from superset import app from superset.models.slice import Slice @@ -237,7 +237,7 @@ def test_chart_digest( **(chart_overrides or {}), } chart = Slice(**kwargs) - user: Optional[User] = None + user: User | None = None if has_current_user: user = User(id=1, username="1") func = CUSTOM_CHART_FUNC if use_custom_digest else None diff --git a/tests/unit_tests/utils/cache_test.py b/tests/unit_tests/utils/cache_test.py index 53650e1d20324..bd6179957e408 100644 --- a/tests/unit_tests/utils/cache_test.py +++ b/tests/unit_tests/utils/cache_test.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/unit_tests/utils/date_parser_tests.py b/tests/unit_tests/utils/date_parser_tests.py index fb0fe07d29902..a2ec20901ab43 100644 --- a/tests/unit_tests/utils/date_parser_tests.py +++ b/tests/unit_tests/utils/date_parser_tests.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import date, datetime, timedelta -from typing import Optional, Tuple +from typing import Optional from unittest.mock import Mock, patch import pytest @@ -74,8 +74,8 @@ def mock_parse_human_datetime(s: str) -> Optional[datetime]: @patch("superset.utils.date_parser.parse_human_datetime", mock_parse_human_datetime) def test_get_since_until() -> None: - result: Tuple[Optional[datetime], Optional[datetime]] - expected: Tuple[Optional[datetime], Optional[datetime]] + result: tuple[Optional[datetime], Optional[datetime]] + expected: tuple[Optional[datetime], Optional[datetime]] result = get_since_until() expected = None, datetime(2016, 11, 7) diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py index 3636983156fdb..996bd1948f385 100644 --- a/tests/unit_tests/utils/test_core.py +++ b/tests/unit_tests/utils/test_core.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import os -from typing import Any, Dict, Optional +from typing import Any, Optional import pytest @@ -86,7 +85,7 @@ ], ) def test_remove_extra_adhoc_filters( - original: Dict[str, Any], expected: Dict[str, Any] + original: dict[str, Any], expected: dict[str, Any] ) -> None: remove_extra_adhoc_filters(original) assert expected == original diff --git a/tests/unit_tests/utils/test_file.py b/tests/unit_tests/utils/test_file.py index de20402e5c21c..a2168a7d9277b 100644 --- a/tests/unit_tests/utils/test_file.py +++ b/tests/unit_tests/utils/test_file.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/unit_tests/utils/urls_tests.py b/tests/unit_tests/utils/urls_tests.py index 208d6caea4375..287f346c3d99d 100644 --- a/tests/unit_tests/utils/urls_tests.py +++ b/tests/unit_tests/utils/urls_tests.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information