Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT to 1.5.latest] Use events.contextvar because of multiprocessing unable to pickle ContextVar (#7949) #7981

Merged
merged 1 commit into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230626-115838.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Move project_root contextvar into events.contextvars
time: 2023-06-26T11:58:38.965299-04:00
custom:
Author: gshank
Issue: "7937"
4 changes: 2 additions & 2 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
SeedExceedsLimitChecksumChanged,
ValidationWarning,
)
from dbt.events.contextvars import set_contextvars
from dbt.events.contextvars import set_log_contextvars
from dbt.flags import get_flags
from dbt.node_types import ModelLanguage, NodeType, AccessType

Expand Down Expand Up @@ -303,7 +303,7 @@ def node_info(self):
def update_event_status(self, **kwargs):
for k, v in kwargs.items():
self._event_status[k] = v
set_contextvars(node_info=self.node_info)
set_log_contextvars(node_info=self.node_info)

def clear_event_status(self):
self._event_status = dict()
Expand Down
77 changes: 54 additions & 23 deletions core/dbt/events/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,110 @@


LOG_PREFIX = "log_"
LOG_PREFIX_LEN = len(LOG_PREFIX)
TASK_PREFIX = "task_"

_log_context_vars: Dict[str, contextvars.ContextVar] = {}
_context_vars: Dict[str, contextvars.ContextVar] = {}


def get_contextvars() -> Dict[str, Any]:
def get_contextvars(prefix: str) -> Dict[str, Any]:
rv = {}
ctx = contextvars.copy_context()

prefix_len = len(prefix)
for k in ctx:
if k.name.startswith(LOG_PREFIX) and ctx[k] is not Ellipsis:
rv[k.name[LOG_PREFIX_LEN:]] = ctx[k]
if k.name.startswith(prefix) and ctx[k] is not Ellipsis:
rv[k.name[prefix_len:]] = ctx[k]

return rv


def get_node_info():
cvars = get_contextvars()
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
else:
return {}


def clear_contextvars() -> None:
def get_project_root():
cvars = get_contextvars(TASK_PREFIX)
if "project_root" in cvars:
return cvars["project_root"]
else:
return None


def clear_contextvars(prefix: str) -> None:
ctx = contextvars.copy_context()
for k in ctx:
if k.name.startswith(LOG_PREFIX):
if k.name.startswith(prefix):
k.set(Ellipsis)


def set_log_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
return set_contextvars(LOG_PREFIX, **kwargs)


def set_task_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
return set_contextvars(TASK_PREFIX, **kwargs)


# put keys and values into context. Returns the contextvar.Token mapping
# Save and pass to reset_contextvars
def set_contextvars(**kwargs: Any) -> Mapping[str, contextvars.Token]:
def set_contextvars(prefix: str, **kwargs: Any) -> Mapping[str, contextvars.Token]:
cvar_tokens = {}
for k, v in kwargs.items():
log_key = f"{LOG_PREFIX}{k}"
log_key = f"{prefix}{k}"
try:
var = _log_context_vars[log_key]
var = _context_vars[log_key]
except KeyError:
var = contextvars.ContextVar(log_key, default=Ellipsis)
_log_context_vars[log_key] = var
_context_vars[log_key] = var

cvar_tokens[k] = var.set(v)

return cvar_tokens


# reset by Tokens
def reset_contextvars(**kwargs: contextvars.Token) -> None:
def reset_contextvars(prefix: str, **kwargs: contextvars.Token) -> None:
for k, v in kwargs.items():
log_key = f"{LOG_PREFIX}{k}"
var = _log_context_vars[log_key]
log_key = f"{prefix}{k}"
var = _context_vars[log_key]
var.reset(v)


# remove from contextvars
def unset_contextvars(*keys: str) -> None:
def unset_contextvars(prefix: str, *keys: str) -> None:
for k in keys:
if k in _log_context_vars:
log_key = f"{LOG_PREFIX}{k}"
_log_context_vars[log_key].set(Ellipsis)
if k in _context_vars:
log_key = f"{prefix}{k}"
_context_vars[log_key].set(Ellipsis)


