Skip to content

Commit

Permalink
refactor: sql lab: handling command exceptions (#16852)
Browse files Browse the repository at this point in the history
* chore: support error_type in SupersetException and method to convert the exception to dictionary

* chore: support error_type in SupersetException and method to convert the exception to dictionary

* refactor handling command exceptions   fix update query status when query was not created
  • Loading branch information
ofekisr authored Sep 29, 2021
1 parent 3d8cc15 commit 3f784cc
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 70 deletions.
6 changes: 6 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,9 @@ def __post_init__(self) -> None:
]
}
)

def to_dict(self) -> Dict[str, Any]:
rv = {"message": self.message, "error_type": self.error_type}
if self.extra:
rv["extra"] = self.extra # type: ignore
return rv
23 changes: 22 additions & 1 deletion superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,35 @@ class SupersetException(Exception):
message = ""

def __init__(
self, message: str = "", exception: Optional[Exception] = None,
self,
message: str = "",
exception: Optional[Exception] = None,
error_type: Optional[SupersetErrorType] = None,
) -> None:
if message:
self.message = message
self._exception = exception
self._error_type = error_type
super().__init__(self.message)

@property
def exception(self) -> Optional[Exception]:
return self._exception

@property
def error_type(self) -> Optional[SupersetErrorType]:
return self._error_type

def to_dict(self) -> Dict[str, Any]:
rv = {}
if hasattr(self, "message"):
rv["message"] = self.message
if self.error_type:
rv["error_type"] = self.error_type
if self.exception is not None and hasattr(self.exception, "to_dict"):
rv = {**rv, **self.exception.to_dict()} # type: ignore
return rv


class SupersetErrorException(SupersetException):
"""Exceptions with a single SupersetErrorType associated with them"""
Expand All @@ -49,6 +67,9 @@ def __init__(self, error: SupersetError, status: Optional[int] = None) -> None:
if status is not None:
self.status = status

def to_dict(self) -> Dict[str, Any]:
return self.error.to_dict()


class SupersetGenericErrorException(SupersetErrorException):
"""Exceptions that are too generic to have their own type"""
Expand Down
123 changes: 61 additions & 62 deletions superset/sqllab/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from superset.models.sql_lab import Query
from superset.queries.dao import QueryDAO
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.utils import core as utils
Expand All @@ -68,18 +69,18 @@


class ExecuteSqlCommand(BaseCommand):
execution_context: SqlJsonExecutionContext
log_params: Optional[Dict[str, Any]] = None
session: Session
_execution_context: SqlJsonExecutionContext
_log_params: Optional[Dict[str, Any]] = None
_session: Session

def __init__(
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
) -> None:
self.execution_context = execution_context
self.log_params = log_params
self.session = db.session()
self._execution_context = execution_context
self._log_params = log_params
self._session = db.session()

def validate(self) -> None:
pass
Expand All @@ -88,30 +89,29 @@ def run( # pylint: disable=too-many-statements,useless-suppression
self,
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
try:
query = self._get_existing_query()
if self.is_query_handled(query):
self._execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()
return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}
except (SqlLabException, SupersetErrorsException) as ex:
raise ex
except Exception as ex:
raise SqlLabException(self._execution_context, exception=ex) from ex

query = self._get_existing_query(self.execution_context, self.session)

if self.is_query_handled(query):
self.execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()

return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}

@classmethod
def _get_existing_query(
cls, execution_context: SqlJsonExecutionContext, session: Session
) -> Optional[Query]:
def _get_existing_query(self) -> Optional[Query]:
query = (
session.query(Query)
self._session.query(Query)
.filter_by(
client_id=execution_context.client_id,
user_id=execution_context.user_id,
sql_editor_id=execution_context.sql_editor_id,
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
sql_editor_id=self._execution_context.sql_editor_id,
)
.one_or_none()
)
Expand All @@ -126,25 +126,24 @@ def is_query_handled(cls, query: Optional[Query]) -> bool:
]

def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
self.execution_context.set_database(self._get_the_query_db())
query = self.execution_context.create_query()
self._execution_context.set_database(self._get_the_query_db())
query = self._execution_context.create_query()
self._save_new_query(query)
try:
self._save_new_query(query)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query)
self.execution_context.set_query(query)
self._execution_context.set_query(query)
rendered_query = self._render_query()

self._set_query_limit_if_required(rendered_query)

return self._execute_query(rendered_query)
except Exception as ex:
query.status = QueryStatus.FAILED
self.session.commit()
self._session.commit()
raise ex

