Skip to content

Commit

Permalink
Address comments from code review.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarp committed Mar 4, 2021
1 parent 7b0c74c commit 7b4b728
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
4 changes: 2 additions & 2 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbt.logger import GLOBAL_LOGGER as logger
from typing_extensions import Protocol
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, HyphenatedDbtClassMixin,
ValidatedStringMixin, register_pattern
)
from dbt.contracts.util import Replaceable
Expand Down Expand Up @@ -212,7 +212,7 @@ def to_target_dict(self):


@dataclass
class QueryComment(dbtClassMixin):
class QueryComment(HyphenatedDbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False
job_label: bool = False
Expand Down
28 changes: 15 additions & 13 deletions plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,23 +307,15 @@ def raw_execute(self, sql, fetch=False, *, use_legacy_sql=False):

logger.debug('On {}: {}', conn.name, sql)

labels = {}
if self.profile.query_comment.job_label:
query_comment = self.query_header.comment.query_comment
labels = self._labels_from_query_comment(query_comment)
else:
labels = {}

if active_user:
labels['dbt_invocation_id'] = active_user.invocation_id

if self.profile.query_comment.job_label:
try:
comment_labels = json.loads(
self.query_header.comment.query_comment
)
labels.update({
_sanitize_label(key): _sanitize_label(str(value))
for key, value in comment_labels.items()
})
except (TypeError, ValueError):
pass

job_params = {'use_legacy_sql': use_legacy_sql, 'labels': labels}

priority = conn.credentials.priority
Expand Down Expand Up @@ -558,6 +550,16 @@ def _retry_generator(self):
initial=self.DEFAULT_INITIAL_DELAY,
maximum=self.DEFAULT_MAXIMUM_DELAY)

def _labels_from_query_comment(self, comment: str) -> Dict:
try:
comment_labels = json.loads(comment)
except (TypeError, ValueError):
return {'query_comment': _sanitize_label(comment)}
return {
_sanitize_label(key): _sanitize_label(str(value))
for key, value in comment_labels.items()
}


class _ErrorCounter(object):
"""Counts errors seen up to a threshold then raises the next error."""
Expand Down
12 changes: 10 additions & 2 deletions test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import agate
import decimal
import json
import re
import unittest
from contextlib import contextmanager
Expand Down Expand Up @@ -588,7 +589,6 @@ def test_query_and_results(self, mock_bq):
self.mock_client.query.assert_called_once_with(
'sql', job_config=mock_bq.QueryJobConfig())


def test_copy_bq_table_appends(self):
self._copy_table(
write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
Expand All @@ -615,12 +615,20 @@ def test_copy_bq_table_truncates(self):
kwargs['job_config'].write_disposition,
dbt.adapters.bigquery.impl.WRITE_TRUNCATE)

def test_job_labels_valid_json(self):
expected = {"key": "value"}
labels = self.connections._labels_from_query_comment(json.dumps(expected))
self.assertEqual(labels, expected)

def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table, conn):
return google.cloud.bigquery.table.TableReference.from_string(
'{}.{}.{}'.format(proj, ds, table))

def _copy_table(self, write_disposition):

self.connections.table_ref = self._table_ref
source = BigQueryRelation.create(
database='project', schema='dataset', identifier='table1')
Expand Down

0 comments on commit 7b4b728

Please sign in to comment.