Skip to content

Commit

Permalink
Merge pull request #2199 from ilkinulas/query_comment_append
Browse files Browse the repository at this point in the history
Append query comment
  • Loading branch information
beckjake authored Mar 24, 2020
2 parents bdcf10e + 0eed51d commit efcf78b
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 86 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
## dbt next (release TBD)

### Features
- Support for appending query comments to SQL queries. ([#2138](https://github.com/fishtown-analytics/dbt/issues/2138))

### Fixes
- When a jinja value is undefined, give a helpful error instead of failing with cryptic "cannot pickle ParserMacroCapture" errors ([#2110](https://github.com/fishtown-analytics/dbt/issues/2110), [#2184](https://github.com/fishtown-analytics/dbt/pull/2184))

Contributers:
- [@ilkinulas](https://github.com/ilkinulas) [#2199](https://github.com/fishtown-analytics/dbt/pull/2199)

## dbt 0.16.0 (Release date TBD)

### Fixes
Expand Down
60 changes: 22 additions & 38 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,10 @@
from dbt.clients.jinja import QueryStringGenerator

from dbt.context.configured import generate_query_header_context
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.connection import AdapterRequiredConfig, QueryComment
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.exceptions import RuntimeException
from dbt.helper_types import NoValue


DEFAULT_QUERY_COMMENT = '''
{%- set comment_dict = {} -%}
{%- do comment_dict.update(
app='dbt',
dbt_version=dbt_version,
profile_name=target.get('profile_name'),
target_name=target.get('target_name'),
) -%}
{%- if node is not none -%}
{%- do comment_dict.update(
node_id=node.unique_id,
) -%}
{% else %}
{# in the node context, the connection name is the node_id #}
{%- do comment_dict.update(connection_name=connection_name) -%}
{%- endif -%}
{{ return(tojson(comment_dict)) }}
'''


class NodeWrapper:
Expand All @@ -47,21 +26,32 @@ class _QueryComment(local):
"""
def __init__(self, initial):
self.query_comment: Optional[str] = initial
self.append = False

def add(self, sql: str) -> str:
if not self.query_comment:
return sql
else:
return '/* {} */\n{}'.format(self.query_comment.strip(), sql)

def set(self, comment: Optional[str]):
if self.append:
# replace last ';' with '<comment>;'
sql = sql.rstrip()
if sql[-1] == ';':
sql = sql[:-1]
return '{}\n/* {} */;'.format(sql, self.query_comment.strip())

return '{}\n/* {} */'.format(sql, self.query_comment.strip())

return '/* {} */\n{}'.format(self.query_comment.strip(), sql)

def set(self, comment: Optional[str], append: bool):
if isinstance(comment, str) and '*/' in comment:
# tell the user "no" so they don't hurt themselves by writing
# garbage
raise RuntimeException(
f'query comment contains illegal value "*/": {comment}'
)
self.query_comment = comment
self.append = append


QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str]
Expand All @@ -87,18 +77,8 @@ def __init__(self, config: AdapterRequiredConfig, manifest: Manifest):
self.comment = _QueryComment(None)
self.reset()

def _get_comment_macro(self):
if (
self.config.query_comment != NoValue() and
self.config.query_comment
):
return self.config.query_comment
# if the query comment is null/empty string, there is no comment at all
elif not self.config.query_comment:
return None
else:
# else, the default
return DEFAULT_QUERY_COMMENT
def _get_comment_macro(self) -> Optional[str]:
return self.config.query_comment.comment

def _get_context(self) -> Dict[str, Any]:
return generate_query_header_context(self.config, self.manifest)
Expand All @@ -114,4 +94,8 @@ def set(self, name: str, node: Optional[CompileResultNode]):
if node is not None:
wrapped = NodeWrapper(node)
comment_str = self.generator(name, wrapped)
self.comment.set(comment_str)

append = False
if isinstance(self.config.query_comment, QueryComment):
append = self.config.query_comment.append
self.comment.set(comment_str, append)
26 changes: 22 additions & 4 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbt.clients.system import path_exists
from dbt.clients.system import load_file_contents
from dbt.clients.yaml_helper import load_yaml_text
from dbt.contracts.connection import QueryComment
from dbt.exceptions import DbtProjectError
from dbt.exceptions import RecursionException
from dbt.exceptions import SemverException
Expand Down Expand Up @@ -202,6 +203,21 @@ def _raw_project_from(project_root: str) -> Dict[str, Any]:
return project_dict


def _query_comment_from_cfg(
cfg_query_comment: Union[QueryComment, NoValue, str]
) -> QueryComment:
if not cfg_query_comment:
return QueryComment(comment='')

if isinstance(cfg_query_comment, str):
return QueryComment(comment=cfg_query_comment)

if isinstance(cfg_query_comment, NoValue):
return QueryComment()

return cfg_query_comment


@dataclass
class PartialProject:
profile_name: Optional[str]
Expand Down Expand Up @@ -244,7 +260,7 @@ class Project:
snapshots: Dict[str, Any]
dbt_version: List[VersionSpecifier]
packages: Dict[str, Any]
query_comment: Optional[Union[str, NoValue]]
query_comment: QueryComment

@property
def all_source_paths(self) -> List[str]:
Expand Down Expand Up @@ -356,7 +372,8 @@ def from_project_config(
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
if cfg.require_dbt_version is not None:
dbt_raw_version = cfg.require_dbt_version
query_comment = cfg.query_comment

query_comment = _query_comment_from_cfg(cfg.query_comment)

try:
dbt_version = _parse_versions(dbt_raw_version)
Expand Down Expand Up @@ -442,10 +459,11 @@ def to_project_config(self, with_packages=False):
v.to_version_string() for v in self.dbt_version
],
})
if self.query_comment:
result['query-comment'] = self.query_comment.to_dict()

if with_packages:
result.update(self.packages.to_dict())
if self.query_comment != NoValue():
result['query-comment'] = self.query_comment

return result

Expand Down
33 changes: 28 additions & 5 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import itertools
from dataclasses import dataclass, field
from typing import (
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable,
Union
)
Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable)
from typing_extensions import Protocol

from hologram import JsonSchemaMixin
Expand All @@ -14,7 +12,6 @@

from dbt.contracts.util import Replaceable
from dbt.exceptions import InternalException
from dbt.helper_types import NoValue
from dbt.utils import translate_aliases


Expand Down Expand Up @@ -175,8 +172,34 @@ class HasCredentials(Protocol):
threads: int


DEFAULT_QUERY_COMMENT = '''
{%- set comment_dict = {} -%}
{%- do comment_dict.update(
app='dbt',
dbt_version=dbt_version,
profile_name=target.get('profile_name'),
target_name=target.get('target_name'),
) -%}
{%- if node is not none -%}
{%- do comment_dict.update(
node_id=node.unique_id,
) -%}
{% else %}
{# in the node context, the connection name is the node_id #}
{%- do comment_dict.update(connection_name=connection_name) -%}
{%- endif -%}
{{ return(tojson(comment_dict)) }}
'''


@dataclass
class QueryComment(JsonSchemaMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False


class AdapterRequiredConfig(HasCredentials, Protocol):
project_name: str
query_comment: Optional[Union[str, NoValue]]
query_comment: QueryComment
cli_vars: Dict[str, Any]
target_path: str
6 changes: 3 additions & 3 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dbt.contracts.util import Replaceable, Mergeable, list_str
from dbt.contracts.connection import UserConfigContract
from dbt.contracts.connection import UserConfigContract, QueryComment
from dbt.helper_types import NoValue
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt import tracking
from dbt.ui import printer
from dbt.helper_types import NoValue

from hologram import JsonSchemaMixin, ValidationError
from hologram.helpers import HyphenatedJsonSchemaMixin, register_pattern, \
Expand Down Expand Up @@ -162,7 +162,7 @@ class Project(HyphenatedJsonSchemaMixin, Replaceable):
seeds: Dict[str, Any] = field(default_factory=dict)
snapshots: Dict[str, Any] = field(default_factory=dict)
packages: List[PackageSpec] = field(default_factory=list)
query_comment: Optional[Union[str, NoValue]] = NoValue()
query_comment: Optional[Union[QueryComment, NoValue, str]] = NoValue()

@classmethod
def from_dict(cls, data, validate=True):
Expand Down
42 changes: 11 additions & 31 deletions core/dbt/helper_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# never name this package "types", or mypy will crash in ugly ways
from dataclasses import dataclass
from datetime import timedelta
from typing import NewType, Dict
from typing import NewType

from hologram import (
FieldEncoder, JsonSchemaMixin, JsonDict, ValidationError
)

from hologram.helpers import StrEnum

Port = NewType('Port', int)

Expand Down Expand Up @@ -37,41 +38,20 @@ def json_schema(self) -> JsonDict:
return {'type': 'number'}


class NoValue:
"""Sometimes, you want a way to say none that isn't None"""
def __eq__(self, other):
return isinstance(other, NoValue)

class NVEnum(StrEnum):
novalue = 'novalue'

class NoValueEncoder(FieldEncoder):
# the FieldEncoder class specifies a narrow range that only includes value
# types (str, float, None) but we want to support something extra
def to_wire(self, value: NoValue) -> Dict[str, str]: # type: ignore
return {'novalue': 'novalue'}
def __eq__(self, other):
return isinstance(other, NVEnum)

def to_python(self, value) -> NoValue:
if (
not isinstance(value, dict) or
'novalue' not in value or
value['novalue'] != 'novalue'
):
raise ValidationError('Got invalid NoValue: {}'.format(value))
return NoValue()

@property
def json_schema(self):
return {
'type': 'object',
'properties': {
'novalue': {
'enum': ['novalue'],
}
}
}
@dataclass
class NoValue(JsonSchemaMixin):
"""Sometimes, you want a way to say none that isn't None"""
novalue: NVEnum = NVEnum.novalue


JsonSchemaMixin.register_field_encoders({
Port: PortEncoder(),
timedelta: TimeDeltaFieldEncoder(),
NoValue: NoValueEncoder(),
})
2 changes: 1 addition & 1 deletion core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def read(fname):
'json-rpc>=1.12,<2',
'werkzeug>=0.15,<0.17',
'dataclasses==0.6;python_version<"3.7"',
'hologram==0.0.5',
'hologram==0.0.6',
'logbook>=1.5,<1.6',
'pytest-logbook>=1.2.0,<1.3',
'typing-extensions>=3.7.4,<3.8',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class TestNullQueryComments(TestDefaultQueryComments):
@property
def project_config(self):
cfg = super().project_config
cfg.update({'query-comment': None})
cfg.update({'query-comment': ''})
return cfg

def matches_comment(self, msg) -> bool:
Expand Down
43 changes: 41 additions & 2 deletions test/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from dbt.adapters.postgres import PostgresCredentials
from dbt.adapters.redshift import RedshiftCredentials
from dbt.context.base import generate_base_context
from dbt.contracts.connection import QueryComment, DEFAULT_QUERY_COMMENT
from dbt.contracts.project import PackageConfig, LocalPackage, GitPackage
from dbt.semver import VersionSpecifier
from dbt.task.run_operation import RunOperationTask

from .utils import normalize

from .utils import normalize, config_from_parts_or_dicts

INITIAL_ROOT = os.getcwd()

Expand Down Expand Up @@ -844,6 +844,45 @@ def test_cycle(self):
self.default_project_data, None
)

def test_query_comment_disabled(self):
self.default_project_data.update({
'query-comment': None,
})
project = dbt.config.Project.from_project_config(self.default_project_data, None)
self.assertEqual(project.query_comment.comment, '')
self.assertEqual(project.query_comment.append, False)

self.default_project_data.update({
'query-comment': '',
})
project = dbt.config.Project.from_project_config(self.default_project_data, None)
self.assertEqual(project.query_comment.comment, '')
self.assertEqual(project.query_comment.append, False)

def test_default_query_comment(self):
project = dbt.config.Project.from_project_config(self.default_project_data, None)
self.assertEqual(project.query_comment, QueryComment())

def test_default_query_comment_append(self):
self.default_project_data.update({
'query-comment': {
'append': True
},
})
project = dbt.config.Project.from_project_config(self.default_project_data, None)
self.assertEqual(project.query_comment.comment, DEFAULT_QUERY_COMMENT)
self.assertEqual(project.query_comment.append, True)

def test_custom_query_comment_append(self):
self.default_project_data.update({
'query-comment': {
'comment': 'run by user test',
'append': True
},
})
project = dbt.config.Project.from_project_config(self.default_project_data, None)
self.assertEqual(project.query_comment.comment, 'run by user test')
self.assertEqual(project.query_comment.append, True)

class TestProjectWithConfigs(BaseConfigTest):
def setUp(self):
Expand Down
Loading

0 comments on commit efcf78b

Please sign in to comment.