Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
Added a custom table merge implementation that tracks if a row is all null and merges those as "any type".
 - added unit tests for that!
Removed some schema casing things
fixed pluralization (it was reversed)
  • Loading branch information
Jacob Beck committed Feb 4, 2020
1 parent c1af3ab commit 04bc2a8
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 16 deletions.
16 changes: 6 additions & 10 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dbt.flags

from dbt import deprecations
from dbt.clients.agate_helper import empty_table
from dbt.clients.agate_helper import empty_table, merge_tables
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSeedNode
Expand Down Expand Up @@ -1024,15 +1024,10 @@ def _get_one_catalog(
'information_schema'
])

# calculate the possible schemas for a given schema name
all_schema_names: Set[str] = set()
for schema in schemas:
all_schema_names.update({schema, schema.lower(), schema.upper()})

with self.connection_named(name):
kwargs = {
'information_schema': information_schema,
'schemas': all_schema_names
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
Expand Down Expand Up @@ -1140,15 +1135,16 @@ def catch_as_completed(
futures # typing: List[Future[agate.Table]]
) -> Tuple[agate.Table, List[Exception]]:

catalogs: agate.Table = agate.Table(rows=[])
# catalogs: agate.Table = agate.Table(rows=[])
tables: List[agate.Table] = []
exceptions: List[Exception] = []

for future in as_completed(futures):
exc = future.exception()
# we want to re-raise on ctrl+c and BaseException
if exc is None:
catalog = future.result()
catalogs = agate.Table.merge([catalogs, catalog])
tables.append(catalog)
elif (
isinstance(exc, KeyboardInterrupt) or
not isinstance(exc, Exception)
Expand All @@ -1160,4 +1156,4 @@ def catch_as_completed(
)
# exc is not None, derives from Exception, and isn't ctrl+c
exceptions.append(exc)
return catalogs, exceptions
return merge_tables(tables), exceptions
86 changes: 85 additions & 1 deletion core/dbt/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import datetime
import isodate
import json
from typing import Iterable
from typing import Iterable, List, Dict, Union

from dbt.exceptions import RuntimeException


BOM = BOM_UTF8.decode('utf-8') # '\ufeff'
Expand Down Expand Up @@ -104,3 +106,85 @@ def from_csv(abspath, text_columns):
if fp.read(1) != BOM:
fp.seek(0)
return agate.Table.from_csv(fp, column_types=type_tester)


class _NullMarker:
pass


NullableAgateType = Union[agate.data_types.DataType, _NullMarker]


class ColumnTypeBuilder(Dict[str, NullableAgateType]):
def __init__(self):
super().__init__()

def __setitem__(self, key, value):
if key not in self:
super().__setitem__(key, value)
return

existing_type = self[key]
if isinstance(existing_type, _NullMarker):
# overwrite
super().__setitem__(key, value)
elif isinstance(value, _NullMarker):
# use the existing value
return
elif not isinstance(value, type(existing_type)):
# actual type mismatch!
raise RuntimeException(
f'Tables contain columns with the same names ({key}), '
f'but different types ({value} vs {existing_type})'
)

def finalize(self) -> Dict[str, agate.data_types.DataType]:
result: Dict[str, agate.data_types.DataType] = {}
for key, value in self.items():
if isinstance(value, _NullMarker):
# this is what agate would do.
result[key] = agate.data_types.Number()
else:
result[key] = value
return result


def _merged_column_types(
tables: List[agate.Table]
) -> Dict[str, agate.data_types.DataType]:
# this is a lot like agate.Table.merge, but with handling for all-null
# rows being "any type".
new_columns: ColumnTypeBuilder = ColumnTypeBuilder()
for table in tables:
for i in range(len(table.columns)):
column_name: str = table.column_names[i]
column_type: NullableAgateType = table.column_types[i]
# avoid over-sensitive type inference
if all(x is None for x in table.columns[column_name]):
column_type = _NullMarker()
new_columns[column_name] = column_type

return new_columns.finalize()


def merge_tables(tables: List[agate.Table]) -> agate.Table:
"""This is similar to agate.Table.merge, but it handles rows of all 'null'
values more gracefully during merges.
"""
new_columns = _merged_column_types(tables)
column_names = tuple(new_columns.keys())
column_types = tuple(new_columns.values())

rows: List[agate.Row] = []
for table in tables:
if (
table.column_names == column_names and
table.column_types == column_types
):
rows.extend(table.rows)
else:
for row in table.rows:
data = [row.get(name, None) for name in column_names]
rows.append(agate.Row(data, column_names))
# _is_fork to tell agate that we already made things into `Row`s.
return agate.Table(rows, column_names, column_types, _is_fork=True)
10 changes: 5 additions & 5 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,16 @@ def run(self):
results.write(path)
write_manifest(self.config, self.manifest)

dbt.ui.printer.print_timestamped_line(
'Catalog written to {}'.format(os.path.abspath(path))
)

if exceptions:
logger.error(
'dbt encountered {} failure{} while writing the catalog'
.format(len(exceptions), (len(exceptions) == 1) * 's')
.format(len(exceptions), (len(exceptions) != 1) * 's')
)

dbt.ui.printer.print_timestamped_line(
'Catalog written to {}'.format(os.path.abspath(path))
)

return results

def get_catalog_results(
Expand Down
38 changes: 38 additions & 0 deletions test/unit/test_agate_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

import agate

from datetime import datetime
from decimal import Decimal
from isodate import tzinfo
Expand Down Expand Up @@ -95,3 +97,39 @@ def test_datetime_formats(self):
fp.write('a\n{}'.format(dt).encode('utf-8'))
tbl = agate_helper.from_csv(path, ())
self.assertEqual(tbl[0][0], expected)

def test_merge_allnull(self):
t1 = agate.Table([(1, 'a', None), (2, 'b', None)], ('a', 'b', 'c'))
t2 = agate.Table([(3, 'c', None), (4, 'd', None)], ('a', 'b', 'c'))
result = agate_helper.merge_tables([t1, t2])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Number)
self.assertEqual(len(result), 4)

def test_merge_mixed(self):
t1 = agate.Table([(1, 'a', None), (2, 'b', None)], ('a', 'b', 'c'))
t2 = agate.Table([(3, 'c', 'dog'), (4, 'd', 'cat')], ('a', 'b', 'c'))
t3 = agate.Table([(3, 'c', None), (4, 'd', None)], ('a', 'b', 'c'))

result = agate_helper.merge_tables([t1, t2])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 4)

result = agate_helper.merge_tables([t2, t3])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 4)

result = agate_helper.merge_tables([t1, t2, t3])
self.assertEqual(result.column_names, ('a', 'b', 'c'))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 6)

0 comments on commit 04bc2a8

Please sign in to comment.