Skip to content

Commit

Permalink
Move query comments into the project config
Browse files Browse the repository at this point in the history
Add special handling to 'dbt debug' for this behavior
Rework the dependencies/requirements for adapters since they now require more of a config object
tests...
  • Loading branch information
Jacob Beck committed Oct 29, 2019
1 parent 178f7f1 commit ae58199
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 37 deletions.
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dbt.exceptions
import dbt.flags
from dbt.contracts.connection import (
Connection, Identifier, ConnectionState, HasCredentials
Connection, Identifier, ConnectionState, AdapterRequiredConfig
)
from dbt.adapters.base.query_headers import QueryStringSetter
from dbt.logger import GLOBAL_LOGGER as logger
Expand All @@ -32,7 +32,7 @@ class BaseConnectionManager(metaclass=abc.ABCMeta):
"""
TYPE: str = NotImplemented

def __init__(self, profile: HasCredentials):
def __init__(self, profile: AdapterRequiredConfig):
self.profile = profile
self.thread_connections: Dict[Hashable, Connection] = {}
self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from dbt.clients.jinja import QueryStringGenerator

from dbt.contracts.connection import HasCredentials
# this generates an import cycle, as usual
from dbt.context.base import QueryHeaderContext
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.compiled import CompileResultNode


Expand Down Expand Up @@ -70,9 +70,9 @@ def set(self, comment: str):


class QueryStringSetter:
def __init__(self, config: HasCredentials):
if config.config.query_comment is not None:
comment = config.config.query_comment
def __init__(self, config: AdapterRequiredConfig):
if config.query_comment is not None:
comment = config.query_comment
else:
comment = default_query_comment
macro = '\n'.join((
Expand Down
9 changes: 5 additions & 4 deletions core/dbt/adapters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from dbt.exceptions import RuntimeException
from dbt.include.global_project import PACKAGES
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.connection import Credentials, HasCredentials
from dbt.contracts.connection import Credentials, AdapterRequiredConfig

from dbt.adapters.base.impl import BaseAdapter
from dbt.adapters.base.plugin import AdapterPlugin


# TODO: we can't import these because they cause an import cycle.
# Profile has to call into load_plugin to get credentials, so adapter/relation
# don't work
Expand Down Expand Up @@ -74,7 +75,7 @@ def load_plugin(self, name: str) -> Type[Credentials]:

return plugin.credentials

def register_adapter(self, config: HasCredentials) -> None:
def register_adapter(self, config: AdapterRequiredConfig) -> None:
adapter_name = config.credentials.type
adapter_type = self.get_adapter_class_by_name(adapter_name)

Expand Down Expand Up @@ -109,11 +110,11 @@ def cleanup_connections(self):
FACTORY: AdpaterContainer = AdpaterContainer()


def register_adapter(config: HasCredentials) -> None:
def register_adapter(config: AdapterRequiredConfig) -> None:
FACTORY.register_adapter(config)


def get_adapter(config: HasCredentials):
def get_adapter(config: AdapterRequiredConfig):
return FACTORY.lookup_adapter(config.credentials.type)


Expand Down
8 changes: 6 additions & 2 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, project_name, version, project_root, profile_name,
analysis_paths, docs_paths, target_path, snapshot_paths,
clean_targets, log_path, modules_path, quoting, models,
on_run_start, on_run_end, seeds, snapshots, dbt_version,
packages):
packages, query_comment):
self.project_name = project_name
self.version = version
self.project_root = project_root
Expand All @@ -173,6 +173,7 @@ def __init__(self, project_name, version, project_root, profile_name,
self.snapshots = snapshots
self.dbt_version = dbt_version
self.packages = packages
self.query_comment = query_comment

@staticmethod
def _preprocess(project_dict):
Expand Down Expand Up @@ -257,6 +258,7 @@ def from_project_config(cls, project_dict, packages_dict=None):
seeds = project_dict.get('seeds', {})
snapshots = project_dict.get('snapshots', {})
dbt_raw_version = project_dict.get('require-dbt-version', '>=0.0.0')
query_comment = project_dict.get('query-comment')

try:
dbt_version = _parse_versions(dbt_raw_version)
Expand Down Expand Up @@ -291,7 +293,8 @@ def from_project_config(cls, project_dict, packages_dict=None):
seeds=seeds,
snapshots=snapshots,
dbt_version=dbt_version,
packages=packages
packages=packages,
query_comment=query_comment,
)
# sanity check - this means an internal issue
project.validate()
Expand Down Expand Up @@ -340,6 +343,7 @@ def to_project_config(self, with_packages=False):
'require-dbt-version': [
v.to_version_string() for v in self.dbt_version
],
'query-comment': self.query_comment,
})
if with_packages:
result.update(self.packages.to_dict())
Expand Down
7 changes: 5 additions & 2 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, project_name, version, project_root, source_paths,
docs_paths, target_path, snapshot_paths, clean_targets,
log_path, modules_path, quoting, models, on_run_start,
on_run_end, seeds, snapshots, dbt_version, profile_name,
target_name, config, threads, credentials, packages, args):
target_name, config, threads, credentials, packages,
query_comment, args):
# 'vars'
self.args = args
self.cli_vars = parse_cli_vars(getattr(args, 'vars', '{}'))
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(self, project_name, version, project_root, source_paths,
seeds=seeds,
snapshots=snapshots,
dbt_version=dbt_version,
packages=packages
packages=packages,
query_comment=query_comment,
)
# 'profile'
Profile.__init__(
Expand Down Expand Up @@ -101,6 +103,7 @@ def from_parts(cls, project, profile, args):
snapshots=project.snapshots,
dbt_version=project.dbt_version,
packages=project.packages,
query_comment=project.query_comment,
profile_name=profile.profile_name,
target_name=profile.target_name,
config=profile.config,
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,7 @@ def to_dict(self, omit_none=True, validate=False, *, with_aliases=False):

class HasCredentials(Protocol):
credentials: Credentials


class AdapterRequiredConfig(HasCredentials):
query_comment: Optional[str]
2 changes: 1 addition & 1 deletion core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
seeds: Dict[str, Any] = field(default_factory=dict)
snapshots: Dict[str, Any] = field(default_factory=dict)
packages: List[PackageSpec] = field(default_factory=list)
query_comment: Optional[str] = None

@classmethod
def from_dict(cls, data, validate=True):
Expand All @@ -176,7 +177,6 @@ class UserConfig(ExtensibleJsonSchemaMixin, Replaceable):
use_colors: bool = DEFAULT_USE_COLORS
partial_parse: Optional[bool] = None
printer_width: Optional[int] = None
query_comment: Optional[str] = None

def set_values(self, cookie_dir):
if self.send_anonymous_usage_stats:
Expand Down
10 changes: 9 additions & 1 deletion core/dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@
FILE_NOT_FOUND = 'file not found'


class QueryCommentedProfile(Profile):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.query_comment = None


class DebugTask(BaseTask):
def __init__(self, args, config):
super().__init__(args, config)
Expand Down Expand Up @@ -209,7 +215,9 @@ def _load_profile(self):
self.profile_name = self._choose_profile_name()
self.target_name = self._choose_target_name()
try:
self.profile = Profile.from_args(self.args, self.profile_name)
self.profile = QueryCommentedProfile.from_args(
self.args, self.profile_name
)
except dbt.exceptions.DbtConfigError as exc:
self.profile_fail_details = str(exc)
return red('ERROR invalid')
Expand Down
32 changes: 32 additions & 0 deletions test/integration/049_dbt_debug_test/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,35 @@ def test_postgres_nopass(self):
def test_postgres_wronguser(self):
self.run_dbt(['debug', '--target', 'wronguser'])
self.assertGotValue(re.compile(r'\s+Connection test'), 'ERROR')


class TestDebugInvalidProject(DBTIntegrationTest):
@property
def schema(self):
return 'dbt_debug_049'

@staticmethod
def dir(value):
return os.path.normpath(value)

@property
def models(self):
return self.dir('models')

@pytest.fixture(autouse=True)
def capsys(self, capsys):
self.capsys = capsys

@use_profile('postgres')
def test_postgres_badproject(self):
# load a special project that is an error
self.use_default_project(overrides={
'invalid-key': 'not a valid key so this is bad project',
})
self.run_dbt(['debug', '--profile', 'test'])
splitout = self.capsys.readouterr().out.split('\n')
for line in splitout:
if line.strip().startswith('dbt_project.yml file'):
self.assertIn('ERROR invalid', line)
elif line.strip().startswith('profiles.yml file'):
self.assertNotIn('ERROR invalid', line)
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def run_assert_comments(self):

class TestQueryComments(QueryComments):
@property
def profile_config(self):
return {'config': {'query_comment': 'dbt\nrules!\n'}}
def project_config(self):
return {'query-comment': 'dbt\nrules!\n'}

def matches_comment(self, msg) -> bool:
self.assertTrue(
Expand Down
3 changes: 3 additions & 0 deletions test/integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def set_packages(self):
with open('packages.yml', 'w') as f:
yaml.safe_dump(self.packages_config, f, default_flow_style=True)

def test_only_config(self):
return None

def load_config(self):
# we've written our profile and project. Now we want to instantiate a
# fresh adapter for the tests.
Expand Down
4 changes: 1 addition & 3 deletions test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def setUp(self):
flags.STRICT_MODE = True

self.raw_profile = {
'config': {
'query_comment': 'dbt'
},
'outputs': {
'oauth': {
'type': 'bigquery',
Expand Down Expand Up @@ -62,6 +59,7 @@ def setUp(self):
'version': '0.1',
'project-root': '/tmp/dbt/does-not-exist',
'profile': 'default',
'query-comment': 'dbt',
}

def get_adapter(self, target):
Expand Down
10 changes: 3 additions & 7 deletions test/unit/test_postgres_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ def setUp(self):
'version': '0.1',
'profile': 'test',
'project-root': '/tmp/dbt/does-not-exist',
'query-comment': 'dbt',
}
profile_cfg = {
'config': {
'query_comment': 'dbt'
},
'outputs': {
'test': {
'type': 'postgres',
Expand Down Expand Up @@ -207,9 +205,6 @@ def setUp(self):
}

profile_cfg = {
'config': {
'query_comment': 'dbt'
},
'outputs': {
'test': self.target_dict,
},
Expand All @@ -223,7 +218,8 @@ def setUp(self):
'quoting': {
'identifier': False,
'schema': True,
}
},
'query-comment': 'dbt',
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
Expand Down
6 changes: 2 additions & 4 deletions test/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def setUp(self):
flags.STRICT_MODE = True

profile_cfg = {
'config': {
'query_comment': 'dbt'
},
'outputs': {
'test': {
'type': 'redshift',
Expand All @@ -51,7 +48,8 @@ def setUp(self):
'quoting': {
'identifier': False,
'schema': True,
}
},
'query-comment': 'dbt',
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
Expand Down
8 changes: 3 additions & 5 deletions test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def setUp(self):
flags.STRICT_MODE = False

profile_cfg = {
'config': {
'query_comment': 'dbt'
},
'outputs': {
'test': {
'type': 'snowflake',
Expand All @@ -42,10 +39,11 @@ def setUp(self):
'quoting': {
'identifier': False,
'schema': True,
}
},
'query-comment': 'dbt',
}
self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.assertEqual(self.config.config.query_comment, 'dbt')
self.assertEqual(self.config.query_comment, 'dbt')

self.handle = mock.MagicMock(
spec=snowflake_connector.SnowflakeConnection)
Expand Down

0 comments on commit ae58199

Please sign in to comment.