diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 6361d675ad4..59aa9774366 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -206,8 +206,7 @@ def log(msg, info=False): class Var(object): UndefinedVarError = "Required var '{}' not found in config:\nVars "\ "supplied to {} = {}" - NoneVarError = "Supplied var '{}' is undefined in config:\nVars supplied "\ - "to {} = {}" + _VAR_NOTSET = object() def __init__(self, model, context, overrides): self.model = model @@ -241,7 +240,7 @@ def pretty_dict(self, data): return json.dumps(data, sort_keys=True, indent=4) def assert_var_defined(self, var_name, default): - if var_name not in self.local_vars and default is None: + if var_name not in self.local_vars and default is self._VAR_NOTSET: pretty_vars = self.pretty_dict(self.local_vars) dbt.exceptions.raise_compiler_error( self.UndefinedVarError.format( @@ -250,25 +249,12 @@ def assert_var_defined(self, var_name, default): self.model ) - def assert_var_not_none(self, var_name): - raw = self.local_vars[var_name] - if raw is None: - pretty_vars = self.pretty_dict(self.local_vars) - dbt.exceptions.raise_compiler_error( - self.NoneVarError.format( - var_name, self.model_name, pretty_vars - ), - self.model - ) - - def __call__(self, var_name, default=None): + def __call__(self, var_name, default=_VAR_NOTSET): self.assert_var_defined(var_name, default) if var_name not in self.local_vars: return default - self.assert_var_not_none(var_name) - raw = self.local_vars[var_name] # if bool/int/float/etc are passed in, don't compile anything diff --git a/test/unit/test_context.py b/test/unit/test_context.py index c5d9a5c99f4..bf286dc79fa 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -43,14 +43,19 @@ def setUp(self): ) self.context = mock.MagicMock() - def test_var_not_none_is_none(self): + def test_var_default_something(self): + var = Var(self.model, self.context, overrides={'foo': 'baz'}) + self.assertEqual(var('foo'), 'baz') + self.assertEqual(var('foo', 'bar'), 'baz') + + def test_var_default_none(self): var = Var(self.model, self.context, overrides={'foo': None}) - var.assert_var_defined('foo', None) - with self.assertRaises(dbt.exceptions.CompilationException): - var.assert_var_not_none('foo') + self.assertEqual(var('foo'), None) + self.assertEqual(var('foo', 'bar'), None) - def test_var_defined_is_missing(self): + def test_var_not_defined(self): var = Var(self.model, self.context, overrides={}) - var.assert_var_defined('foo', 'bar') + + self.assertEqual(var('foo', 'bar'), 'bar') with self.assertRaises(dbt.exceptions.CompilationException): - var.assert_var_defined('foo', None) + var('foo')