Skip to content

Commit

Permalink
Fix tests to ignore rollbacks for bq, fix redshift tests
Browse files Browse the repository at this point in the history
Trim down the "target" context value to use the opt-in connection_info
Make sure it contains a superset of the documented stuff
Make sure it does not contain any blacklisted items
Change some asserts to raise InternalExceptions because assert error messages are bad
  • Loading branch information
Jacob Beck committed Oct 25, 2019
1 parent 542a053 commit 82bd794
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 31 deletions.
14 changes: 12 additions & 2 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection:
# named 'master'
conn_name = 'master'
else:
if not isinstance(name, str):
raise dbt.exceptions.CompilerException(
f'For connection name, got {name} - not a string!'
)
assert isinstance(name, str)
conn_name = name

Expand Down Expand Up @@ -223,7 +227,10 @@ def _close_handle(cls, connection: Connection) -> None:
def _rollback(cls, connection: Connection) -> None:
"""Roll back the given connection."""
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In _rollback, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand All @@ -238,7 +245,10 @@ def _rollback(cls, connection: Connection) -> None:
@classmethod
def close(cls, connection: Connection) -> Connection:
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In close, got {connection} - not a Connection!'
)

# if the connection is in closed or init, there's nothing to do
if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
Expand Down
13 changes: 11 additions & 2 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dbt.contracts.util import Replaceable
from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.exceptions import InternalException
from dbt import deprecations


Expand Down Expand Up @@ -330,10 +331,18 @@ def create_from(
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
assert isinstance(node, ParsedSourceDefinition)
if not isinstance(node, ParsedSourceDefinition):
raise InternalException(
'type mismatch, expected ParsedSourceDefinition but got {}'
.format(type(node))
)
return cls.create_from_source(node, **kwargs)
else:
assert isinstance(node, (ParsedNode, CompiledNode))
if not isinstance(node, (ParsedNode, CompiledNode)):
raise InternalException(
'type mismatch, expected ParsedNode or CompiledNode but '
'got {}'.format(type(node))
)
return cls.create_from_node(config, node, **kwargs)

@classmethod
Expand Down
10 changes: 8 additions & 2 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def begin(self):
connection = self.get_thread_connection()

if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In begin, got {connection} - not a Connection!'
)

if connection.transaction_open is True:
raise dbt.exceptions.InternalException(
Expand All @@ -146,7 +149,10 @@ def begin(self):
def commit(self):
connection = self.get_thread_connection()
if dbt.flags.STRICT_MODE:
assert isinstance(connection, Connection)
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'In commit, got {connection} - not a Connection!'
)

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def exception_handler(self) -> Iterator[None]:
dbt.exceptions.raise_compiler_error(str(e))

def call_macro(self, *args, **kwargs):
if self.context is None:
raise dbt.exceptions.InternalException(
'Context is still None in call_macro!'
)
assert self.context is not None

macro = self.get_macro()
Expand Down
6 changes: 4 additions & 2 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def recursively_prepend_ctes(model, manifest):
return (model, model.extra_ctes, manifest)

if dbt.flags.STRICT_MODE:
assert isinstance(model, tuple(COMPILED_TYPES.values())), \
'Bad model type: {}'.format(type(model))
if not isinstance(model, tuple(COMPILED_TYPES.values())):
raise dbt.exceptions.InternalException(
'Bad model type: {}'.format(type(model))
)

prepended_ctes = []

Expand Down
20 changes: 12 additions & 8 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,18 @@ def __init__(self, config):
self.config = config

def get_target(self) -> Dict[str, Any]:
target_name = self.config.target_name
target = self.config.to_profile_info()
del target['credentials']
target.update(self.config.credentials.to_dict(with_aliases=True))
target['type'] = self.config.credentials.type
target.pop('pass', None)
target.pop('password', None)
target['name'] = target_name
target = dict(
self.config.credentials.connection_info(with_aliases=True)
)
target.update({
'type': self.config.credentials.type,
'threads': self.config.threads,
'name': self.config.target_name,
# not specified, but present for compatibility
'target_name': self.config.target_name,
'profile_name': self.config.profile_name,
'config': self.config.config.to_dict(),
})
return target

@property
Expand Down
19 changes: 14 additions & 5 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import abc
import itertools
from dataclasses import dataclass, field
from typing import (
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List
)
from typing_extensions import Protocol

Expand Down Expand Up @@ -88,11 +89,19 @@ def type(self) -> str:
'type not implemented for base credentials class'
)

def connection_info(self) -> Iterable[Tuple[str, Any]]:
def connection_info(
self, *, with_aliases: bool = False
) -> Iterable[Tuple[str, Any]]:
"""Return an ordered iterator of key/value pairs for pretty-printing.
"""
as_dict = self.to_dict()
for key in self._connection_keys():
as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases)
connection_keys = set(self._connection_keys())
aliases: List[str] = []
if with_aliases:
aliases = [
k for k, v in self._ALIASES.items() if v in connection_keys
]
for key in itertools.chain(self._connection_keys(), aliases):
if key in as_dict:
yield key, as_dict[key]

