Skip to content

Commit

Permalink
inject query comments
Browse files Browse the repository at this point in the history
Make a fake "macro" that we parse specially with a single global context
Macro takes an argument (the node, may be none)
Users supply the text of the macro in their 'user_config' under a new 'query_comment'
No macros available
query generator is an attribute on the connection manager
 - has a thread-local comment str
 - when acquiring a connection, set the comment str
new 'connection_for' context manager: like connection_named, except also use the node to set the query string
Updated unit tests to account for query comments
Added a hacky, brittle integration test
  - log to a custom stream and read that
Trim down the "target" context value to use the opt-in connection_info
 - Make sure it contains a superset of the documented stuff
 - Make sure it does not contain any blacklisted items
Change some asserts to raise InternalExceptions because assert error messages in threads are useless
  • Loading branch information
Jacob Beck committed Oct 25, 2019
1 parent b3ef028 commit fdef2f3
Show file tree
Hide file tree
Showing 32 changed files with 808 additions and 291 deletions.
21 changes: 18 additions & 3 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from multiprocessing import RLock
import os
from multiprocessing import RLock
from threading import get_ident
from typing import (
Dict, Tuple, Hashable, Optional, ContextManager, List
Expand All @@ -13,6 +13,7 @@
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, HasCredentials
)
from dbt.adapters.base.query_headers import QueryStringSetter
from dbt.logger import GLOBAL_LOGGER as logger


Expand All @@ -35,6 +36,7 @@ def __init__(self, profile: HasCredentials):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
self.query_header = QueryStringSetter(profile)

@staticmethod
def get_thread_identifier() -> Hashable:
Expand Down Expand Up @@ -91,6 +93,10 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
# named 'master'
conn_name = 'master'
else:
if not isinstance(name, str):
raise dbt.exceptions.CompilerException(
f'For connection name, got {name} - not a string!'
)
assert isinstance(name, str)
conn_name = name

Expand Down Expand Up @@ -221,7 +227,10 @@ def _close_handle(cls, connection: Connection) -> None:
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand All @@ -236,7 +245,10 @@ def _rollback(cls, connection: Connection) -> None:
@classmethod
def close(cls, connection: Connection) -> Connection:
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
)

# if the connection is in closed or init, there's nothing to do
if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
Expand All @@ -257,6 +269,9 @@ def commit_if_has_connection(self) -> None:
if connection:
self.commit()

def _add_query_comment(self, sql: str) -> str:
return self.query_header.add(sql)

@abc.abstractmethod
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
Expand Down
37 changes: 25 additions & 12 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping
Mapping, Iterator,
)

import agate
Expand All @@ -13,12 +13,13 @@
import dbt.flags

from dbt.clients.agate_helper import empty_table
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values

from dbt.adapters.base.connections import BaseConnectionManager
from dbt.adapters.base.connections import BaseConnectionManager, Connection
from dbt.adapters.base.meta import AdapterMeta, available
from dbt.adapters.base.relation import ComponentName, BaseRelation
from dbt.adapters.base import Column as BaseColumn
Expand Down Expand Up @@ -203,20 +204,20 @@ def __init__(self, config):
###
# Methods that pass through to the connection manager
###
def acquire_connection(self, name=None):
def acquire_connection(self, name=None) -> Connection:
return self.connections.set_connection_name(name)

def release_connection(self):
return self.connections.release()
def release_connection(self) -> None:
self.connections.release()

def cleanup_connections(self):
return self.connections.cleanup_all()
def cleanup_connections(self) -> None:
self.connections.cleanup_all()

def clear_transaction(self):
def clear_transaction(self) -> None:
self.connections.clear_transaction()

def commit_if_has_connection(self):
return self.connections.commit_if_has_connection()
def commit_if_has_connection(self) -> None:
self.connections.commit_if_has_connection()

def nice_connection_name(self):
conn = self.connections.get_if_exists()
Expand All @@ -225,11 +226,23 @@ def nice_connection_name(self):
return conn.name

@contextmanager
def connection_named(self, name):
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
):
try:
yield self.acquire_connection(name)
self.connections.query_header.set(name, node)
conn = self.acquire_connection(name)
yield conn
finally:
self.release_connection()
self.connections.query_header.reset()

