diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 59aa9774366..1491e60018d 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -239,30 +239,33 @@ def __init__(self, model, context, overrides): def pretty_dict(self, data): return json.dumps(data, sort_keys=True, indent=4) + def get_missing_var(self, var_name): + pretty_vars = self.pretty_dict(self.local_vars) + msg = self.UndefinedVarError.format( + var_name, self.model_name, pretty_vars + ) + dbt.exceptions.raise_compiler_error(msg, self.model) + def assert_var_defined(self, var_name, default): if var_name not in self.local_vars and default is self._VAR_NOTSET: - pretty_vars = self.pretty_dict(self.local_vars) - dbt.exceptions.raise_compiler_error( - self.UndefinedVarError.format( - var_name, self.model_name, pretty_vars - ), - self.model - ) - - def __call__(self, var_name, default=_VAR_NOTSET): - self.assert_var_defined(var_name, default) - - if var_name not in self.local_vars: - return default + return self.get_missing_var(var_name) + def get_rendered_var(self, var_name): raw = self.local_vars[var_name] - # if bool/int/float/etc are passed in, don't compile anything if not isinstance(raw, basestring): return raw return dbt.clients.jinja.get_rendered(raw, self.context) + def __call__(self, var_name, default=_VAR_NOTSET): + if var_name in self.local_vars: + return self.get_rendered_var(var_name) + elif default is not self._VAR_NOTSET: + return default + else: + return self.get_missing_var(var_name) + def write(node, target_path, subdirectory): def fn(payload): @@ -395,7 +398,8 @@ def generate_base(model, model_dict, config, manifest, source_config, return context -def modify_generated_context(context, model, model_dict, config, manifest): +def modify_generated_context(context, model, model_dict, config, manifest, + provider): cli_var_overrides = config.cli_vars context = _add_tracking(context) @@ -408,7 +412,8 @@ def modify_generated_context(context, model, model_dict, config, manifest): context["write"] = write(model_dict, config.target_path, 'run') context["render"] = render(context, model_dict) - context["var"] = Var(model, context=context, overrides=cli_var_overrides) + context["var"] = provider.Var(model, context=context, + overrides=cli_var_overrides) context['context'] = context return context @@ -427,7 +432,7 @@ def generate_execute_macro(model, config, manifest, provider): provider) return modify_generated_context(context, model, model_dict, config, - manifest) + manifest, provider) def generate_model(model, config, manifest, source_config, provider): @@ -448,7 +453,7 @@ def generate_model(model, config, manifest, source_config, provider): }) return modify_generated_context(context, model, model_dict, config, - manifest) + manifest, provider) def generate(model, config, manifest, source_config=None, provider=None): diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 3d2a8da5d78..5759ad5105a 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -97,6 +97,12 @@ def get(self, name, validator=None, default=None): return '' +class Var(dbt.context.common.Var): + def get_missing_var(self, var_name): + # in the parser, just always return None. + return None + + def generate(model, runtime_config, manifest, source_config): # during parsing, we don't have a connection, but we might need one, so we # have to acquire it. diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index 2fc7b32cddb..62ba94283e1 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -118,6 +118,10 @@ def get(self, name, validator=None, default=None): return to_return +class Var(dbt.context.common.Var): + pass + + def generate(model, runtime_config, manifest): return dbt.context.common.generate( model, runtime_config, manifest, None, dbt.context.runtime) diff --git a/test/unit/test_context.py b/test/unit/test_context.py index bf286dc79fa..dd2f516f736 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -3,8 +3,10 @@ from dbt.contracts.graph.parsed import ParsedNode from dbt.context.common import Var +from dbt.context.parser import Var as ParserVar import dbt.exceptions + class TestVar(unittest.TestCase): def setUp(self): self.model = ParsedNode( @@ -59,3 +61,21 @@ def test_var_not_defined(self): self.assertEqual(var('foo', 'bar'), 'bar') with self.assertRaises(dbt.exceptions.CompilationException): var('foo') + + def test_parser_var_default_something(self): + var = ParserVar(self.model, self.context, overrides={'foo': 'baz'}) + self.assertEqual(var('foo'), 'baz') + self.assertEqual(var('foo', 'bar'), 'baz') + + def test_parser_var_default_none(self): + var = ParserVar(self.model, self.context, overrides={'foo': None}) + self.assertEqual(var('foo'), None) + self.assertEqual(var('foo', 'bar'), None) + + def test_parser_var_not_defined(self): + # at parse-time, we should not raise if we encounter a missing var + # that way disabled models don't get parse errors + var = ParserVar(self.model, self.context, overrides={}) + + self.assertEqual(var('foo', 'bar'), 'bar') + self.assertEqual(var('foo'), None)