# Context manager or decorator to set and unset the context vars
@contextlib.contextmanager
def log_contextvars(**kwargs: Any) -> Generator[None, None, None]:
context = get_contextvars()
context = get_contextvars(LOG_PREFIX)
saved = {k: context[k] for k in context.keys() & kwargs.keys()}

set_contextvars(LOG_PREFIX, **kwargs)
try:
yield
finally:
unset_contextvars(LOG_PREFIX, *kwargs.keys())
set_contextvars(LOG_PREFIX, **saved)


# Context manager for earlier in task.run
@contextlib.contextmanager
def task_contextvars(**kwargs: Any) -> Generator[None, None, None]:
context = get_contextvars(TASK_PREFIX)
saved = {k: context[k] for k in context.keys() & kwargs.keys()}

set_contextvars(**kwargs)
set_contextvars(TASK_PREFIX, **kwargs)
try:
yield
finally:
unset_contextvars(*kwargs.keys())
set_contextvars(**saved)
unset_contextvars(TASK_PREFIX, *kwargs.keys())
set_contextvars(TASK_PREFIX, **saved)
8 changes: 6 additions & 2 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DbtRuntimeError,
)
from dbt.node_types import NodeType
from dbt.task.contextvars import cv_project_root
from dbt.events.contextvars import get_project_root


SELECTOR_GLOB = "*"
Expand Down Expand Up @@ -326,7 +326,11 @@ class PathSelectorMethod(SelectorMethod):
def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]:
"""Yields nodes from included that match the given path."""
# get project root from contextvar
root = Path(cv_project_root.get())
project_root = get_project_root()
if project_root:
root = Path(project_root)
else:
root = Path.cwd()
paths = set(p.relative_to(root) for p in root.glob(selector))
for node, real_node in self.all_nodes(included_nodes):
ofp = Path(real_node.original_file_path)
Expand Down
3 changes: 0 additions & 3 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from dbt.graph import Graph
from dbt.logger import log_manager
from .printer import print_run_result_error
from dbt.task.contextvars import cv_project_root


class NoneConfig:
Expand Down Expand Up @@ -76,8 +75,6 @@ def __init__(self, args, config, project=None):
self.args = args
self.config = config
self.project = config if isinstance(config, Project) else project
if self.config:
cv_project_root.set(self.config.project_root)

@classmethod
def pre_init_hook(cls, args):
Expand Down
6 changes: 0 additions & 6 deletions core/dbt/task/contextvars.py

This file was deleted.

41 changes: 23 additions & 18 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
EndRunResult,
NothingToDo,
)
from dbt.events.contextvars import log_contextvars
from dbt.events.contextvars import log_contextvars, task_contextvars
from dbt.contracts.graph.nodes import SourceDefinition, ResultNode
from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus
from dbt.contracts.state import PreviousState
Expand Down Expand Up @@ -422,25 +422,30 @@ def run(self):
"""
Run dbt for the query, based on the graph.
"""
self._runtime_initialize()
# We set up a context manager here with "task_contextvars" because we
# we need the project_root in runtime_initialize.
with task_contextvars(project_root=self.config.project_root):
self._runtime_initialize()

if self._flattened_nodes is None:
raise DbtInternalError("after _runtime_initialize, _flattened_nodes was still None")
if self._flattened_nodes is None:
raise DbtInternalError(
"after _runtime_initialize, _flattened_nodes was still None"
)

if len(self._flattened_nodes) == 0:
with TextOnly():
fire_event(Formatting(""))
warn_or_error(NothingToDo())
result = self.get_result(
results=[],
generated_at=datetime.utcnow(),
elapsed_time=0.0,
)
else:
with TextOnly():
fire_event(Formatting(""))
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
result = self.execute_with_hooks(selected_uids)
if len(self._flattened_nodes) == 0:
with TextOnly():
fire_event(Formatting(""))
warn_or_error(NothingToDo())
result = self.get_result(
results=[],
generated_at=datetime.utcnow(),
elapsed_time=0.0,
)
else:
with TextOnly():
fire_event(Formatting(""))
selected_uids = frozenset(n.unique_id for n in self._flattened_nodes)
result = self.execute_with_hooks(selected_uids)

# We have other result types here too, including FreshnessResult
if isinstance(result, RunExecutionResult):
Expand Down