@contextmanager
def connection_for(
self, node: CompileResultNode
) -> Iterator[Connection]:
with self.connection_named(node.unique_id, node) as conn:
yield conn

@available.parse(lambda *a, **k: ('', empty_table()))
def execute(
Expand Down
101 changes: 101 additions & 0 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from threading import local
from typing import Optional, Callable

from dbt.clients.jinja import QueryStringGenerator

from dbt.contracts.connection import HasCredentials
# this generates an import cycle, as usual
from dbt.context.base import QueryHeaderContext
from dbt.contracts.graph.compiled import CompileResultNode


default_query_comment = '''
{%- set comment_dict = {} -%}
{%- do comment_dict.update(target) -%}
{%- do comment_dict.update(
app='dbt',
dbt_version=dbt_version,
) -%}
{%- if node is not none -%}
{%- do comment_dict.update(
file=node.original_file_path,
node_id=node.unique_id,
node_name=node.name,
resource_type=node.resource_type,
package_name=node.package_name,
tags=node.tags,
identifier=node.identifier,
schema=node.schema,
database=node.database,
) -%}
{% else %}
{# in the node context, the connection name is the node_id #}
{%- do comment_dict.update(connection_name=connection_name) -%}
{%- endif -%}
{{ return(tojson(comment_dict)) }}
'''


class NodeWrapper:
def __init__(self, node):
self._inner_node = node

def __getattr__(self, name):
return getattr(self._inner_node, name, '')


class _QueryComment(local):
"""A thread-local class storing thread-specific state information for
connection management, namely:
- the current thread's query comment.
- a source_name indicating what set the current thread's query comment
"""
def __init__(self, initial):
self.query_comment: str = initial

def add(self, sql: str) -> str:
# Make sure there are no trailing newlines.
# For every newline, add a comment after it in case query_comment
# is multiple lines.
# Then add a comment to the first line of the query comment, and
# put the sql on a fresh line.
comment_split = self.query_comment.strip().replace('\n', '\n-- ')
return '-- {}\n{}'.format(comment_split, sql)

def set(self, comment: str):
self.query_comment = comment


QueryStringFunc = Callable[[str, Optional[CompileResultNode]], str]


class QueryStringSetter:
def __init__(self, config: HasCredentials):
if config.config.query_comment is not None:
comment = config.config.query_comment
else:
comment = default_query_comment
macro = '\n'.join((
'{%- macro query_comment_macro(connection_name, node) -%}',
comment,
'{% endmacro %}'
))

ctx = QueryHeaderContext(config).to_dict()
self.generator: QueryStringFunc = QueryStringGenerator(macro, ctx)
self.comment = _QueryComment('')
self.reset()

def add(self, sql: str) -> str:
return self.comment.add(sql)

def reset(self):
self.set('master', None)

def set(self, name: str, node: Optional[CompileResultNode]):
if node is not None:
wrapped = NodeWrapper(node)
else:
wrapped = None
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)
13 changes: 11 additions & 2 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dbt.contracts.util import Replaceable
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.exceptions import InternalException
from dbt import deprecations


Expand Down Expand Up @@ -330,10 +331,18 @@ def create_from(
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
assert isinstance(node, ParsedSourceDefinition)
if not isinstance(node, ParsedSourceDefinition):
raise InternalException(
'type mismatch, expected ParsedSourceDefinition but got {}'
.format(type(node))
)
return cls.create_from_source(node, **kwargs)
else:
assert isinstance(node, (ParsedNode, CompiledNode))
if not isinstance(node, (ParsedNode, CompiledNode)):
raise InternalException(
'type mismatch, expected ParsedNode or CompiledNode but '
'got {}'.format(type(node))
)
return cls.create_from_node(config, node, **kwargs)

@classmethod
Expand Down
11 changes: 9 additions & 2 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table:
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[str, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
status = self.get_status(cursor)
if fetch:
Expand All @@ -130,7 +131,10 @@ def begin(self):
connection = self.get_thread_connection()

if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
)

if connection.transaction_open is True:
raise dbt.exceptions.InternalException(
Expand All @@ -145,7 +149,10 @@ def begin(self):
def commit(self):
connection = self.get_thread_connection()
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand Down
Loading

0 comments on commit fdef2f3

Please sign in to comment.