Expand All @@ -109,7 +118,7 @@ def from_dict(cls, data):
def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
return translate_aliases(kwargs, cls._ALIASES)

def to_dict(self, omit_none=True, validate=False, with_aliases=False):
def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):
serialized = super().to_dict(omit_none=omit_none, validate=validate)
if with_aliases:
serialized.update({
Expand Down
3 changes: 2 additions & 1 deletion plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def type(self):
return 'bigquery'

def _connection_keys(self):
return ('method', 'database', 'schema', 'location')
return ('method', 'database', 'schema', 'location', 'priority',
'timeout_seconds')


class BigQueryConnectionManager(BaseConnectionManager):
Expand Down
12 changes: 10 additions & 2 deletions plugins/bigquery/dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,16 @@ def execute_model(self, model, materialization, sql_override=None,

if flags.STRICT_MODE:
connection = self.connections.get_thread_connection()
assert isinstance(connection, Connection)
assert(connection.name == model.get('name'))
if not isinstance(connection, Connection):
raise dbt.exceptions.CompilerException(
f'Got {connection} - not a Connection!'
)
model_uid = model.get('unique_id')
if connection.name != model_uid:
raise dbt.exceptions.InternalException(
f'Connection had name "{connection.name}", expected model '
f'unique id of "{model_uid}"'
)

if materialization == 'view':
res = self._materialize_as_view(model)
Expand Down
3 changes: 2 additions & 1 deletion plugins/postgres/dbt/adapters/postgres/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def type(self):
return 'postgres'

def _connection_keys(self):
return ('host', 'port', 'user', 'database', 'schema', 'search_path')
return ('host', 'port', 'user', 'database', 'schema', 'search_path',
'keepalives_idle')


class PostgresConnectionManager(SQLConnectionManager):
Expand Down
5 changes: 2 additions & 3 deletions plugins/redshift/dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def type(self):
return 'redshift'

def _connection_keys(self):
return (
'host', 'port', 'user', 'database', 'schema', 'method',
'search_path')
keys = super()._connection_keys()
return keys + ('method', 'cluster_id', 'iam_duration_seconds')


class RedshiftConnectionManager(PostgresConnectionManager):
Expand Down
3 changes: 2 additions & 1 deletion plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def type(self):
return 'snowflake'

def _connection_keys(self):
return ('account', 'user', 'database', 'schema', 'warehouse', 'role')
return ('account', 'user', 'database', 'schema', 'warehouse', 'role',
'client_session_keep_alive')

def auth_args(self):
# Pull all of the optional authentication args for the connector,
Expand Down
30 changes: 30 additions & 0 deletions test/integration/051_query_comments_test/models/x.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
{% set blacklist = ['pass', 'password', 'keyfile', 'keyfile.json', 'password', 'private_key_passphrase'] %}
{% for key in blacklist %}
{% if key in blacklist and blacklist[key] %}
{% do exceptions.raise_compiler_error('invalid target, found banned key "' ~ key ~ '"') %}
{% endif %}
{% endfor %}

{% if 'type' not in target %}
{% do exceptions.raise_compiler_error('invalid target, missing "type"') %}
{% endif %}

{% set required = ['name', 'schema', 'type', 'threads'] %}

{# Require what we docuement at https://docs.getdbt.com/docs/target #}
{% if target.type == 'postgres' or target.type == 'redshift' %}
{% do required.extend(['dbname', 'host', 'user', 'port']) %}
{% elif target.type == 'snowflake' %}
{% do required.extend(['database', 'warehouse', 'user', 'role', 'account']) %}
{% elif target.type == 'bigquery' %}
{% do required.extend(['project']) %}
{% else %}
{% do exceptions.raise_compiler_error('invalid target, got unknown type "' ~ target.type ~ '"') %}

{% endif %}

{% for value in required %}
{% if value not in target %}
{% do exceptions.raise_compiler_error('invalid target, missing "' ~ value ~ '"') %}
{% endif %}
{% endfor %}

{% do run_query('select 2 as inner_id') %}
select 1 as outer_id
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def query_comment(self, model_name, log):

if log['message'].startswith(prefix):
msg = log['message'][len(prefix):]
if msg in {'COMMIT', 'BEGIN'}:
if msg in {'COMMIT', 'BEGIN', 'ROLLBACK'}:
return None
return msg
return None
Expand All @@ -78,7 +78,10 @@ def profile_config(self):
return {'config': {'query_comment': 'dbt\nrules!\n'}}

def matches_comment(self, msg) -> bool:
self.assertTrue(msg.startswith('-- dbt\n-- rules!\n'))
self.assertTrue(
msg.startswith('-- dbt\n-- rules!\n'),
f'{msg} did not start with query comment'
)

@use_profile('postgres')
def test_postgres_comments(self):
Expand Down

0 comments on commit 82bd794

Please sign in to comment.