Skip to content

Commit

Permalink
Merge pull request #1864 from fishtown-analytics/feature/query-comments
Browse files Browse the repository at this point in the history
inject query comments (#1643)
  • Loading branch information
beckjake authored Nov 4, 2019
2 parents f985902 + b56d93b commit c4cd4fc
Show file tree
Hide file tree
Showing 43 changed files with 1,121 additions and 348 deletions.
33 changes: 28 additions & 5 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from multiprocessing import RLock
import os
from multiprocessing import RLock
from threading import get_ident
from typing import (
Dict, Tuple, Hashable, Optional, ContextManager, List
Expand All @@ -11,7 +11,10 @@
import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, HasCredentials
Connection, Identifier, ConnectionState, AdapterRequiredConfig
)
from dbt.adapters.base.query_headers import (
QueryStringSetter, MacroQueryStringSetter,
)
from dbt.logger import GLOBAL_LOGGER as logger

Expand All @@ -31,10 +34,17 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
"""
TYPE: str = NotImplemented

def __init__(self, profile: HasCredentials):
def __init__(self, profile: AdapterRequiredConfig):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
self.query_header = QueryStringSetter(self.profile)

def set_query_header(self, manifest=None) -> None:
if manifest is not None:
self.query_header = MacroQueryStringSetter(self.profile, manifest)
else:
self.query_header = QueryStringSetter(self.profile)

@staticmethod
def get_thread_identifier() -> Hashable:
Expand Down Expand Up @@ -91,6 +101,10 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
# named 'master'
conn_name = 'master'
else:
if not isinstance(name, str):
raise dbt.exceptions.CompilerException(
f'For connection name, got {name} - not a string!'
)
assert isinstance(name, str)
conn_name = name

Expand Down Expand Up @@ -221,7 +235,10 @@ def _close_handle(cls, connection: Connection) -> None:
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand All @@ -236,7 +253,10 @@ def _rollback(cls, connection: Connection) -> None:
@classmethod
def close(cls, connection: Connection) -> Connection:
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
)

# if the connection is in closed or init, there's nothing to do
if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
Expand All @@ -257,6 +277,9 @@ def commit_if_has_connection(self) -> None:
if connection:
self.commit()

def _add_query_comment(self, sql: str) -> str:
return self.query_header.add(sql)

@abc.abstractmethod
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
Expand Down
37 changes: 25 additions & 12 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping
Mapping, Iterator,
)

import agate
Expand All @@ -13,12 +13,13 @@
import dbt.flags

from dbt.clients.agate_helper import empty_table
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values

from dbt.adapters.base.connections import BaseConnectionManager
from dbt.adapters.base.connections import BaseConnectionManager, Connection
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import ComponentName, BaseRelation
from dbt.adapters.base import Column as BaseColumn
Expand Down Expand Up @@ -203,20 +204,20 @@ def __init__(self, config):
###
# Methods that pass through to the connection manager
###
def acquire_connection(self, name=None):
def acquire_connection(self, name=None) -> Connection:
return self.connections.set_connection_name(name)

def release_connection(self):
return self.connections.release()
def release_connection(self) -> None:
self.connections.release()

def cleanup_connections(self):
return self.connections.cleanup_all()
def cleanup_connections(self) -> None:
self.connections.cleanup_all()

def clear_transaction(self):
def clear_transaction(self) -> None:
self.connections.clear_transaction()

def commit_if_has_connection(self):
return self.connections.commit_if_has_connection()
def commit_if_has_connection(self) -> None:
self.connections.commit_if_has_connection()

def nice_connection_name(self):
conn = self.connections.get_if_exists()
Expand All @@ -225,11 +226,23 @@ def nice_connection_name(self):
return conn.name

@contextmanager
def connection_named(self, name):
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
):
try:
yield self.acquire_connection(name)
self.connections.query_header.set(name, node)
conn = self.acquire_connection(name)
yield conn
finally:
self.release_connection()
self.connections.query_header.reset()

@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[Connection]:
with self.connection_named(node.unique_id, node) as conn:
yield conn

@available.parse(lambda *a, **k: ('', empty_table()))
def execute(
Expand Down
131 changes: 131 additions & 0 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from threading import local
from typing import Optional, Callable

from dbt.clients.jinja import QueryStringGenerator

# this generates an import cycle, as usual
from dbt.context.base import QueryHeaderContext
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.exceptions import RuntimeException
from dbt.helper_types import NoValue


DEFAULT_QUERY_COMMENT = '''
{%- set comment_dict = {} -%}
{%- do comment_dict.update(
app='dbt',
dbt_version=dbt_version,
profile_name=target.get('profile_name'),
target_name=target.get('target_name'),
) -%}
{%- if node is not none -%}
{%- do comment_dict.update(
node_id=node.unique_id,
) -%}
{% else %}
{# in the node context, the connection name is the node_id #}
{%- do comment_dict.update(connection_name=connection_name) -%}
{%- endif -%}
{{ return(tojson(comment_dict)) }}
'''


class NodeWrapper:
def __init__(self, node):
self._inner_node = node

def __getattr__(self, name):
return getattr(self._inner_node, name, '')


class _QueryComment(local):
"""A thread-local class storing thread-specific state information for
connection management, namely:
- the current thread's query comment.
- a source_name indicating what set the current thread's query comment
"""
def __init__(self, initial):
self.query_comment: Optional[str] = initial

def add(self, sql: str) -> str:
if not self.query_comment:
return sql
else:
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)

def set(self, comment: Optional[str]):
if '*/' in comment:
# tell the user "no" so they don't hurt themselves by writing
# garbage
raise RuntimeException(
f'query comment contains illegal value "*/": {comment}'
)
self.query_comment = comment


QueryStringFunc = Callable[[str, Optional[CompileResultNode]], str]


class QueryStringSetter:
"""The base query string setter. This is only used once."""
def __init__(self, config: AdapterRequiredConfig):
self.config = config

comment_macro = self._get_comment_macro()
self.generator: QueryStringFunc = lambda name, model: ''
# if the comment value was None or the empty string, just skip it
if comment_macro:
macro = '\n'.join((
'{%- macro query_comment_macro(connection_name, node) -%}',
self._get_comment_macro(),
'{% endmacro %}'
))
ctx = self._get_context()
self.generator: QueryStringFunc = QueryStringGenerator(macro, ctx)
self.comment = _QueryComment(None)
self.reset()

def _get_context(self):
return QueryHeaderContext(self.config).to_dict()

def _get_comment_macro(self) -> Optional[str]:
# if the query comment is null/empty string, there is no comment at all
if not self.config.query_comment:
return None
else:
# else, the default
return DEFAULT_QUERY_COMMENT

def add(self, sql: str) -> str:
return self.comment.add(sql)

def reset(self):
self.set('master', None)

def set(self, name: str, node: Optional[CompileResultNode]):
if node is not None:
wrapped = NodeWrapper(node)
else:
wrapped = None
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)


class MacroQueryStringSetter(QueryStringSetter):
def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
self.manifest = manifest
super().__init__(config)

def _get_comment_macro(self):
if (
self.config.query_comment != NoValue() and
self.config.query_comment
):
return self.config.query_comment
else:
return super()._get_comment_macro()

def _get_context(self):
return QueryHeaderContext(self.config).to_dict(self.manifest.macros)
13 changes: 11 additions & 2 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dbt.contracts.util import Replaceable
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.exceptions import InternalException
from dbt import deprecations


Expand Down Expand Up @@ -322,10 +323,18 @@ def create_from(
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
assert isinstance(node, ParsedSourceDefinition)
if not isinstance(node, ParsedSourceDefinition):
raise InternalException(
'type mismatch, expected ParsedSourceDefinition but got {}'
.format(type(node))
)
return cls.create_from_source(node, **kwargs)
else:
assert isinstance(node, (ParsedNode, CompiledNode))
if not isinstance(node, (ParsedNode, CompiledNode)):
raise InternalException(
'type mismatch, expected ParsedNode or CompiledNode but '
'got {}'.format(type(node))
)
return cls.create_from_node(config, node, **kwargs)

@classmethod
Expand Down
9 changes: 5 additions & 4 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from dbt.exceptions import RuntimeException
from dbt.include.global_project import PACKAGES
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.connection import Credentials, HasCredentials
from dbt.contracts.connection import Credentials, AdapterRequiredConfig

from dbt.adapters.base.impl import BaseAdapter
from dbt.adapters.base.plugin import AdapterPlugin


# TODO: we can't import these because they cause an import cycle.
# Profile has to call into load_plugin to get credentials, so adapter/relation
# don't work
Expand Down Expand Up @@ -74,7 +75,7 @@ def load_plugin(self, name: str) -> Type[Credentials]:

return plugin.credentials

def register_adapter(self, config: HasCredentials) -> None:
def register_adapter(self, config: AdapterRequiredConfig) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)

Expand Down Expand Up @@ -109,11 +110,11 @@ def cleanup_connections(self):
FACTORY: AdpaterContainer = AdpaterContainer()


def register_adapter(config: HasCredentials) -> None:
def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config)


def get_adapter(config: HasCredentials):
def get_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type)


Expand Down
11 changes: 9 additions & 2 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table:
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[str, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
status = self.get_status(cursor)
if fetch:
Expand All @@ -130,7 +131,10 @@ def begin(self):
connection = self.get_thread_connection()

if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
)

if connection.transaction_open is True:
raise dbt.exceptions.InternalException(
Expand All @@ -145,7 +149,10 @@ def begin(self):
def commit(self):
connection = self.get_thread_connection()
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand Down
Loading

0 comments on commit c4cd4fc

Please sign in to comment.