Skip to content

Commit

Permalink
chore: use contextlib.surpress instead of passing on error (#24896)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <4567245+john-bodley@users.noreply.github.com>
  • Loading branch information
sebastianliebscher and john-bodley authored Aug 29, 2023
1 parent 72150eb commit e585db8
Show file tree
Hide file tree
Showing 18 changed files with 66 additions and 146 deletions.
12 changes: 3 additions & 9 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import contextlib
import json
import logging
from typing import Any, TYPE_CHECKING
Expand Down Expand Up @@ -223,11 +224,8 @@ def data(self) -> Response:
json_body = request.json
elif request.form.get("form_data"):
# CSV export submits regular form data
try:
with contextlib.suppress(TypeError, json.JSONDecodeError):
json_body = json.loads(request.form["form_data"])
except (TypeError, json.JSONDecodeError):
pass

if json_body is None:
return self.response_400(message=_("Request is not JSON"))

Expand Down Expand Up @@ -324,14 +322,10 @@ def _run_async(
Execute command as an async query.
"""
# First, look for the chart query results in the cache.
result = None
try:
with contextlib.suppress(ChartDataCacheLoadError):
result = command.run(force_cached=True)
if result is not None:
return self._send_chart_response(result)
except ChartDataCacheLoadError:
pass

# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint
Expand Down
27 changes: 6 additions & 21 deletions superset/common/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
from typing import Any

from sqlalchemy import MetaData
Expand Down Expand Up @@ -221,14 +222,8 @@ def add_types(metadata: MetaData) -> None:
# add a tag for each object type
insert = tag.insert()
for type_ in ObjectTypes.__members__:
try:
db.session.execute(
insert,
name=f"type:{type_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"type:{type_}", type=TagTypes.type)

add_types_to_charts(metadata, tag, tagged_object, columns)
add_types_to_dashboards(metadata, tag, tagged_object, columns)
Expand Down Expand Up @@ -448,11 +443,8 @@ def add_owners(metadata: MetaData) -> None:
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in db.session.execute(ids):
try:
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"owner:{id_}", type=TagTypes.owner)
except IntegrityError:
pass # already exists

add_owners_to_charts(metadata, tag, tagged_object, columns)
add_owners_to_dashboards(metadata, tag, tagged_object, columns)
add_owners_to_saved_queries(metadata, tag, tagged_object, columns)
Expand Down Expand Up @@ -489,15 +481,8 @@ def add_favorites(metadata: MetaData) -> None:
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in db.session.execute(ids):
try:
db.session.execute(
insert,
name=f"favorited_by:{id_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists

with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"favorited_by:{id_}", type=TagTypes.type)
favstars = (
select(
[
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import json
import re
import urllib
Expand Down Expand Up @@ -557,11 +558,8 @@ def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]:
except (json.JSONDecodeError, TypeError):
return encrypted_extra

try:
with contextlib.suppress(KeyError):
config["credentials_info"]["private_key"] = PASSWORD_MASK
except KeyError:
pass

return json.dumps(config)

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import contextlib
import json
import logging
import re
Expand Down Expand Up @@ -167,11 +168,8 @@ def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None:
except (TypeError, json.JSONDecodeError):
return encrypted_extra

try:
with contextlib.suppress(KeyError):
config["service_account_info"]["private_key"] = PASSWORD_MASK
except KeyError:
pass

return json.dumps(config)

@classmethod
Expand Down
5 changes: 2 additions & 3 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import re
from datetime import datetime
from re import Pattern
Expand Down Expand Up @@ -258,11 +259,9 @@ def epoch_to_dttm(cls) -> str:
def _extract_error_message(cls, ex: Exception) -> str:
"""Extract error message for queries"""
message = str(ex)
try:
with contextlib.suppress(AttributeError, KeyError):
if isinstance(ex.args, tuple) and len(ex.args) > 1:
message = ex.args[1]
except (AttributeError, KeyError):
pass
return message

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/ocient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import contextlib
import re
import threading
from re import Pattern
Expand All @@ -24,8 +25,7 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session

# Need to try-catch here because pyocient may not be installed
try:
with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be installed
# Ensure pyocient inherits Superset's logging level
import geojson
import pyocient
Expand All @@ -35,8 +35,6 @@

superset_log_level = app.config["LOG_LEVEL"]
pyocient.logger.setLevel(superset_log_level)
except (ImportError, RuntimeError):
pass

from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
Expand Down
10 changes: 3 additions & 7 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=too-many-lines
from __future__ import annotations

import contextlib
import logging
import re
import time
Expand Down Expand Up @@ -67,11 +68,8 @@
# prevent circular imports
from superset.models.core import Database

# need try/catch because pyhive may not be installed
try:
with contextlib.suppress(ImportError): # pyhive may not be installed
from pyhive.presto import Cursor
except ImportError:
pass

COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
"line (?P<location>.+?): .*Column '(?P<column_name>.+?)' cannot be resolved"
Expand Down Expand Up @@ -1274,12 +1272,10 @@ def get_create_view(

@classmethod
def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
with contextlib.suppress(AttributeError):
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
return f"{cursor._protocol}://{cursor._host}:{cursor._port}/ui/query.html?{cursor.last_query_id}"
except AttributeError:
pass
return None

@classmethod
Expand Down
9 changes: 3 additions & 6 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import contextlib
import logging
from typing import Any, TYPE_CHECKING

Expand All @@ -35,10 +36,8 @@
if TYPE_CHECKING:
from superset.models.core import Database

try:
with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
except ImportError:
pass

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -140,12 +139,10 @@ def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
return cursor.info_uri
except AttributeError:
try:
with contextlib.suppress(AttributeError):
conn = cursor.connection
# pylint: disable=protected-access, line-too-long
return f"{conn.http_scheme}://{conn.host}:{conn.port}/ui/query.html?{cursor._query.query_id}"
except AttributeError:
pass
return None

@classmethod
Expand Down
7 changes: 3 additions & 4 deletions superset/explore/commands/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import logging
from abc import ABC
from typing import Any, cast, Optional
Expand Down Expand Up @@ -107,17 +108,15 @@ def run(self) -> Optional[dict[str, Any]]:
)
except SupersetException:
self._datasource_id = None
# fallback unkonw datasource to table type
# fallback unknown datasource to table type
self._datasource_type = SqlaTable.type

datasource: Optional[BaseDatasource] = None
if self._datasource_id is not None:
try:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
db.session, cast(str, self._datasource_type), self._datasource_id
)
except DatasourceNotFound:
pass
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint:
Expand Down
8 changes: 3 additions & 5 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import contextlib
import logging
import os
import sys
Expand All @@ -25,7 +26,7 @@
from deprecation import deprecated
from flask import Flask, redirect
from flask_appbuilder import expose, IndexView
from flask_babel import gettext as __, lazy_gettext as _
from flask_babel import gettext as __
from flask_compress import Compress
from werkzeug.middleware.proxy_fix import ProxyFix

Expand Down Expand Up @@ -594,11 +595,8 @@ def __call__(
self.superset_app.wsgi_app = ChunkedEncodingFix(self.superset_app.wsgi_app)

if self.config["UPLOAD_FOLDER"]:
try:
with contextlib.suppress(OSError):
os.makedirs(self.config["UPLOAD_FOLDER"])
except OSError:
pass

for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
self.superset_app.wsgi_app = middleware(self.superset_app.wsgi_app)

Expand Down
18 changes: 5 additions & 13 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import logging
import textwrap
from ast import literal_eval
from contextlib import closing, contextmanager, nullcontext
from contextlib import closing, contextmanager, nullcontext, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
Expand Down Expand Up @@ -225,7 +225,6 @@ def allows_cost_estimate(self) -> bool:
@property
def allows_virtual_table_explore(self) -> bool:
extra = self.get_extra()

return bool(extra.get("allows_virtual_table_explore", True))

@property
Expand All @@ -235,9 +234,7 @@ def explore_database_id(self) -> int:
@property
def disable_data_preview(self) -> bool:
# this will prevent any 'trash value' strings from going through
if self.get_extra().get("disable_data_preview", False) is not True:
return False
return True
return self.get_extra().get("disable_data_preview", False) is True

@property
def data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -285,11 +282,8 @@ def parameters(self) -> dict[str, Any]:
masked_uri = make_url_safe(self.sqlalchemy_uri)
encrypted_config = {}
if (masked_encrypted_extra := self.masked_encrypted_extra) is not None:
try:
with suppress(TypeError, json.JSONDecodeError):
encrypted_config = json.loads(masked_encrypted_extra)
except (TypeError, json.JSONDecodeError):
pass

try:
# pylint: disable=useless-suppression
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
Expand Down Expand Up @@ -550,7 +544,7 @@ def get_default_schema_for_query(self, query: Query) -> str | None:

@property
def quote_identifier(self) -> Callable[[str], str]:
"""Add quotes to potential identifiter expressions if needed"""
"""Add quotes to potential identifier expressions if needed"""
return self.get_dialect().identifier_preparer.quote

def get_reserved_words(self) -> set[str]:
Expand Down Expand Up @@ -692,15 +686,14 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
"""
try:
with self.get_inspector_with_context() as inspector:
tables = {
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=inspector,
schema=schema,
)
}
return tables
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

Expand Down Expand Up @@ -985,7 +978,6 @@ def make_sqla_column_compatible(


class Log(Model): # pylint: disable=too-few-public-methods

"""ORM object used to log Superset actions to the database"""

__tablename__ = "logs"
Expand Down
Loading

0 comments on commit e585db8

Please sign in to comment.