def _get_the_query_db(self) -> Database:
mydb = self.session.query(Database).get(self.execution_context.database_id)
mydb = self._session.query(Database).get(self._execution_context.database_id)
self._validate_query_db(mydb)
return mydb

Expand All @@ -160,12 +159,12 @@ def _validate_query_db(cls, database: Optional[Database]) -> None:

def _save_new_query(self, query: Query) -> None:
try:
self.session.add(query)
self.session.flush()
self.session.commit() # shouldn't be necessary
self._session.add(query)
self._session.flush()
self._session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex), exc_info=True)
self.session.rollback()
self._session.rollback()
if not query.id:
raise SupersetGenericErrorException(
__(
Expand All @@ -181,7 +180,7 @@ def _validate_access(self, query: Query) -> None:
query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)])
query.status = QueryStatus.FAILED
query.error_message = ex.error.message
self.session.commit()
self._session.commit()
raise SupersetErrorException(ex.error, status=403) from ex

def _render_query(self) -> str:
Expand All @@ -205,18 +204,18 @@ def validate(
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": self.execution_context.template_params,
"template_parameters": self._execution_context.template_params,
},
)

query = self.execution_context.query
query = self._execution_context.query

try:
template_processor = get_template_processor(
database=query.database, query=query
)
rendered_query = template_processor.process_template(
query.sql, **self.execution_context.template_params
query.sql, **self._execution_context.template_params
)
validate(rendered_query, template_processor)
except TemplateError as ex:
Expand All @@ -235,32 +234,32 @@ def _set_query_limit_if_required(self, rendered_query: str,) -> None:

def _is_required_to_set_limit(self) -> bool:
return not (
config.get("SQLLAB_CTAS_NO_LIMIT") and self.execution_context.select_as_cta
config.get("SQLLAB_CTAS_NO_LIMIT") and self._execution_context.select_as_cta
)

def _set_query_limit(self, rendered_query: str) -> None:
db_engine_spec = self.execution_context.database.db_engine_spec # type: ignore
db_engine_spec = self._execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
self.execution_context.limit,
self._execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
self._execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.QUERY
self._execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
self.execution_context.query.limiting_factor = (
self._execution_context.query.limiting_factor = (
LimitingFactor.QUERY_AND_DROPDOWN
)
self.execution_context.query.limit = min(
self._execution_context.query.limit = min(
lim for lim in limits if lim is not None
)

def _execute_query(self, rendered_query: str,) -> SqlJsonExecutionStatus:
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
# Async request.
if self.execution_context.is_run_asynchronous():
if self._execution_context.is_run_asynchronous():
return self._sql_json_async(rendered_query)

return self._sql_json_sync(rendered_query)
Expand All @@ -271,7 +270,7 @@ def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
:param rendered_query: the rendered query to perform by workers
:return: A Flask Response
"""
query = self.execution_context.query
query = self._execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
Expand All @@ -285,8 +284,8 @@ def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)

# Explicitly forget the task to ensure the task metadata is removed from the
Expand All @@ -312,14 +311,14 @@ def _sql_json_async(self, rendered_query: str,) -> SqlJsonExecutionStatus:
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
self.session.commit()
self._session.commit()

raise SupersetErrorException(error) from ex

# Update saved query with execution info from the query execution
QueryDAO.update_saved_query_exec_info(query_id)

self.session.commit()
self._session.commit()
return SqlJsonExecutionStatus.QUERY_IS_RUNNING

def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
Expand All @@ -329,7 +328,7 @@ def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
:param rendered_query: The rendered query (included templates)
:raises: SupersetTimeoutException
"""
query = self.execution_context.query
query = self._execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
Expand All @@ -339,7 +338,7 @@ def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
self.execution_context.set_execution_result(data)
self._execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
Expand All @@ -362,7 +361,7 @@ def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
def _get_sql_results_with_timeout(
self, timeout: int, rendered_query: str, timeout_msg: str,
) -> Optional[SqlResults]:
query = self.execution_context.query
query = self._execution_context.query
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
Expand All @@ -373,8 +372,8 @@ def _get_sql_results_with_timeout(
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)

@classmethod
Expand All @@ -389,9 +388,9 @@ def _create_payload_from_execution_context( # pylint: disable=invalid-name

if status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
self.execution_context.get_execution_result() or {}
self._execution_context.get_execution_result() or {}
)
return self._to_payload_query_based(self.execution_context.query)
return self._to_payload_query_based(self._execution_context.query)

def _to_payload_results_based( # pylint: disable=no-self-use
self, execution_result: SqlResults
Expand Down
Loading

0 comments on commit 3f784cc

Please sign in to comment.