Skip to content

Commit

Permalink
chore: Embrace the walrus operator (#24127)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored May 19, 2023
1 parent 6b54591 commit d583ca9
Show file tree
Hide file tree
Showing 54 changed files with 100 additions and 185 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ repos:
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
Expand Down
6 changes: 2 additions & 4 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,7 @@ def screenshot(self, pk: int, digest: str) -> WerkzeugResponse:
return self.response_404()

# fetch the chart screenshot using the current user and cache if set
img = ChartScreenshot.get_from_cache_key(thumbnail_cache, digest)
if img:
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest):
return Response(
FileWrapper(img), mimetype="image/png", direct_passthrough=True
)
Expand Down Expand Up @@ -783,7 +782,6 @@ def export(self, **kwargs: Any) -> Response:
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"chart_export_{timestamp}"
Expand All @@ -805,7 +803,7 @@ def export(self, **kwargs: Any) -> Response:
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

Expand Down
3 changes: 1 addition & 2 deletions superset/charts/commands/bulk_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def validate(self) -> None:
if not self._models or len(self._models) != len(self._model_ids):
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_chart_ids(self._model_ids)
if reports:
if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids):
report_names = [report.name for report in reports]
raise ChartBulkDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/charts/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def validate(self) -> None:
if not self._model:
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_chart_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_chart_id(self._model_id):
report_names = [report.name for report in reports]
raise ChartDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ def get_query_results(
:raises QueryObjectValidationError: if an unsupported result type is requested
:return: JSON serializable result payload
"""
result_func = _result_type_functions.get(result_type)
if result_func:
if result_func := _result_type_functions.get(result_type):
return result_func(query_context, query_obj, force_cached)
raise QueryObjectValidationError(
_("Invalid result type: %(result_type)s", result_type=result_type)
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ def _apply_granularity(
for column in datasource.columns
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
}
granularity = query_object.granularity
x_axis = form_data and form_data.get("x_axis")

if granularity:
if granularity := query_object.granularity:
filter_to_remove = None
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ def get_payload(
return return_value

def get_cache_timeout(self) -> int:
cache_timeout_rv = self._query_context.get_cache_timeout()
if cache_timeout_rv:
if cache_timeout_rv := self._query_context.get_cache_timeout():
return cache_timeout_rv
if (
data_cache_timeout := config["DATA_CACHE_CONFIG"].get(
Expand Down
3 changes: 1 addition & 2 deletions superset/common/utils/query_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def get(
if not key or not _cache[region] or force_query:
return query_cache

cache_value = _cache[region].get(key)
if cache_value:
if cache_value := _cache[region].get(key):
logger.debug("Cache key: %s", key)
stats_logger.incr("loading_from_cache")
try:
Expand Down
3 changes: 1 addition & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,11 +993,10 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
schema=self.schema,
template_processor=template_processor,
)
col_in_metadata = self.get_column(expression)
time_grain = col.get("timeGrain")
has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain
is_dttm = False
if col_in_metadata:
if col_in_metadata := self.get_column(expression):
sqla_column = col_in_metadata.get_sqla_col(
template_processor=template_processor
)
Expand Down
3 changes: 1 addition & 2 deletions superset/dashboards/commands/bulk_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def validate(self) -> None:
if not self._models or len(self._models) != len(self._model_ids):
raise DashboardNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_dashboard_ids(self._model_ids)
if reports:
if reports := ReportScheduleDAO.find_by_dashboard_ids(self._model_ids):
report_names = [report.name for report in reports]
raise DashboardBulkDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/dashboards/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def validate(self) -> None:
if not self._model:
raise DashboardNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_dashboard_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_dashboard_id(self._model_id):
report_names = [report.name for report in reports]
raise DashboardDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/dashboards/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,10 @@ def set_dash_metadata( # pylint: disable=too-many-locals
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
commit: bool = False,
) -> Dashboard:
positions = data.get("positions")
new_filter_scopes = {}
md = dashboard.params_dict

if positions is not None:
if (positions := data.get("positions")) is not None:
# find slices in the position data
slice_ids = [
value.get("meta", {}).get("chartId")
Expand Down
3 changes: 1 addition & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,6 @@ def export(self, **kwargs: Any) -> Response:
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"database_export_{timestamp}"
Expand All @@ -1060,7 +1059,7 @@ def export(self, **kwargs: Any) -> Response:
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

Expand Down
3 changes: 1 addition & 2 deletions superset/databases/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def validate(self) -> None:
if not self._model:
raise DatabaseNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_database_id(self._model_id)

if reports:
if reports := ReportScheduleDAO.find_by_database_id(self._model_id):
report_names = [report.name for report in reports]
raise DatabaseDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,5 @@ def ping(engine: Engine) -> bool:
raise DatabaseTestConnectionUnexpectedError(errors) from ex

def validate(self) -> None:
database_name = self._properties.get("database_name")
if database_name is not None:
if (database_name := self._properties.get("database_name")) is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)
3 changes: 1 addition & 2 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,5 @@ def run(self) -> None:
)

def validate(self) -> None:
database_id = self._properties.get("id")
if database_id is not None:
if (database_id := self._properties.get("id")) is not None:
self._model = DatabaseDAO.find_by_id(database_id)
3 changes: 1 addition & 2 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,7 @@ def get_or_create_dataset(self) -> Response:
return self.response(400, message=ex.messages)
table_name = body["table_name"]
database_id = body["database_id"]
table = DatasetDAO.get_table_by_name(database_id, table_name)
if table:
if table := DatasetDAO.get_table_by_name(database_id, table_name):
return self.response(200, result={"table_id": table.id})

body["database"] = database_id
Expand Down
3 changes: 1 addition & 2 deletions superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
if native_type.upper() in type_map:
return type_map[native_type.upper()]

match = VARCHAR.match(native_type)
if match:
if match := VARCHAR.match(native_type):
size = int(match.group(1))
return String(size)

Expand Down
6 changes: 2 additions & 4 deletions superset/datasets/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,11 @@ def validate(self) -> None:
exceptions.append(DatasetEndpointUnsafeValidationError())

# Validate columns
columns = self._properties.get("columns")
if columns:
if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions)

# Validate metrics
metrics = self._properties.get("metrics")
if metrics:
if metrics := self._properties.get("metrics"):
self._validate_metrics(metrics, exceptions)

if exceptions:
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,8 +1704,7 @@ def get_column_spec( # pylint: disable=unused-argument
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
col_types = cls.get_column_types(native_type)
if col_types:
if col_types := cls.get_column_types(native_type):
column_type, generic_type = col_types
is_dttm = generic_type == GenericDataType.TEMPORAL
return ColumnSpec(
Expand Down Expand Up @@ -1996,9 +1995,8 @@ def validate_parameters(
required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,8 @@ def df_to_sql(
}

# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info

if creds:
if creds := engine.dialect.credentials_info:
to_gbq_kwargs[
"credentials"
] = service_account.Credentials.from_service_account_info(creds)
Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,8 @@ def validate_parameters( # type: ignore
parameters["http_path"] = connect_args.get("http_path")

present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,8 +1213,7 @@ def extra_table_metadata(
) -> Dict[str, Any]:
metadata = {}

indexes = database.get_indexes(table_name, schema_name)
if indexes:
if indexes := database.get_indexes(table_name, schema_name):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
Expand Down Expand Up @@ -1278,8 +1277,7 @@ def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
"""Updates progress information"""
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
session.commit()

Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,8 @@ def validate_parameters(
}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def extra_table_metadata(
) -> Dict[str, Any]:
metadata = {}

indexes = database.get_indexes(table_name, schema_name)
if indexes:
if indexes := database.get_indexes(table_name, schema_name):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
Expand Down Expand Up @@ -150,8 +149,7 @@ def get_tracking_url(cls, cursor: Cursor) -> Optional[str]:

@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url

# Adds the executed query id to the extra payload so the query can be cancelled
Expand Down
3 changes: 1 addition & 2 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def __post_init__(self) -> None:
Mutates the extra params with user facing error codes that map to backend
errors.
"""
issue_codes = ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type)
if issue_codes:
if issue_codes := ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type):
self.extra = self.extra or {}
self.extra.update(
{
Expand Down
3 changes: 1 addition & 2 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,7 @@ def init_app_in_ctx(self) -> None:

# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = self.config["FLASK_APP_MUTATOR"]
if flask_app_mutator:
if flask_app_mutator := self.config["FLASK_APP_MUTATOR"]:
flask_app_mutator(self.superset_app)

if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
Expand Down
5 changes: 2 additions & 3 deletions superset/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,15 @@ def process_revision_directives( # pylint: disable=redefined-outer-name, unused
kwargs = {}
if engine.name in ("sqlite", "mysql"):
kwargs = {"transaction_per_migration": True, "transactional_ddl": True}
configure_args = current_app.extensions["migrate"].configure_args
if configure_args:
if configure_args := current_app.extensions["migrate"].configure_args:
kwargs.update(configure_args)

context.configure(
connection=connection,
target_metadata=target_metadata,
# compare_type=True,
process_revision_directives=process_revision_directives,
**kwargs
**kwargs,
)

try:
Expand Down
6 changes: 2 additions & 4 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,15 @@ def upgrade_slice(cls, slc: Slice) -> Slice:
# only backup params
slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak})

query_context = try_load_json(slc.query_context)
if "form_data" in query_context:
if "form_data" in (query_context := try_load_json(slc.query_context)):
query_context["form_data"] = clz.data
slc.query_context = json.dumps(query_context)
return slc

@classmethod
def downgrade_slice(cls, slc: Slice) -> Slice:
form_data = try_load_json(slc.params)
form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {})
if "viz_type" in form_data_bak:
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
slc.params = json.dumps(form_data_bak)
slc.viz_type = form_data_bak.get("viz_type")
query_context = try_load_json(slc.query_context)
Expand Down
6 changes: 2 additions & 4 deletions superset/migrations/shared/migrate_viz/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ def _pre_action(self) -> None:
if self.data.get("contribution"):
self.data["contributionMode"] = "row"

stacked = self.data.get("stacked_style")
if stacked:
if stacked := self.data.get("stacked_style"):
stacked_map = {
"expand": "Expand",
"stack": "Stack",
}
self.data["show_extra_controls"] = True
self.data["stack"] = stacked_map.get(stacked)

x_axis_label = self.data.get("x_axis_label")
if x_axis_label:
if x_axis_label := self.data.get("x_axis_label"):
self.data["x_axis_title"] = x_axis_label
self.data["x_axis_title_margin"] = 30
Loading

0 comments on commit d583ca9

Please sign in to comment.