Skip to content

Commit

Permalink
Remove bare invocations of @contextmember and @contextproperty,…
Browse files Browse the repository at this point in the history
… and add typing to them

Previously `contextmember` and `contextproperty` were 2-in-1 decorators.
This meant they could be invoked either as `@contextmember` or
`@contextmember('some_string')`. This was fine until we wanted to return
typing to the functions. In the instance where the bare decorator was used
(i.e. no `(...)` were present) an object was expected to be returned. However
in the instance where parameters were passed on the invocation, a callable
was expected to be returned. Putting a union of both in the return type
made the invocations complain about each others' return type. To get around this
we've dropped the bare invocation as acceptable. The parenthesis are now always
required, but passing a string in them is optional.
  • Loading branch information
QMalcolm committed Aug 31, 2023
1 parent 4c98b18 commit 909b377
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 76 deletions.
52 changes: 24 additions & 28 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import os
from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List
from typing import Any, Callable, Dict, NoReturn, Optional, Mapping, Iterable, Set, List
import threading

from dbt.flags import get_flags
Expand Down Expand Up @@ -98,16 +98,12 @@ def key(self, default: str) -> str:
return self.name


def contextmember(value):
if isinstance(value, str):
return lambda v: ContextMember(v, name=value)
return ContextMember(value)
def contextmember(value: Optional[str] = None) -> Callable:
return lambda v: ContextMember(v, name=value)


def contextproperty(value):
if isinstance(value, str):
return lambda v: ContextMember(property(v), name=value)
return ContextMember(property(value))
def contextproperty(value: Optional[str] = None) -> Callable:
return lambda v: ContextMember(property(v), name=value)


class ContextMeta(type):
Expand Down Expand Up @@ -208,7 +204,7 @@ def to_dict(self) -> Dict[str, Any]:
self._ctx.update(builtins)
return self._ctx

@contextproperty
@contextproperty()
def dbt_version(self) -> str:
"""The `dbt_version` variable returns the installed version of dbt that
is currently running. It can be used for debugging or auditing
Expand All @@ -228,7 +224,7 @@ def dbt_version(self) -> str:
"""
return dbt_version

@contextproperty
@contextproperty()
def var(self) -> Var:
"""Variables can be passed from your `dbt_project.yml` file into models
during compilation. These variables are useful for configuring packages
Expand Down Expand Up @@ -297,7 +293,7 @@ def var(self) -> Var:
"""
return Var(self._ctx, self.cli_vars)

@contextmember
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the environment variable named 'var'.
If there is no such environment variable set, return the default.
Expand Down Expand Up @@ -325,7 +321,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:

if os.environ.get("DBT_MACRO_DEBUGGING"):

@contextmember
@contextmember()
@staticmethod
def debug():
"""Enter a debugger at this line in the compiled jinja code."""
Expand Down Expand Up @@ -364,7 +360,7 @@ def _return(data: Any) -> NoReturn:
"""
raise MacroReturn(data)

@contextmember
@contextmember()
@staticmethod
def fromjson(string: str, default: Any = None) -> Any:
"""The `fromjson` context method can be used to deserialize a json
Expand All @@ -385,7 +381,7 @@ def fromjson(string: str, default: Any = None) -> Any:
except ValueError:
return default

@contextmember
@contextmember()
@staticmethod
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
"""The `tojson` context method can be used to serialize a Python
Expand All @@ -408,7 +404,7 @@ def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
except ValueError:
return default

@contextmember
@contextmember()
@staticmethod
def fromyaml(value: str, default: Any = None) -> Any:
"""The fromyaml context method can be used to deserialize a yaml string
Expand Down Expand Up @@ -439,7 +435,7 @@ def fromyaml(value: str, default: Any = None) -> Any:

