From 48c7de0db8bcc20cb4492ca8039d6af733c1fd9d Mon Sep 17 00:00:00 2001 From: Quigley Malcolm Date: Thu, 31 Aug 2023 12:38:26 -0700 Subject: [PATCH] Remove bare `invocations` of `@contextmember` and `@contextproperty`, and add typing to them Previously `contextmember` and `contextproperty` were 2-in-1 decorators. This meant they could be invoked either as `@contextmember` or `@contextmember('some_string')`. This was fine until we wanted to return typing to the functions. In the instance where the bare decorator was used (i.e. no `(...)` were present) an object was expected to be returned. However in the instance where parameters were passed on the invocation, a callable was expected to be returned. Putting a union of both in the return type made the invocations complain about each others' return type. To get around this we've dropped the bare invocation as acceptable. The parenthesis are now always required, but passing a string in them is optional. --- core/dbt/context/base.py | 52 ++++++++++------------ core/dbt/context/configured.py | 8 ++-- core/dbt/context/docs.py | 2 +- core/dbt/context/manifest.py | 2 +- core/dbt/context/providers.py | 80 +++++++++++++++++----------------- core/dbt/context/secret.py | 2 +- core/dbt/context/target.py | 2 +- 7 files changed, 72 insertions(+), 76 deletions(-) diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 6eacfdf9b78..35f3dd4a81d 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -2,7 +2,7 @@ import json import os -from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List +from typing import Any, Callable, Dict, NoReturn, Optional, Mapping, Iterable, Set, List import threading from dbt.flags import get_flags @@ -98,16 +98,12 @@ def key(self, default: str) -> str: return self.name -def contextmember(value): - if isinstance(value, str): - return lambda v: ContextMember(v, name=value) - return ContextMember(value) +def contextmember(value: Optional[str] = None) -> Callable: + return lambda v: ContextMember(v, name=value) -def contextproperty(value): - if isinstance(value, str): - return lambda v: ContextMember(property(v), name=value) - return ContextMember(property(value)) +def contextproperty(value: Optional[str] = None) -> Callable: + return lambda v: ContextMember(property(v), name=value) class ContextMeta(type): @@ -208,7 +204,7 @@ def to_dict(self) -> Dict[str, Any]: self._ctx.update(builtins) return self._ctx - @contextproperty + @contextproperty() def dbt_version(self) -> str: """The `dbt_version` variable returns the installed version of dbt that is currently running. It can be used for debugging or auditing @@ -228,7 +224,7 @@ def dbt_version(self) -> str: """ return dbt_version - @contextproperty + @contextproperty() def var(self) -> Var: """Variables can be passed from your `dbt_project.yml` file into models during compilation. These variables are useful for configuring packages @@ -297,7 +293,7 @@ def var(self) -> Var: """ return Var(self._ctx, self.cli_vars) - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the environment variable named 'var'. If there is no such environment variable set, return the default. @@ -325,7 +321,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: if os.environ.get("DBT_MACRO_DEBUGGING"): - @contextmember + @contextmember() @staticmethod def debug(): """Enter a debugger at this line in the compiled jinja code.""" @@ -364,7 +360,7 @@ def _return(data: Any) -> NoReturn: """ raise MacroReturn(data) - @contextmember + @contextmember() @staticmethod def fromjson(string: str, default: Any = None) -> Any: """The `fromjson` context method can be used to deserialize a json @@ -385,7 +381,7 @@ def fromjson(string: str, default: Any = None) -> Any: except ValueError: return default - @contextmember + @contextmember() @staticmethod def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any: """The `tojson` context method can be used to serialize a Python @@ -408,7 +404,7 @@ def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any: except ValueError: return default - @contextmember + @contextmember() @staticmethod def fromyaml(value: str, default: Any = None) -> Any: """The fromyaml context method can be used to deserialize a yaml string @@ -439,7 +435,7 @@ def fromyaml(value: str, default: Any = None) -> Any: # safe_dump defaults to sort_keys=True, but we act like json.dumps (the # opposite) - @contextmember + @contextmember() @staticmethod def toyaml( value: Any, default: Optional[str] = None, sort_keys: bool = False @@ -484,7 +480,7 @@ def _set(value: Iterable[Any], default: Any = None) -> Optional[Set[Any]]: except TypeError: return default - @contextmember + @contextmember() @staticmethod def set_strict(value: Iterable[Any]) -> Set[Any]: """The `set_strict` context method can be used to convert any iterable @@ -526,7 +522,7 @@ def _zip(*args: Iterable[Any], default: Any = None) -> Optional[Iterable[Any]]: except TypeError: return default - @contextmember + @contextmember() @staticmethod def zip_strict(*args: Iterable[Any]) -> Iterable[Any]: """The `zip_strict` context method can be used to used to return @@ -548,7 +544,7 @@ def zip_strict(*args: Iterable[Any]) -> Iterable[Any]: except TypeError as e: raise ZipStrictWrongTypeError(e) - @contextmember + @contextmember() @staticmethod def log(msg: str, info: bool = False) -> str: """Logs a line to either the log file or stdout. @@ -569,7 +565,7 @@ def log(msg: str, info: bool = False) -> str: fire_event(JinjaLogDebug(msg=msg, node_info=get_node_info())) return "" - @contextproperty + @contextproperty() def run_started_at(self) -> Optional[datetime.datetime]: """`run_started_at` outputs the timestamp that this run started, e.g. `2017-04-21 01:23:45.678`. The `run_started_at` variable is a Python @@ -597,19 +593,19 @@ def run_started_at(self) -> Optional[datetime.datetime]: else: return None - @contextproperty + @contextproperty() def invocation_id(self) -> Optional[str]: """invocation_id outputs a UUID generated for this dbt run (useful for auditing) """ return get_invocation_id() - @contextproperty + @contextproperty() def thread_id(self) -> str: """thread_id outputs an ID for the current thread (useful for auditing)""" return threading.current_thread().name - @contextproperty + @contextproperty() def modules(self) -> Dict[str, Any]: """The `modules` variable in the Jinja context contains useful Python modules for operating on data. @@ -634,7 +630,7 @@ def modules(self) -> Dict[str, Any]: """ # noqa return get_context_modules() - @contextproperty + @contextproperty() def flags(self) -> Any: """The `flags` variable contains true/false values for flags provided on the command line. @@ -651,7 +647,7 @@ def flags(self) -> Any: """ return flags_module.get_flag_obj() - @contextmember + @contextmember() @staticmethod def print(msg: str) -> str: """Prints a line to stdout. @@ -669,7 +665,7 @@ def print(msg: str) -> str: print(msg) return "" - @contextmember + @contextmember() @staticmethod def diff_of_two_dicts( dict_a: Dict[str, List[str]], dict_b: Dict[str, List[str]] @@ -698,7 +694,7 @@ def diff_of_two_dicts( dict_diff.update({k: dict_a[k]}) return dict_diff - @contextmember + @contextmember() @staticmethod def local_md5(value: str) -> str: """Calculates an MD5 hash of the given string. diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index bb292a19565..08f5bee1143 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -19,7 +19,7 @@ def __init__(self, config: AdapterRequiredConfig) -> None: super().__init__(config.to_target_dict(), config.cli_vars) self.config = config - @contextproperty + @contextproperty() def project_name(self) -> str: return self.config.project_name @@ -80,11 +80,11 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY self._project_name = project_name self.schema_yaml_vars = schema_yaml_vars - @contextproperty + @contextproperty() def var(self) -> ConfiguredVar: return ConfiguredVar(self._ctx, self.config, self._project_name) - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): @@ -113,7 +113,7 @@ class MacroResolvingContext(ConfiguredContext): def __init__(self, config): super().__init__(config) - @contextproperty + @contextproperty() def var(self) -> ConfiguredVar: return ConfiguredVar(self._ctx, self.config, self.config.project_name) diff --git a/core/dbt/context/docs.py b/core/dbt/context/docs.py index 3d5abf42e11..94f64709fc7 100644 --- a/core/dbt/context/docs.py +++ b/core/dbt/context/docs.py @@ -24,7 +24,7 @@ def __init__( self.node = node self.manifest = manifest - @contextmember + @contextmember() def doc(self, *args: str) -> str: """The `doc` function is used to reference docs blocks in schema.yml files. It is analogous to the `ref` function. For more information, diff --git a/core/dbt/context/manifest.py b/core/dbt/context/manifest.py index c6a39993d92..f2492612cc8 100644 --- a/core/dbt/context/manifest.py +++ b/core/dbt/context/manifest.py @@ -67,7 +67,7 @@ def to_dict(self): dct.update(self.namespace) return dct - @contextproperty + @contextproperty() def context_macro_stack(self): return self.macro_stack diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 6b981091682..ffc1f6d07b4 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -754,19 +754,19 @@ def _get_namespace_builder(self): self.model, ) - @contextproperty + @contextproperty() def dbt_metadata_envs(self) -> Dict[str, str]: return get_metadata_vars() - @contextproperty + @contextproperty() def invocation_args_dict(self): return args_to_dict(self.config.args) - @contextproperty + @contextproperty() def _sql_results(self) -> Dict[str, Optional[AttrDict]]: return self.sql_results - @contextmember + @contextmember() def load_result(self, name: str) -> Optional[AttrDict]: if name in self.sql_results: # handle the special case of "main" macro @@ -787,7 +787,7 @@ def load_result(self, name: str) -> Optional[AttrDict]: # Handle trying to load a result that was never stored return None - @contextmember + @contextmember() def store_result( self, name: str, response: Any, agate_table: Optional[agate.Table] = None ) -> str: @@ -803,7 +803,7 @@ def store_result( ) return "" - @contextmember + @contextmember() def store_raw_result( self, name: str, @@ -815,7 +815,7 @@ def store_raw_result( response = AdapterResponse(_message=message, code=code, rows_affected=rows_affected) return self.store_result(name, response, agate_table) - @contextproperty + @contextproperty() def validation(self): def validate_any(*args) -> Callable[[T], None]: def inner(value: T) -> None: @@ -836,7 +836,7 @@ def inner(value: T) -> None: } ) - @contextmember + @contextmember() def write(self, payload: str) -> str: # macros/source defs aren't 'writeable'. if isinstance(self.model, (Macro, SourceDefinition)): @@ -845,11 +845,11 @@ def write(self, payload: str) -> str: self.model.write_node(self.config.project_root, self.model.build_path, payload) return "" - @contextmember + @contextmember() def render(self, string: str) -> str: return get_rendered(string, self._ctx, self.model) - @contextmember + @contextmember() def try_or_compiler_error( self, message_if_exception: str, func: Callable, *args, **kwargs ) -> Any: @@ -858,7 +858,7 @@ def try_or_compiler_error( except Exception: raise CompilationError(message_if_exception, self.model) - @contextmember + @contextmember() def load_agate_table(self) -> agate.Table: if not isinstance(self.model, SeedNode): raise LoadAgateTableNotSeedError(self.model.resource_type, node=self.model) @@ -873,7 +873,7 @@ def load_agate_table(self) -> agate.Table: table.original_abspath = os.path.abspath(path) return table - @contextproperty + @contextproperty() def ref(self) -> Callable: """The most important function in dbt is `ref()`; it's impossible to build even moderately complex models without it. `ref()` is how you @@ -914,11 +914,11 @@ def ref(self) -> Callable: """ return self.provider.ref(self.db_wrapper, self.model, self.config, self.manifest) - @contextproperty + @contextproperty() def source(self) -> Callable: return self.provider.source(self.db_wrapper, self.model, self.config, self.manifest) - @contextproperty + @contextproperty() def metric(self) -> Callable: return self.provider.metric(self.db_wrapper, self.model, self.config, self.manifest) @@ -979,7 +979,7 @@ def ctx_config(self) -> Config: """ # noqa return self.provider.Config(self.model, self.context_config) - @contextproperty + @contextproperty() def execute(self) -> bool: """`execute` is a Jinja variable that returns True when dbt is in "execute" mode. @@ -1040,7 +1040,7 @@ def execute(self) -> bool: """ # noqa return self.provider.execute - @contextproperty + @contextproperty() def exceptions(self) -> Dict[str, Any]: """The exceptions namespace can be used to raise warnings and errors in dbt userspace. @@ -1078,15 +1078,15 @@ def exceptions(self) -> Dict[str, Any]: """ # noqa return wrapped_exports(self.model) - @contextproperty + @contextproperty() def database(self) -> str: return self.config.credentials.database - @contextproperty + @contextproperty() def schema(self) -> str: return self.config.credentials.schema - @contextproperty + @contextproperty() def var(self) -> ModelConfiguredVar: return self.provider.Var( context=self._ctx, @@ -1103,22 +1103,22 @@ def ctx_adapter(self) -> BaseDatabaseWrapper: """ return self.db_wrapper - @contextproperty + @contextproperty() def api(self) -> Dict[str, Any]: return { "Relation": self.db_wrapper.Relation, "Column": self.adapter.Column, } - @contextproperty + @contextproperty() def column(self) -> Type[Column]: return self.adapter.Column - @contextproperty + @contextproperty() def env(self) -> Dict[str, Any]: return self.target - @contextproperty + @contextproperty() def graph(self) -> Dict[str, Any]: """The `graph` context variable contains information about the nodes in your dbt project. Models, sources, tests, and snapshots are all @@ -1234,23 +1234,23 @@ def ctx_model(self) -> Dict[str, Any]: ret["compiled_sql"] = ret["compiled_code"] return ret - @contextproperty + @contextproperty() def pre_hooks(self) -> Optional[List[Dict[str, Any]]]: return None - @contextproperty + @contextproperty() def post_hooks(self) -> Optional[List[Dict[str, Any]]]: return None - @contextproperty + @contextproperty() def sql(self) -> Optional[str]: return None - @contextproperty + @contextproperty() def sql_now(self) -> str: return self.adapter.date_function() - @contextmember + @contextmember() def adapter_macro(self, name: str, *args, **kwargs): """This was deprecated in v0.18 in favor of adapter.dispatch""" msg = ( @@ -1262,7 +1262,7 @@ def adapter_macro(self, name: str, *args, **kwargs): ) raise CompilationError(msg) - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the environment variable named 'var'. If there is no such environment variable set, return the default. @@ -1306,7 +1306,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: else: raise EnvVarMissingError(var) - @contextproperty + @contextproperty() def selected_resources(self) -> List[str]: """The `selected_resources` variable contains a list of the resources selected based on the parameters provided to the dbt command. @@ -1315,7 +1315,7 @@ def selected_resources(self) -> List[str]: """ return selected_resources.SELECTED_RESOURCES - @contextmember + @contextmember() def submit_python_job(self, parsed_model: Dict, compiled_code: str) -> AdapterResponse: # Check macro_stack and that the unique id is for a materialization macro if not ( @@ -1358,7 +1358,7 @@ def __init__( class ModelContext(ProviderContext): model: ManifestNode - @contextproperty + @contextproperty() def pre_hooks(self) -> List[Dict[str, Any]]: if self.model.resource_type in [NodeType.Source, NodeType.Test]: return [] @@ -1367,7 +1367,7 @@ def pre_hooks(self) -> List[Dict[str, Any]]: h.to_dict(omit_none=True) for h in self.model.config.pre_hook # type: ignore[union-attr] # noqa ] - @contextproperty + @contextproperty() def post_hooks(self) -> List[Dict[str, Any]]: if self.model.resource_type in [NodeType.Source, NodeType.Test]: return [] @@ -1376,7 +1376,7 @@ def post_hooks(self) -> List[Dict[str, Any]]: h.to_dict(omit_none=True) for h in self.model.config.post_hook # type: ignore[union-attr] # noqa ] - @contextproperty + @contextproperty() def sql(self) -> Optional[str]: # only doing this in sql model for backward compatible if self.model.language == ModelLanguage.sql: # type: ignore[union-attr] @@ -1393,7 +1393,7 @@ def sql(self) -> Optional[str]: else: return None - @contextproperty + @contextproperty() def compiled_code(self) -> Optional[str]: if getattr(self.model, "defer_relation", None): # TODO https://github.com/dbt-labs/dbt-core/issues/7976 @@ -1404,15 +1404,15 @@ def compiled_code(self) -> Optional[str]: else: return None - @contextproperty + @contextproperty() def database(self) -> str: return getattr(self.model, "database", self.config.credentials.database) - @contextproperty + @contextproperty() def schema(self) -> str: return getattr(self.model, "schema", self.config.credentials.schema) - @contextproperty + @contextproperty() def this(self) -> Optional[RelationProxy]: """`this` makes available schema information about the currently executing model. It's is useful in any context in which you need to @@ -1447,7 +1447,7 @@ def this(self) -> Optional[RelationProxy]: return None return self.db_wrapper.Relation.create_from(self.config, self.model) - @contextproperty + @contextproperty() def defer_relation(self) -> Optional[RelationProxy]: """ For commands which add information about this node's corresponding @@ -1661,7 +1661,7 @@ def _build_test_namespace(self): ) self.namespace = macro_namespace - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): diff --git a/core/dbt/context/secret.py b/core/dbt/context/secret.py index 4d8ff342aff..2c75546c42a 100644 --- a/core/dbt/context/secret.py +++ b/core/dbt/context/secret.py @@ -14,7 +14,7 @@ class SecretContext(BaseContext): """This context is used in profiles.yml + packages.yml. It can render secret env vars that aren't usable elsewhere""" - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the environment variable named 'var'. If there is no such environment variable set, return the default. diff --git a/core/dbt/context/target.py b/core/dbt/context/target.py index a6d587269d5..39c5a30ee4e 100644 --- a/core/dbt/context/target.py +++ b/core/dbt/context/target.py @@ -9,7 +9,7 @@ def __init__(self, target_dict: Dict[str, Any], cli_vars: Dict[str, Any]): super().__init__(cli_vars=cli_vars) self.target_dict = target_dict - @contextproperty + @contextproperty() def target(self) -> Dict[str, Any]: """`target` contains information about your connection to the warehouse (specified in profiles.yml). Some configs are shared between all