Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(chart-data-api): assert referenced columns are present in datasource #10451

Merged
merged 5 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ class ChartDataQueryObjectSchema(Schema):
deprecated=True,
)
having_filters = fields.List(
fields.Dict(),
fields.Nested(ChartDataFilterSchema),
description="HAVING filters to be added to legacy Druid datasource queries. "
"This field is deprecated and should be passed to `extras` "
"as `having_druid`.",
Expand Down
16 changes: 16 additions & 0 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import pandas as pd
from flask_babel import gettext as _

from superset import app, cache, db, security_manager
from superset.common.query_object import QueryObject
Expand Down Expand Up @@ -235,6 +236,21 @@ def get_df_payload( # pylint: disable=too-many-locals,too-many-statements

if query_obj and not is_loaded:
try:
invalid_columns = [
col
for col in query_obj.columns
+ query_obj.groupby
+ [flt["col"] for flt in query_obj.filter]
+ utils.get_column_names_from_metrics(query_obj.metrics)
if col not in self.datasource.column_names
]
if invalid_columns:
raise QueryObjectValidationError(
_(
"Columns missing in datasource: %(invalid_columns)s",
invalid_columns=invalid_columns,
)
)
query_result = self.get_query_result(query_obj)
status = query_result["status"]
query = query_result["query"]
Expand Down
17 changes: 15 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
changed_on = None
type = "annotation"
column_names = [
"created_on",
"changed_on",
"id",
"start_dttm",
"end_dttm",
"layer_id",
"short_descr",
"long_descr",
"json_metadata",
"created_by_fk",
"changed_by_fk",
]
Comment on lines +93 to +105
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These had to be added, as AnnotationDatasources don't have any defined columns.


def query(self, query_obj: QueryObjectDict) -> QueryResult:
error_message = None
Expand Down Expand Up @@ -721,15 +734,15 @@ def adhoc_metric_to_sqla(
expression_type = metric.get("expressionType")
label = utils.get_metric_name(metric)

if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
column_name = metric["column"].get("column_name")
table_column = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col()
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
elif expression_type == utils.AdhocMetricExpressionType.SQL:
sqla_metric = literal_column(metric.get("sqlExpression"))
else:
return None
Expand Down
57 changes: 48 additions & 9 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -102,7 +103,6 @@
logger = logging.getLogger(__name__)

DTTM_ALIAS = "__timestamp"
ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"}

JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1

Expand Down Expand Up @@ -1030,20 +1030,23 @@ def get_main_database() -> "Database":


def is_adhoc_metric(metric: Metric) -> bool:
if not isinstance(metric, dict):
return False
metric = cast(Dict[str, Any], metric)
return bool(
isinstance(metric, dict)
and (
(
(
metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]
and metric["column"]
and metric["aggregate"]
metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
and metric.get("column")
and cast(Dict[str, Any], metric["column"]).get("column_name")
and metric.get("aggregate")
)
or (
metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"]
and metric["sqlExpression"]
metric.get("expressionType") == AdhocMetricExpressionType.SQL
and metric.get("sqlExpression")
)
)
and metric["label"]
and metric.get("label")
)


Expand Down Expand Up @@ -1390,6 +1393,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]


def get_column_name_from_metric(metric: Metric) -> Optional[str]:
"""
Extract the column that a metric is referencing. If the metric isn't
a simple metric, always returns `None`.

:param metric: Ad-hoc metric
:return: column name if simple metric, otherwise None
"""
if is_adhoc_metric(metric):
metric = cast(Dict[str, Any], metric)
if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
return cast(Dict[str, Any], metric["column"])["column_name"]
return None


def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
"""
Extract the columns that a list of metrics are referencing. Expcludes all
SQL metrics.

:param metrics: Ad-hoc metric
:return: column name if simple metric, otherwise None
"""
columns: List[str] = []
for metric in metrics:
column_name = get_column_name_from_metric(metric)
if column_name:
columns.append(column_name)
return columns


class LenientEnum(Enum):
"""Enums that do not raise ValueError when value is invalid"""

Expand Down Expand Up @@ -1515,3 +1549,8 @@ class PostProcessingContributionOrientation(str, Enum):

ROW = "row"
COLUMN = "column"


class AdhocMetricExpressionType(str, Enum):
SIMPLE = "SIMPLE"
SQL = "SQL"
18 changes: 18 additions & 0 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,24 @@ def get_df_payload(

if query_obj and not is_loaded:
try:
invalid_columns = [
col
for col in (query_obj.get("columns") or [])
+ (query_obj.get("groupby") or [])
+ utils.get_column_names_from_metrics(
cast(
List[Union[str, Dict[str, Any]]], query_obj.get("metrics"),
)
)
if col not in self.datasource.column_names
]
if invalid_columns:
raise QueryObjectValidationError(
_(
"Columns missing in datasource: %(invalid_columns)s",
invalid_columns=invalid_columns,
)
)
df = self.get_df(query_obj)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr("loaded_from_source")
Expand Down
19 changes: 19 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,25 @@ def test_explore_database_id(self):
database.extra = json.dumps(extra)
self.assertEqual(database.explore_database_id, explore_database.id)

def test_get_column_names_from_metric(self):
simple_metric = {
"expressionType": utils.AdhocMetricExpressionType.SIMPLE.value,
"column": {"column_name": "my_col"},
"aggregate": "SUM",
"label": "My Simple Label",
}
assert utils.get_column_name_from_metric(simple_metric) == "my_col"

sql_metric = {
"expressionType": utils.AdhocMetricExpressionType.SQL.value,
"sqlExpression": "SUM(my_label)",
"label": "My SQL Label",
}
assert utils.get_column_name_from_metric(sql_metric) is None
assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [
"my_col"
]


if __name__ == "__main__":
unittest.main()
91 changes: 79 additions & 12 deletions tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import tests.test_app
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import (
AdhocMetricExpressionType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator,
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_cache_key_changes_when_datasource_is_updated(self):
payload = get_query_context(table.name, table.id, table.type)

# construct baseline cache_key
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)

Expand All @@ -92,7 +93,7 @@ def test_cache_key_changes_when_datasource_is_updated(self):
db.session.commit()

# create new QueryContext with unchanged attributes and extract new cache_key
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)

Expand All @@ -108,20 +109,20 @@ def test_cache_key_changes_when_post_processing_is_updated(self):
)

# construct baseline cache_key from query_context with post processing operation
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)

# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)

# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
Expand All @@ -136,7 +137,7 @@ def test_query_context_time_range_endpoints(self):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
del payload["queries"][0]["extras"]["time_range_endpoints"]
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
Expand All @@ -155,8 +156,8 @@ def test_convert_deprecated_fields(self):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
query_context = QueryContext(**payload)
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
Expand All @@ -172,13 +173,79 @@ def test_csv_response_format(self):
payload = get_query_context(table.name, table.id, table.type)
payload["result_format"] = ChartDataResultFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)

def test_sql_injection_via_groupby(self):
"""
Ensure that calling invalid columns names in groupby are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["currentDatabase()"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If currentDatabase() is a defined metric will it run ok?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't, as it isn't an aggregate expression, hence will be missing from the groupby causing an invalid query.

query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is assert preferable to self.assertEqual or self.assertIsNotNone ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the pytest way


def test_sql_injection_via_columns(self):
"""
Ensure that calling invalid columns names in columns are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = []
payload["queries"][0]["metrics"] = []
payload["queries"][0]["columns"] = ["*, 'extra'"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None

def test_sql_injection_via_filters(self):
"""
Ensure that calling invalid columns names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to deny an injected metric?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't been able to create one yet, would be interested to see if someone is able to do one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not easy that's for sure

payload["queries"][0]["filters"] = [
{"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None

def test_sql_injection_via_metrics(self):
"""
Ensure that calling invalid columns names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SIMPLE.value,
"column": {"column_name": "invalid_col"},
"aggregate": "SUM",
"label": "My Simple Label",
}
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None

def test_samples_response_type(self):
"""
Ensure that samples result type works
Expand All @@ -189,7 +256,7 @@ def test_samples_response_type(self):
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
Expand All @@ -206,7 +273,7 @@ def test_query_response_type(self):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]
Expand Down