Skip to content

Commit

Permalink
Merge pull request #2085 from fishtown-analytics/feature/contexts-redux
Browse files Browse the repository at this point in the history
Feature: contexts cleanup
  • Loading branch information
beckjake authored Feb 7, 2020
2 parents 140cfd7 + 2e5fdbf commit 9df123a
Show file tree
Hide file tree
Showing 58 changed files with 2,375 additions and 1,650 deletions.
14 changes: 7 additions & 7 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle
)
from dbt.contracts.graph.manifest import Manifest
from dbt.adapters.base.query_headers import (
QueryStringSetter, MacroQueryStringSetter,
MacroQueryStringSetter,
)
from dbt.logger import GLOBAL_LOGGER as logger

Expand All @@ -39,13 +40,10 @@ def __init__(self, profile: AdapterRequiredConfig):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
self.query_header = QueryStringSetter(self.profile)
self.query_header: Optional[MacroQueryStringSetter] = None

def set_query_header(self, manifest=None) -> None:
if manifest is not None:
self.query_header = MacroQueryStringSetter(self.profile, manifest)
else:
self.query_header = QueryStringSetter(self.profile)
def set_query_header(self, manifest: Manifest) -> None:
self.query_header = MacroQueryStringSetter(self.profile, manifest)

@staticmethod
def get_thread_identifier() -> Hashable:
Expand Down Expand Up @@ -285,6 +283,8 @@ def commit_if_has_connection(self) -> None:
self.commit()

def _add_query_comment(self, sql: str) -> str:
if self.query_header is None:
return sql
return self.query_header.add(sql)

@abc.abstractmethod
Expand Down
16 changes: 9 additions & 7 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,14 @@ def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
) -> Iterator[None]:
try:
self.connections.query_header.set(name, node)
if self.connections.query_header is not None:
self.connections.query_header.set(name, node)
self.acquire_connection(name)
yield
finally:
self.release_connection()
self.connections.query_header.reset()
if self.connections.query_header is not None:
self.connections.query_header.reset()

@contextmanager
def connection_for(
Expand Down Expand Up @@ -983,12 +985,12 @@ def execute_macro(
'dbt could not find a macro with the name "{}" in {}'
.format(macro_name, package_name)
)
# This causes a reference cycle, as dbt.context.runtime.generate()
# This causes a reference cycle, as generate_runtime_macro()
# ends up calling get_adapter, so the import has to be here.
import dbt.context.operation
macro_context = dbt.context.operation.generate(
model=macro,
runtime_config=self.config,
from dbt.context.providers import generate_runtime_macro
macro_context = generate_runtime_macro(
macro=macro,
config=self.config,
manifest=manifest,
package_name=project
)
Expand Down
9 changes: 7 additions & 2 deletions core/dbt/adapters/base/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Type

from dbt.adapters.base import BaseAdapter, Credentials
from dbt.exceptions import CompilationException


class AdapterPlugin:
Expand All @@ -23,8 +24,12 @@ def __init__(
self.adapter: Type[BaseAdapter] = adapter
self.credentials: Type[Credentials] = credentials
self.include_path: str = include_path
project = Project.from_project_root(include_path, {})
self.project_name: str = project.project_name
partial = Project.partial_load(include_path)
if partial.project_name is None:
raise CompilationException(
f'Invalid project at {include_path}: name not set!'
)
self.project_name: str = partial.project_name
self.dependencies: List[str]
if dependencies is None:
self.dependencies = []
Expand Down
42 changes: 14 additions & 28 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from dbt.clients.jinja import QueryStringGenerator

# this generates an import cycle, as usual
from dbt.context.base import QueryHeaderContext
from dbt.context.configured import generate_query_header_context
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
Expand Down Expand Up @@ -68,9 +67,9 @@ def set(self, comment: Optional[str]):
QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str]


class QueryStringSetter:
"""The base query string setter. This is only used once."""
def __init__(self, config: AdapterRequiredConfig):
class MacroQueryStringSetter:
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
self.manifest = manifest
self.config = config

comment_macro = self._get_comment_macro()
Expand All @@ -88,17 +87,22 @@ def __init__(self, config: AdapterRequiredConfig):
self.comment = _QueryComment(None)
self.reset()

def _get_context(self):
return QueryHeaderContext(self.config).to_dict()

def _get_comment_macro(self) -> Optional[str]:
def _get_comment_macro(self):
if (
self.config.query_comment != NoValue() and
self.config.query_comment
):
return self.config.query_comment
# if the query comment is null/empty string, there is no comment at all
if not self.config.query_comment:
elif not self.config.query_comment:
return None
else:
# else, the default
return DEFAULT_QUERY_COMMENT

def _get_context(self) -> Dict[str, Any]:
return generate_query_header_context(self.config, self.manifest)

def add(self, sql: str) -> str:
return self.comment.add(sql)

Expand All @@ -111,21 +115,3 @@ def set(self, name: str, node: Optional[CompileResultNode]):
wrapped = NodeWrapper(node)
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)


class MacroQueryStringSetter(QueryStringSetter):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
self.manifest = manifest
super().__init__(config)

def _get_comment_macro(self):
if (
self.config.query_comment != NoValue() and
self.config.query_comment
):
return self.config.query_comment
else:
return super()._get_comment_macro()

def _get_context(self) -> Dict[str, Any]:
return QueryHeaderContext(self.config).to_dict(self.manifest.macros)
21 changes: 17 additions & 4 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import os
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
List, Union, Set, Optional, Dict, Any, Callable, Iterator, Type
List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn
)

