diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 5129d60d646..41c29c53853 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -19,7 +19,7 @@ import dbt.flags from dbt import deprecations -from dbt.clients.agate_helper import empty_table +from dbt.clients.agate_helper import empty_table, merge_tables from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSeedNode @@ -1024,15 +1024,10 @@ def _get_one_catalog( 'information_schema' ]) - # calculate the possible schemas for a given schema name - all_schema_names: Set[str] = set() - for schema in schemas: - all_schema_names.update({schema, schema.lower(), schema.upper()}) - with self.connection_named(name): kwargs = { 'information_schema': information_schema, - 'schemas': all_schema_names + 'schemas': schemas } table = self.execute_macro( GET_CATALOG_MACRO_NAME, @@ -1140,7 +1135,8 @@ def catch_as_completed( futures # typing: List[Future[agate.Table]] ) -> Tuple[agate.Table, List[Exception]]: - catalogs: agate.Table = agate.Table(rows=[]) + # catalogs: agate.Table = agate.Table(rows=[]) + tables: List[agate.Table] = [] exceptions: List[Exception] = [] for future in as_completed(futures): @@ -1148,7 +1144,7 @@ def catch_as_completed( # we want to re-raise on ctrl+c and BaseException if exc is None: catalog = future.result() - catalogs = agate.Table.merge([catalogs, catalog]) + tables.append(catalog) elif ( isinstance(exc, KeyboardInterrupt) or not isinstance(exc, Exception) @@ -1160,4 +1156,4 @@ def catch_as_completed( ) # exc is not None, derives from Exception, and isn't ctrl+c exceptions.append(exc) - return catalogs, exceptions + return merge_tables(tables), exceptions diff --git a/core/dbt/clients/agate_helper.py b/core/dbt/clients/agate_helper.py index 03c059017d4..e18ac687770 100644 --- a/core/dbt/clients/agate_helper.py +++ b/core/dbt/clients/agate_helper.py @@ -4,7 +4,9 @@ import datetime import isodate import json -from typing import Iterable +from typing import Iterable, List, Dict, Union + +from dbt.exceptions import RuntimeException BOM = BOM_UTF8.decode('utf-8') # '\ufeff' @@ -104,3 +106,85 @@ def from_csv(abspath, text_columns): if fp.read(1) != BOM: fp.seek(0) return agate.Table.from_csv(fp, column_types=type_tester) + + +class _NullMarker: + pass + + +NullableAgateType = Union[agate.data_types.DataType, _NullMarker] + + +class ColumnTypeBuilder(Dict[str, NullableAgateType]): + def __init__(self): + super().__init__() + + def __setitem__(self, key, value): + if key not in self: + super().__setitem__(key, value) + return + + existing_type = self[key] + if isinstance(existing_type, _NullMarker): + # overwrite + super().__setitem__(key, value) + elif isinstance(value, _NullMarker): + # use the existing value + return + elif not isinstance(value, type(existing_type)): + # actual type mismatch! + raise RuntimeException( + f'Tables contain columns with the same names ({key}), ' + f'but different types ({value} vs {existing_type})' + ) + + def finalize(self) -> Dict[str, agate.data_types.DataType]: + result: Dict[str, agate.data_types.DataType] = {} + for key, value in self.items(): + if isinstance(value, _NullMarker): + # this is what agate would do. + result[key] = agate.data_types.Number() + else: + result[key] = value + return result + + +def _merged_column_types( + tables: List[agate.Table] +) -> Dict[str, agate.data_types.DataType]: + # this is a lot like agate.Table.merge, but with handling for all-null + # rows being "any type". + new_columns: ColumnTypeBuilder = ColumnTypeBuilder() + for table in tables: + for i in range(len(table.columns)): + column_name: str = table.column_names[i] + column_type: NullableAgateType = table.column_types[i] + # avoid over-sensitive type inference + if all(x is None for x in table.columns[column_name]): + column_type = _NullMarker() + new_columns[column_name] = column_type + + return new_columns.finalize() + + +def merge_tables(tables: List[agate.Table]) -> agate.Table: + """This is similar to agate.Table.merge, but it handles rows of all 'null' + values more gracefully during merges. + """ + new_columns = _merged_column_types(tables) + column_names = tuple(new_columns.keys()) + column_types = tuple(new_columns.values()) + + rows: List[agate.Row] = [] + for table in tables: + if ( + table.column_names == column_names and + table.column_types == column_types + ): + rows.extend(table.rows) + else: + for row in table.rows: + data = [row.get(name, None) for name in column_names] + rows.append(agate.Row(data, column_names)) + # _is_fork to tell agate that we already made things into `Row`s. + return agate.Table(rows, column_names, column_types, _is_fork=True) diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 00e1a052cfe..18a3003e7c4 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -235,16 +235,16 @@ def run(self): results.write(path) write_manifest(self.config, self.manifest) - dbt.ui.printer.print_timestamped_line( - 'Catalog written to {}'.format(os.path.abspath(path)) - ) - if exceptions: logger.error( 'dbt encountered {} failure{} while writing the catalog' - .format(len(exceptions), (len(exceptions) == 1) * 's') + .format(len(exceptions), (len(exceptions) != 1) * 's') ) + dbt.ui.printer.print_timestamped_line( + 'Catalog written to {}'.format(os.path.abspath(path)) + ) + return results def get_catalog_results( diff --git a/test/unit/test_agate_helper.py b/test/unit/test_agate_helper.py index 4f1e42ab5fc..8f15ff75984 100644 --- a/test/unit/test_agate_helper.py +++ b/test/unit/test_agate_helper.py @@ -1,5 +1,7 @@ import unittest +import agate + from datetime import datetime from decimal import Decimal from isodate import tzinfo @@ -95,3 +97,39 @@ def test_datetime_formats(self): fp.write('a\n{}'.format(dt).encode('utf-8')) tbl = agate_helper.from_csv(path, ()) self.assertEqual(tbl[0][0], expected) + + def test_merge_allnull(self): + t1 = agate.Table([(1, 'a', None), (2, 'b', None)], ('a', 'b', 'c')) + t2 = agate.Table([(3, 'c', None), (4, 'd', None)], ('a', 'b', 'c')) + result = agate_helper.merge_tables([t1, t2]) + self.assertEqual(result.column_names, ('a', 'b', 'c')) + assert isinstance(result.column_types[0], agate.data_types.Number) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Number) + self.assertEqual(len(result), 4) + + def test_merge_mixed(self): + t1 = agate.Table([(1, 'a', None), (2, 'b', None)], ('a', 'b', 'c')) + t2 = agate.Table([(3, 'c', 'dog'), (4, 'd', 'cat')], ('a', 'b', 'c')) + t3 = agate.Table([(3, 'c', None), (4, 'd', None)], ('a', 'b', 'c')) + + result = agate_helper.merge_tables([t1, t2]) + self.assertEqual(result.column_names, ('a', 'b', 'c')) + assert isinstance(result.column_types[0], agate.data_types.Number) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t2, t3]) + self.assertEqual(result.column_names, ('a', 'b', 'c')) + assert isinstance(result.column_types[0], agate.data_types.Number) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + self.assertEqual(len(result), 4) + + result = agate_helper.merge_tables([t1, t2, t3]) + self.assertEqual(result.column_names, ('a', 'b', 'c')) + assert isinstance(result.column_types[0], agate.data_types.Number) + assert isinstance(result.column_types[1], agate.data_types.Text) + assert isinstance(result.column_types[2], agate.data_types.Text) + self.assertEqual(len(result), 6)