# safe_dump defaults to sort_keys=True, but we act like json.dumps (the
# opposite)
@contextmember
@contextmember()
@staticmethod
def toyaml(
value: Any, default: Optional[str] = None, sort_keys: bool = False
Expand Down Expand Up @@ -484,7 +480,7 @@ def _set(value: Iterable[Any], default: Any = None) -> Optional[Set[Any]]:
except TypeError:
return default

@contextmember
@contextmember()
@staticmethod
def set_strict(value: Iterable[Any]) -> Set[Any]:
"""The `set_strict` context method can be used to convert any iterable
Expand Down Expand Up @@ -526,7 +522,7 @@ def _zip(*args: Iterable[Any], default: Any = None) -> Optional[Iterable[Any]]:
except TypeError:
return default

@contextmember
@contextmember()
@staticmethod
def zip_strict(*args: Iterable[Any]) -> Iterable[Any]:
"""The `zip_strict` context method can be used to used to return
Expand All @@ -548,7 +544,7 @@ def zip_strict(*args: Iterable[Any]) -> Iterable[Any]:
except TypeError as e:
raise ZipStrictWrongTypeError(e)

@contextmember
@contextmember()
@staticmethod
def log(msg: str, info: bool = False) -> str:
"""Logs a line to either the log file or stdout.
Expand All @@ -569,7 +565,7 @@ def log(msg: str, info: bool = False) -> str:
fire_event(JinjaLogDebug(msg=msg, node_info=get_node_info()))
return ""

@contextproperty
@contextproperty()
def run_started_at(self) -> Optional[datetime.datetime]:
"""`run_started_at` outputs the timestamp that this run started, e.g.
`2017-04-21 01:23:45.678`. The `run_started_at` variable is a Python
Expand Down Expand Up @@ -597,19 +593,19 @@ def run_started_at(self) -> Optional[datetime.datetime]:
else:
return None

@contextproperty
@contextproperty()
def invocation_id(self) -> Optional[str]:
"""invocation_id outputs a UUID generated for this dbt run (useful for
auditing)
"""
return get_invocation_id()

@contextproperty
@contextproperty()
def thread_id(self) -> str:
"""thread_id outputs an ID for the current thread (useful for auditing)"""
return threading.current_thread().name

@contextproperty
@contextproperty()
def modules(self) -> Dict[str, Any]:
"""The `modules` variable in the Jinja context contains useful Python
modules for operating on data.
Expand All @@ -634,7 +630,7 @@ def modules(self) -> Dict[str, Any]:
""" # noqa
return get_context_modules()

@contextproperty
@contextproperty()
def flags(self) -> Any:
"""The `flags` variable contains true/false values for flags provided
on the command line.
Expand All @@ -651,7 +647,7 @@ def flags(self) -> Any:
"""
return flags_module.get_flag_obj()

@contextmember
@contextmember()
@staticmethod
def print(msg: str) -> str:
"""Prints a line to stdout.
Expand All @@ -669,7 +665,7 @@ def print(msg: str) -> str:
print(msg)
return ""

@contextmember
@contextmember()
@staticmethod
def diff_of_two_dicts(
dict_a: Dict[str, List[str]], dict_b: Dict[str, List[str]]
Expand Down Expand Up @@ -698,7 +694,7 @@ def diff_of_two_dicts(
dict_diff.update({k: dict_a[k]})
return dict_diff

@contextmember
@contextmember()
@staticmethod
def local_md5(value: str) -> str:
"""Calculates an MD5 hash of the given string.
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config: AdapterRequiredConfig) -> None:
super().__init__(config.to_target_dict(), config.cli_vars)
self.config = config

@contextproperty
@contextproperty()
def project_name(self) -> str:
return self.config.project_name

Expand Down Expand Up @@ -80,11 +80,11 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

@contextproperty
@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)

@contextmember
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
Expand Down Expand Up @@ -113,7 +113,7 @@ class MacroResolvingContext(ConfiguredContext):
def __init__(self, config):
super().__init__(config)

@contextproperty
@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self.config.project_name)

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/context/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.node = node
self.manifest = manifest

@contextmember
@contextmember()
def doc(self, *args: str) -> str:
"""The `doc` function is used to reference docs blocks in schema.yml
files. It is analogous to the `ref` function. For more information,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/context/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def to_dict(self):
dct.update(self.namespace)
return dct

@contextproperty
@contextproperty()
def context_macro_stack(self):
return self.macro_stack

Expand Down
Loading

0 comments on commit 909b377

Please sign in to comment.