Skip to content

Commit

Permalink
(#2984) Prevent Agate from coercing values in query result sets
Browse files Browse the repository at this point in the history
  • Loading branch information
drewbanin committed Jun 26, 2021
1 parent 41610b8 commit 30a9c87
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 9 deletions.
4 changes: 2 additions & 2 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def get_result_from_cursor(cls, cursor: Any) -> agate.Table:
data = cls.process_results(column_names, rows)

return dbt.clients.agate_helper.table_from_data_flat(
data,
column_names
data=data,
column_names=column_names,
)

def execute(
Expand Down
55 changes: 48 additions & 7 deletions core/dbt/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ def cast(self, d):
)


def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
def build_type_tester(
text_columns: Iterable[str],
string_null_values: Optional[List[str]] = None
) -> agate.TypeTester:

if string_null_values is None:
string_null_values = ('null', '')

types = [
agate.data_types.Number(null_values=('null', '')),
agate.data_types.Date(null_values=('null', ''),
Expand All @@ -46,10 +53,10 @@ def build_type_tester(text_columns: Iterable[str]) -> agate.TypeTester:
agate.data_types.Boolean(true_values=('true',),
false_values=('false',),
null_values=('null', '')),
agate.data_types.Text(null_values=('null', ''))
agate.data_types.Text(null_values=string_null_values)
]
force = {
k: agate.data_types.Text(null_values=('null', ''))
k: agate.data_types.Text(null_values=string_null_values)
for k in text_columns
}
return agate.TypeTester(force=force, types=types)
Expand All @@ -66,7 +73,9 @@ def table_from_rows(
if text_only_columns is None:
column_types = DEFAULT_TYPE_TESTER
else:
column_types = build_type_tester(text_only_columns)
# If text_only_columns are present, prevent coercing empty string or
# literal 'null' strings to a None representation.
column_types = build_type_tester(text_only_columns, string_null_values=[])
return agate.Table(rows, column_names, column_types=column_types)


Expand All @@ -84,9 +93,13 @@ def table_from_data(data, column_names: Iterable[str]) -> agate.Table:
table = agate.Table.from_object(data, column_types=DEFAULT_TYPE_TESTER)
return table.select(column_names)


def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
"Convert list of dictionaries into an Agate table"
"""
Convert a list of dictionaries into an Agate table. This method does not
coerce string values into more specific types (eg. '005' will not be
coerced to '5'). Additionally, this method does not coerce values to
None (eg. '' or 'null' will retain their string literal representations).
"""

rows = []
for _row in data:
Expand All @@ -98,7 +111,14 @@ def table_from_data_flat(data, column_names: Iterable[str]) -> agate.Table:
row.append(value)
rows.append(row)

return table_from_rows(rows=rows, column_names=column_names)
text_only_columns = _get_string_columns(rows, column_names)

return table_from_rows(
rows=rows,
column_names=column_names,
text_only_columns=text_only_columns
)



def empty_table():
Expand All @@ -121,6 +141,27 @@ def from_csv(abspath, text_columns):
return agate.Table.from_csv(fp, column_types=type_tester)


def _get_string_columns(rows, column_names):
"""
Detect string columns by peeking at the first row in a result set.
This method can be used to bypass type coercion for stringy columns
that _could_ be coerced into a non-string type, but should not be.
Example: '00005': This value should not be cast to <int> 5.
Example: 'false': This value should not be cast to <bool> False.
Exampel: 'null': This value should not be cast to None
Example: 17: This value does not need to be cast (it is already <int>)
Note: This implementation assumes that column types will not vary across
rows in a query result set.
"""
if len(rows) == 0:
return {}

first_row = rows[0]
return {col for col, value in zip(column_names, first_row) if type(value) is str}


class _NullMarker:
pass

Expand Down
29 changes: 29 additions & 0 deletions test/integration/100_rpc_test/sql/generic.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

/*
Every column returned by this query should be interpreted
as a string. If any result is coerced into a non-string (eg.
an int, float, bool, date, NoneType, etc) then dbt did the
wrong thing
*/

select
'' as str_empty_string,

'null' as str_null,

'1' as str_int,
'00005' as str_int_2,
'00' as str_int_3,

'1.1' as str_float,
'00001.1' as str_float_2,

'true' as str_bool,
'True' as str_bool_2,

'2021-01-01' as str_date,
'2021-01-01T12:00:00Z' as str_datetime,

-- this is obviously not a date... but Agate used to think it was!
-- see: https://github.com/fishtown-analytics/dbt/issues/2984
'0010T00000aabbccdd' as str_obviously_not_date
21 changes: 21 additions & 0 deletions test/integration/100_rpc_test/test_execute_fetch_and_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def do_test_file(self, filename):
self.assertTrue(len(table.rows) > 0, "agate table had no rows")

self.do_test_pickle(table)
return table

def assert_all_columns_are_strings(self, table):
for row in table:
for value in row:
self.assertEqual(type(value), str, f'Found a not-string: {value} in row {row}')

@use_profile('bigquery')
def test__bigquery_fetch_and_serialize(self):
Expand All @@ -51,3 +57,18 @@ def test__snowflake_fetch_and_serialize(self):
@use_profile('redshift')
def test__redshift_fetch_and_serialize(self):
self.do_test_file('redshift.sql')

@use_profile('bigquery')
def test__bigquery_type_coercion(self):
table = self.do_test_file('generic.sql')
self.assert_all_columns_are_strings(table)

@use_profile('snowflake')
def test__snowflake_type_coercion(self):
table = self.do_test_file('generic.sql')
self.assert_all_columns_are_strings(table)

@use_profile('redshift')
def test__redshift_type_coercion(self):
table = self.do_test_file('generic.sql')
self.assert_all_columns_are_strings(table)
26 changes: 26 additions & 0 deletions test/unit/test_agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,29 @@ def test_merge_mixed(self):
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 6)

def test_nocast_string_types(self):
# String fields should not be coerced into a representative type
# See: https://github.com/fishtown-analytics/dbt/issues/2984

column_names = ['a', 'b', 'c', 'd', 'e']
text_only_columns = {'a', 'b', 'c'}
result_set = [
{'a': '0005', 'b': '01T00000aabbccdd', 'c': 'true', 'd': 10, 'e': False},
{'a': '0006', 'b': '01T00000aabbccde', 'c': 'false', 'd': 11, 'e': True},
]

tbl = agate_helper.table_from_data_flat(
data=result_set,
column_names=column_names,
text_only_columns=text_only_columns
)
self.assertEqual(len(tbl), len(result_set))

expected = [
['0005', '01T00000aabbccdd', 'true', Decimal(10), False],
['0006', '01T00000aabbccde', 'false', Decimal(11), True],
]

for i, row in enumerate(tbl):
self.assertEqual(list(row), expected[i])

0 comments on commit 30a9c87

Please sign in to comment.