From 5659c87ed2da1ebafe3578cac9c3c52aeb256c5d Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 30 Oct 2023 09:50:44 -0400 Subject: [PATCH] fix: DB-specific quoting in Jinja macro (#25779) --- superset/jinja_context.py | 45 ++++++++++++++++++-------- tests/unit_tests/jinja_context_test.py | 9 ++++-- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 71ebf0d29a46e..c159a667ee4ed 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -25,6 +25,7 @@ from jinja2 import DebugUndefined from jinja2.sandbox import SandboxedEnvironment from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql.expression import bindparam from sqlalchemy.types import String from superset.constants import LRU_CACHE_MAX_SIZE @@ -396,23 +397,39 @@ def validate_template_context( return validate_context_types(context) -def where_in(values: list[Any], mark: str = "'") -> str: - """ - Given a list of values, build a parenthesis list suitable for an IN expression. +class WhereInMacro: # pylint: disable=too-few-public-methods + def __init__(self, dialect: Dialect): + self.dialect = dialect - >>> where_in([1, "b", 3]) - (1, 'b', 3) + def __call__(self, values: list[Any], mark: Optional[str] = None) -> str: + """ + Given a list of values, build a parenthesis list suitable for an IN expression. - """ + >>> from sqlalchemy.dialects import mysql + >>> where_in = WhereInMacro(dialect=mysql.dialect()) + >>> where_in([1, "Joe's", 3]) + (1, 'Joe''s', 3) - def quote(value: Any) -> str: - if isinstance(value, str): - value = value.replace(mark, mark * 2) - return f"{mark}{value}{mark}" - return str(value) + """ + binds = [bindparam(f"value_{i}", value) for i, value in enumerate(values)] + string_representations = [ + str( + bind.compile( + dialect=self.dialect, compile_kwargs={"literal_binds": True} + ) + ) + for bind in binds + ] + joined_values = ", ".join(string_representations) + result = f"({joined_values})" + + if mark: + result += ( + "\n-- WARNING: the `mark` parameter was removed from the `where_in` " + "macro for security reasons\n" + ) - joined_values = ", ".join(quote(value) for value in values) - return f"({joined_values})" + return result class BaseTemplateProcessor: @@ -448,7 +465,7 @@ def __init__( self.set_context(**kwargs) # custom filters - self._env.filters["where_in"] = where_in + self._env.filters["where_in"] = WhereInMacro(database.get_dialect()) def set_context(self, **kwargs: Any) -> None: self._context.update(kwargs) diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index fe4b144d2fd7a..114f046300169 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -20,17 +20,22 @@ import pytest from pytest_mock import MockFixture +from sqlalchemy.dialects import mysql from superset.datasets.commands.exceptions import DatasetNotFoundError -from superset.jinja_context import dataset_macro, where_in +from superset.jinja_context import dataset_macro, WhereInMacro def test_where_in() -> None: """ Test the ``where_in`` Jinja2 filter. """ + where_in = WhereInMacro(mysql.dialect()) assert where_in([1, "b", 3]) == "(1, 'b', 3)" - assert where_in([1, "b", 3], '"') == '(1, "b", 3)' + assert where_in([1, "b", 3], '"') == ( + "(1, 'b', 3)\n-- WARNING: the `mark` parameter was removed from the " + "`where_in` macro for security reasons\n" + ) assert where_in(["O'Malley's"]) == "('O''Malley''s')"