import jinja2
Expand Down Expand Up @@ -164,6 +165,18 @@ def call_macro(self, *args, **kwargs):
return e.value


@dataclass
class MacroProxy:
generator: 'MacroGenerator'

@property
def node(self):
return self.generator.node

def __call__(self, *args, **kwargs):
return self.generator.call_macro(*args, **kwargs)


class MacroGenerator(BaseMacroGenerator):
def __init__(self, node, context: Optional[Dict[str, Any]] = None) -> None:
super().__init__(context)
Expand All @@ -185,9 +198,9 @@ def exception_handler(self) -> Iterator[None]:
e.stack.append(self.node)
raise e

def __call__(self, context: Dict[str, Any]) -> Callable:
def __call__(self, context: Dict[str, Any]) -> MacroProxy:
self.context = context
return self.call_macro
return MacroProxy(self)


class QueryStringGenerator(BaseMacroGenerator):
Expand Down Expand Up @@ -374,7 +387,7 @@ def get_rendered(string, ctx, node=None,
return render_template(template, ctx, node)


def undefined_error(msg):
def undefined_error(msg) -> NoReturn:
raise jinja2.exceptions.UndefinedError(msg)


Expand Down
15 changes: 7 additions & 8 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dbt.node_types import NodeType
from dbt.linker import Linker

import dbt.context.runtime
from dbt.context.providers import generate_runtime_model
import dbt.contracts.project
import dbt.exceptions
import dbt.flags
Expand Down Expand Up @@ -146,8 +146,9 @@ def compile_node(self, node, manifest, extra_context=None):
})
compiled_node = _compiled_type_for(node).from_dict(data)

context = dbt.context.runtime.generate(
compiled_node, self.config, manifest)
context = generate_runtime_model(
compiled_node, self.config, manifest
)
context.update(extra_context)

