Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inject query comments (#1643) #1864

Merged
merged 8 commits into from
Nov 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from multiprocessing import RLock
import os
from multiprocessing import RLock
from threading import get_ident
from typing import (
Dict, Tuple, Hashable, Optional, ContextManager, List
Expand All @@ -11,7 +11,10 @@
import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, HasCredentials
Connection, Identifier, ConnectionState, AdapterRequiredConfig
)
from dbt.adapters.base.query_headers import (
QueryStringSetter, MacroQueryStringSetter,
)
from dbt.logger import GLOBAL_LOGGER as logger

Expand All @@ -31,10 +34,17 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
"""
TYPE: str = NotImplemented

def __init__(self, profile: HasCredentials):
def __init__(self, profile: AdapterRequiredConfig):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
self.query_header = QueryStringSetter(self.profile)

def set_query_header(self, manifest=None) -> None:
if manifest is not None:
self.query_header = MacroQueryStringSetter(self.profile, manifest)
else:
self.query_header = QueryStringSetter(self.profile)

@staticmethod
def get_thread_identifier() -> Hashable:
Expand Down Expand Up @@ -91,6 +101,10 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
# named 'master'
conn_name = 'master'
else:
if not isinstance(name, str):
raise dbt.exceptions.CompilerException(
f'For connection name, got {name} - not a string!'
)
assert isinstance(name, str)
conn_name = name

Expand Down Expand Up @@ -221,7 +235,10 @@ def _close_handle(cls, connection: Connection) -> None:
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand All @@ -236,7 +253,10 @@ def _rollback(cls, connection: Connection) -> None:
@classmethod
def close(cls, connection: Connection) -> Connection:
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
)

# if the connection is in closed or init, there's nothing to do
if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
Expand All @@ -257,6 +277,9 @@ def commit_if_has_connection(self) -> None:
if connection:
self.commit()

def _add_query_comment(self, sql: str) -> str:
return self.query_header.add(sql)

@abc.abstractmethod
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
Expand Down
37 changes: 25 additions & 12 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping
Mapping, Iterator,
)

import agate
Expand All @@ -13,12 +13,13 @@
import dbt.flags

from dbt.clients.agate_helper import empty_table
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values

from dbt.adapters.base.connections import BaseConnectionManager
from dbt.adapters.base.connections import BaseConnectionManager, Connection
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import ComponentName, BaseRelation
from dbt.adapters.base import Column as BaseColumn
Expand Down Expand Up @@ -203,20 +204,20 @@ def __init__(self, config):
###
# Methods that pass through to the connection manager
###
def acquire_connection(self, name=None):
def acquire_connection(self, name=None) -> Connection:
return self.connections.set_connection_name(name)

def release_connection(self):
return self.connections.release()
def release_connection(self) -> None:
self.connections.release()

def cleanup_connections(self):
return self.connections.cleanup_all()
def cleanup_connections(self) -> None:
self.connections.cleanup_all()

def clear_transaction(self):
def clear_transaction(self) -> None:
self.connections.clear_transaction()

def commit_if_has_connection(self):
return self.connections.commit_if_has_connection()
def commit_if_has_connection(self) -> None:
self.connections.commit_if_has_connection()

def nice_connection_name(self):
conn = self.connections.get_if_exists()
Expand All @@ -225,11 +226,23 @@ def nice_connection_name(self):
return conn.name

@contextmanager
def connection_named(self, name):
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
):
try:
yield self.acquire_connection(name)
self.connections.query_header.set(name, node)
conn = self.acquire_connection(name)
yield conn
finally:
self.release_connection()
self.connections.query_header.reset()

@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[Connection]:
with self.connection_named(node.unique_id, node) as conn:
yield conn

@available.parse(lambda *a, **k: ('', empty_table()))
def execute(
Expand Down
131 changes: 131 additions & 0 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from threading import local
from typing import Optional, Callable

from dbt.clients.jinja import QueryStringGenerator

# this generates an import cycle, as usual
from dbt.context.base import QueryHeaderContext
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.exceptions import RuntimeException
from dbt.helper_types import NoValue


DEFAULT_QUERY_COMMENT = '''
{%- set comment_dict = {} -%}
{%- do comment_dict.update(
app='dbt',
dbt_version=dbt_version,
profile_name=target.get('profile_name'),
target_name=target.get('target_name'),
) -%}
{%- if node is not none -%}
{%- do comment_dict.update(
node_id=node.unique_id,
) -%}
{% else %}
{# in the node context, the connection name is the node_id #}
{%- do comment_dict.update(connection_name=connection_name) -%}
{%- endif -%}
{{ return(tojson(comment_dict)) }}
'''


class NodeWrapper:
def __init__(self, node):
self._inner_node = node

def __getattr__(self, name):
return getattr(self._inner_node, name, '')


class _QueryComment(local):
"""A thread-local class storing thread-specific state information for
connection management, namely:
- the current thread's query comment.
- a source_name indicating what set the current thread's query comment
"""
def __init__(self, initial):
self.query_comment: Optional[str] = initial

def add(self, sql: str) -> str:
if not self.query_comment:
return sql
else:
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)

def set(self, comment: Optional[str]):
if '*/' in comment:
# tell the user "no" so they don't hurt themselves by writing
# garbage
raise RuntimeException(
f'query comment contains illegal value "*/": {comment}'
)
self.query_comment = comment


QueryStringFunc = Callable[[str, Optional[CompileResultNode]], str]


class QueryStringSetter:
"""The base query string setter. This is only used once."""
def __init__(self, config: AdapterRequiredConfig):
self.config = config

comment_macro = self._get_comment_macro()
self.generator: QueryStringFunc = lambda name, model: ''
# if the comment value was None or the empty string, just skip it
if comment_macro:
macro = '\n'.join((
'{%- macro query_comment_macro(connection_name, node) -%}',
self._get_comment_macro(),
'{% endmacro %}'
))
ctx = self._get_context()
self.generator: QueryStringFunc = QueryStringGenerator(macro, ctx)
self.comment = _QueryComment(None)
self.reset()

def _get_context(self):
return QueryHeaderContext(self.config).to_dict()

def _get_comment_macro(self) -> Optional[str]:
# if the query comment is null/empty string, there is no comment at all
if not self.config.query_comment:
return None
else:
# else, the default
return DEFAULT_QUERY_COMMENT

def add(self, sql: str) -> str:
return self.comment.add(sql)

def reset(self):
self.set('master', None)

def set(self, name: str, node: Optional[CompileResultNode]):
if node is not None:
wrapped = NodeWrapper(node)
else:
wrapped = None
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)


class MacroQueryStringSetter(QueryStringSetter):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
self.manifest = manifest
super().__init__(config)

def _get_comment_macro(self):
if (
self.config.query_comment != NoValue() and
self.config.query_comment
):
return self.config.query_comment
else:
return super()._get_comment_macro()

def _get_context(self):
return QueryHeaderContext(self.config).to_dict(self.manifest.macros)
13 changes: 11 additions & 2 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dbt.contracts.util import Replaceable
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.exceptions import InternalException
from dbt import deprecations


Expand Down Expand Up @@ -322,10 +323,18 @@ def create_from(
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
assert isinstance(node, ParsedSourceDefinition)
if not isinstance(node, ParsedSourceDefinition):
raise InternalException(
'type mismatch, expected ParsedSourceDefinition but got {}'
.format(type(node))
)
return cls.create_from_source(node, **kwargs)
else:
assert isinstance(node, (ParsedNode, CompiledNode))
if not isinstance(node, (ParsedNode, CompiledNode)):
raise InternalException(
'type mismatch, expected ParsedNode or CompiledNode but '
'got {}'.format(type(node))
)
return cls.create_from_node(config, node, **kwargs)

@classmethod
Expand Down
9 changes: 5 additions & 4 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from dbt.exceptions import RuntimeException
from dbt.include.global_project import PACKAGES
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.connection import Credentials, HasCredentials
from dbt.contracts.connection import Credentials, AdapterRequiredConfig

from dbt.adapters.base.impl import BaseAdapter
from dbt.adapters.base.plugin import AdapterPlugin


# TODO: we can't import these because they cause an import cycle.
# Profile has to call into load_plugin to get credentials, so adapter/relation
# don't work
Expand Down Expand Up @@ -74,7 +75,7 @@ def load_plugin(self, name: str) -> Type[Credentials]:

return plugin.credentials

def register_adapter(self, config: HasCredentials) -> None:
def register_adapter(self, config: AdapterRequiredConfig) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)

Expand Down Expand Up @@ -109,11 +110,11 @@ def cleanup_connections(self):
FACTORY: AdpaterContainer = AdpaterContainer()


def register_adapter(config: HasCredentials) -> None:
def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config)


def get_adapter(config: HasCredentials):
def get_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type)


Expand Down
11 changes: 9 additions & 2 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table:
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[str, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
status = self.get_status(cursor)
if fetch:
Expand All @@ -130,7 +131,10 @@ def begin(self):
connection = self.get_thread_connection()

if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
)

if connection.transaction_open is True:
raise dbt.exceptions.InternalException(
Expand All @@ -145,7 +149,10 @@ def begin(self):
def commit(self):
connection = self.get_thread_connection()
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand Down
Loading