Skip to content

Commit

Permalink
Refactor template context checking to support engine-specific methods
Browse files Browse the repository at this point in the history
  • Loading branch information
robdiciuccio committed Nov 17, 2020
1 parent a9461e4 commit f7fbd59
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
23 changes: 18 additions & 5 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
return return_value


def validate_template_context(context: Dict[str, Any]) -> Dict[str, Any]:
def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]:
for key in context:
arg_type = type(context[key]).__name__
if arg_type not in ALLOWED_TYPES and key not in context_addons():
Expand All @@ -253,12 +253,25 @@ def validate_template_context(context: Dict[str, Any]) -> Dict[str, Any]:
context[key] = json.loads(json.dumps(context[key]))
except TypeError:
raise SupersetTemplateException(
_("Unsupported template value for key %(key)s", key=key,)
_("Unsupported template value for key %(key)s", key=key)
)

return context


def validate_template_context(
engine: Optional[str], context: Dict[str, Any]
) -> Dict[str, Any]:
if engine and engine in context:
# validate engine context separately to allow for engine-specific methods
engine_context = validate_context_types(context.pop(engine))
valid_context = validate_context_types(context)
valid_context[engine] = engine_context
return valid_context

return validate_context_types(context)


class BaseTemplateProcessor: # pylint: disable=too-few-public-methods
"""
Base class for database-specific jinja context
Expand Down Expand Up @@ -300,7 +313,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
template = self._env.from_string(sql)
kwargs.update(self._context)

context = validate_template_context(kwargs)
context = validate_template_context(self.engine, kwargs)
return template.render(context)


Expand Down Expand Up @@ -404,7 +417,7 @@ class HiveTemplateProcessor(PrestoTemplateProcessor):


@memoized
def template_processors() -> Dict[str, Any]:
def get_template_processors() -> Dict[str, Any]:
processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {})
for engine in DEFAULT_PROCESSORS:
# do not overwrite engine-specific CUSTOM_TEMPLATE_PROCESSORS
Expand All @@ -421,7 +434,7 @@ def get_template_processor(
**kwargs: Any,
) -> BaseTemplateProcessor:
if feature_flag_manager.is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
template_processor = template_processors().get(
template_processor = get_template_processors().get(
database.backend, JinjaTemplateProcessor
)
else:
Expand Down
10 changes: 10 additions & 0 deletions tests/jinja_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ def test_template_kwarg_nested_module(self) -> None:
with pytest.raises(SupersetTemplateException):
tp.process_template(s, foo={"bar": datetime})

@mock.patch("superset.jinja_context.HiveTemplateProcessor.latest_partition")
def test_template_hive(self, lp_mock) -> None:
lp_mock.return_value = "the_latest"
db = mock.Mock()
db.backend = "hive"
s = "{{ hive.latest_partition('my_table') }}"
tp = get_template_processor(database=db)
rendered = tp.process_template(s)
self.assertEqual("the_latest", rendered)

@mock.patch("superset.jinja_context.context_addons")
def test_template_context_addons(self, addons_mock) -> None:
addons_mock.return_value = {"datetime": datetime}
Expand Down

0 comments on commit f7fbd59

Please sign in to comment.