compiled_node.compiled_sql = dbt.clients.jinja.get_rendered(
Expand Down Expand Up @@ -253,13 +254,11 @@ def compile_node(adapter, config, node, manifest, extra_context, write=True):
logger.debug('Writing injected SQL for node "{}"'.format(
node.unique_id))

written_path = dbt.writer.write_node(
node,
node.build_path = node.write_node(
config.target_path,
'compiled',
node.injected_sql)

node.build_path = written_path
node.injected_sql
)

return node

Expand Down
37 changes: 17 additions & 20 deletions core/dbt/config/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from dbt.clients.system import load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.connection import Credentials
from dbt.contracts.connection import Credentials, HasCredentials
from dbt.contracts.project import ProfileConfig, UserConfig
from dbt.exceptions import DbtProfileError
from dbt.exceptions import DbtProjectError
from dbt.exceptions import ValidationException
from dbt.exceptions import RuntimeException
from dbt.exceptions import validator_error_message
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import parse_cli_vars, coerce_dict_str
from dbt.utils import coerce_dict_str

from .renderer import ConfigRenderer

Expand Down Expand Up @@ -73,7 +73,7 @@ def read_user_config(directory: str) -> UserConfig:


@dataclass
class Profile:
class Profile(HasCredentials):
profile_name: str
target_name: str
config: UserConfig
Expand Down Expand Up @@ -147,7 +147,8 @@ def _credentials_from_profile(

@staticmethod
def pick_profile_name(
args_profile_name: str, project_profile_name: Optional[str] = None,
args_profile_name: Optional[str],
project_profile_name: Optional[str] = None,
) -> str:
profile_name = project_profile_name
if args_profile_name is not None:
Expand Down Expand Up @@ -217,13 +218,11 @@ def render_profile(
raw_profile: Dict[str, Any],
profile_name: str,
target_override: Optional[str],
cli_vars: Dict[str, Any],
renderer: ConfigRenderer,
) -> Tuple[str, Dict[str, Any]]:
"""This is a containment zone for the hateful way we're rendering
profiles.
"""
renderer = ConfigRenderer(cli_vars=cli_vars)

# rendering profiles is a bit complex. Two constraints cause trouble:
# 1) users should be able to use environment/cli variables to specify
# the target in their profile.
Expand Down Expand Up @@ -255,7 +254,7 @@ def from_raw_profile_info(
cls,
raw_profile: Dict[str, Any],
profile_name: str,
cli_vars: Dict[str, Any],
renderer: ConfigRenderer,
user_cfg: Optional[Dict[str, Any]] = None,
target_override: Optional[str] = None,
threads_override: Optional[int] = None,
Expand All @@ -267,8 +266,7 @@ def from_raw_profile_info(
:param raw_profile: The profile data for a single profile, from
disk as yaml and its values rendered with jinja.
:param profile_name: The profile name used.
:param cli_vars: The command-line variables passed as arguments,
as a dict.
:param renderer: The config renderer.
:param user_cfg: The global config for the user, if it
was present.
:param target_override: The target to use, if provided on
Expand All @@ -285,7 +283,7 @@ def from_raw_profile_info(

# TODO: should it be, and the values coerced to bool?
target_name, profile_data = cls.render_profile(
raw_profile, profile_name, target_override, cli_vars
raw_profile, profile_name, target_override, renderer
)

# valid connections never include the number of threads, but it's
Expand All @@ -311,15 +309,14 @@ def from_raw_profiles(
cls,
raw_profiles: Dict[str, Any],
profile_name: str,
cli_vars: Dict[str, Any],
renderer: ConfigRenderer,
target_override: Optional[str] = None,
threads_override: Optional[int] = None,
) -> 'Profile':
"""
:param raw_profiles: The profile data, from disk as yaml.
:param profile_name: The profile name to use.
:param cli_vars: The command-line variables passed as arguments, as a
dict.
:param renderer: The config renderer.
:param target_override: The target to use, if provided on the command
line.
:param threads_override: The thread count to use, if provided on the
Expand All @@ -344,17 +341,18 @@ def from_raw_profiles(
return cls.from_raw_profile_info(
raw_profile=raw_profile,
profile_name=profile_name,
cli_vars=cli_vars,
renderer=renderer,
user_cfg=user_cfg,
target_override=target_override,
threads_override=threads_override,
)

@classmethod
def from_args(
def render_from_args(
cls,
args: Any,
project_profile_name: Optional[str] = None,
renderer: ConfigRenderer,
project_profile_name: Optional[str],
) -> 'Profile':
"""Given the raw profiles as read from disk and the name of the desired
profile if specified, return the profile component of the runtime
Expand All @@ -370,17 +368,16 @@ def from_args(
target could not be found.
:returns Profile: The new Profile object.
"""
cli_vars = parse_cli_vars(getattr(args, 'vars', '{}'))
threads_override = getattr(args, 'threads', None)
target_override = getattr(args, 'target', None)
raw_profiles = read_profile(args.profiles_dir)
profile_name = cls.pick_profile_name(args.profile,
profile_name = cls.pick_profile_name(getattr(args, 'profile', None),
project_profile_name)

return cls.from_raw_profiles(
raw_profiles=raw_profiles,
profile_name=profile_name,
cli_vars=cli_vars,
renderer=renderer,
target_override=target_override,
threads_override=threads_override
)
Loading

0 comments on commit 9df123a

Please sign in to comment.