Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: faster snowflake catalogs (#2009) #2037

Merged
merged 4 commits into from
Feb 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import Future # noqa - we use this for typing only
from contextlib import contextmanager
from datetime import datetime
from typing import (
Expand All @@ -17,10 +19,11 @@
import dbt.flags

from dbt import deprecations
from dbt.clients.agate_helper import empty_table
from dbt.clients.agate_helper import empty_table, merge_tables
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.exceptions import warn_or_error
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values
Expand Down Expand Up @@ -117,7 +120,7 @@ def _relation_name(rel: Optional[BaseRelation]) -> str:
return str(rel)


class SchemaSearchMap(dict):
class SchemaSearchMap(Dict[InformationSchema, Set[Optional[str]]]):
"""A utility class to keep track of what information_schema tables to
search for what schemas
"""
Expand Down Expand Up @@ -1009,29 +1012,51 @@ def _catalog_filter_table(
"""
return table.where(_catalog_filter_schemas(manifest))

def _get_catalog_information_schemas(
self, manifest: Manifest
) -> List[InformationSchema]:
return list(self._get_cache_schemas(manifest).keys())
def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Set[str],
manifest: Manifest,
) -> agate.Table:

def get_catalog(self, manifest: Manifest) -> agate.Table:
"""Get the catalog for this manifest by running the get catalog macro.
Returns an agate.Table of catalog information.
"""
information_schemas = self._get_catalog_information_schemas(manifest)
# make it a list so macros can index into it.
kwargs = {'information_schemas': information_schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
release=True,
# pass in the full manifest so we get any local project overrides
manifest=manifest,
)
name = '.'.join([
str(information_schema.database),
'information_schema'
])

with self.connection_named(name):
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
release=True,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)

results = self._catalog_filter_table(table, manifest)
return results

def get_catalog(
self, manifest: Manifest
) -> Tuple[agate.Table, List[Exception]]:
# snowflake is super slow. split it out into the specified threads
num_threads = self.config.threads
schema_map = self._get_cache_schemas(manifest)

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [
executor.submit(self._get_one_catalog, info, schemas, manifest)
for info, schemas in schema_map.items() if len(schemas) > 0
]
catalogs, exceptions = catch_as_completed(futures)

return catalogs, exceptions

def cancel_open_connections(self):
"""Cancel all open connections."""
return self.connections.cancel_open()
Expand Down Expand Up @@ -1104,3 +1129,31 @@ def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None:
The second parameter is the value returned by pre_mdoel_hook.
"""
pass


def catch_as_completed(
futures # typing: List[Future[agate.Table]]
) -> Tuple[agate.Table, List[Exception]]:

# catalogs: agate.Table = agate.Table(rows=[])
tables: List[agate.Table] = []
exceptions: List[Exception] = []

for future in as_completed(futures):
exc = future.exception()
# we want to re-raise on ctrl+c and BaseException
if exc is None:
catalog = future.result()
tables.append(catalog)
elif (
isinstance(exc, KeyboardInterrupt) or
not isinstance(exc, Exception)
):
raise exc
else:
warn_or_error(
f'Encountered an error while generating catalog: {str(exc)}'
)
# exc is not None, derives from Exception, and isn't ctrl+c
exceptions.append(exc)
return merge_tables(tables), exceptions
86 changes: 85 additions & 1 deletion core/dbt/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import datetime
import isodate
import json
from typing import Iterable
from typing import Iterable, List, Dict, Union

from dbt.exceptions import RuntimeException


BOM = BOM_UTF8.decode('utf-8') # '\ufeff'
Expand Down Expand Up @@ -104,3 +106,85 @@ def from_csv(abspath, text_columns):
if fp.read(1) != BOM:
fp.seek(0)
return agate.Table.from_csv(fp, column_types=type_tester)


class _NullMarker:
pass


NullableAgateType = Union[agate.data_types.DataType, _NullMarker]


class ColumnTypeBuilder(Dict[str, NullableAgateType]):
def __init__(self):
super().__init__()

def __setitem__(self, key, value):
if key not in self:
super().__setitem__(key, value)
return

existing_type = self[key]
if isinstance(existing_type, _NullMarker):
# overwrite
super().__setitem__(key, value)
elif isinstance(value, _NullMarker):
# use the existing value
return
elif not isinstance(value, type(existing_type)):
# actual type mismatch!
raise RuntimeException(
f'Tables contain columns with the same names ({key}), '
f'but different types ({value} vs {existing_type})'
)

def finalize(self) -> Dict[str, agate.data_types.DataType]:
result: Dict[str, agate.data_types.DataType] = {}
for key, value in self.items():
if isinstance(value, _NullMarker):
# this is what agate would do.
result[key] = agate.data_types.Number()
else:
result[key] = value
return result


def _merged_column_types(
tables: List[agate.Table]
) -> Dict[str, agate.data_types.DataType]:
# this is a lot like agate.Table.merge, but with handling for all-null
# rows being "any type".
new_columns: ColumnTypeBuilder = ColumnTypeBuilder()
for table in tables:
for i in range(len(table.columns)):
column_name: str = table.column_names[i]
column_type: NullableAgateType = table.column_types[i]
# avoid over-sensitive type inference
if all(x is None for x in table.columns[column_name]):
column_type = _NullMarker()
new_columns[column_name] = column_type

return new_columns.finalize()


def merge_tables(tables: List[agate.Table]) -> agate.Table:
"""This is similar to agate.Table.merge, but it handles rows of all 'null'
values more gracefully during merges.
"""
new_columns = _merged_column_types(tables)
column_names = tuple(new_columns.keys())
column_types = tuple(new_columns.values())

rows: List[agate.Row] = []
for table in tables:
if (
table.column_names == column_names and
table.column_types == column_types
):
rows.extend(table.rows)
else:
for row in table.rows:
data = [row.get(name, None) for name in column_names]
rows.append(agate.Row(data, column_names))
# _is_fork to tell agate that we already made things into `Row`s.
return agate.Table(rows, column_names, column_types, _is_fork=True)
1 change: 1 addition & 0 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,5 @@ def key(self) -> CatalogKey:
class CatalogResults(JsonSchemaMixin, Writable):
nodes: Dict[str, CatalogTable]
generated_at: datetime
errors: Optional[List[str]]
_compile_results: Optional[Any] = None
1 change: 1 addition & 0 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def from_result(
return cls(
nodes=base.nodes,
generated_at=base.generated_at,
errors=base.errors,
_compile_results=base._compile_results,
logs=logs,
tags=tags,
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/include/global_project/macros/adapters/common.sql
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@
{% endmacro %}


{% macro get_catalog(information_schemas) -%}
{{ return(adapter_macro('get_catalog', information_schemas)) }}
{% macro get_catalog(information_schema, schemas) -%}
{{ return(adapter_macro('get_catalog', information_schema, schemas)) }}
{%- endmacro %}

{% macro default__get_catalog(information_schemas) -%}
{% macro default__get_catalog(information_schema, schemas) -%}

{% set typename = adapter.type() %}
{% set msg -%}
Expand Down
26 changes: 23 additions & 3 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from dbt.exceptions import InternalException
from dbt.include.global_project import DOCS_INDEX_FILE_PATH
from dbt.logger import GLOBAL_LOGGER as logger
import dbt.ui.printer
import dbt.utils
import dbt.compilation
Expand Down Expand Up @@ -194,7 +195,9 @@ def run(self):
dbt.ui.printer.print_timestamped_line(
'compile failed, cannot generate docs'
)
return CatalogResults({}, datetime.utcnow(), compile_results)
return CatalogResults(
{}, datetime.utcnow(), compile_results, None
)

shutil.copyfile(
DOCS_INDEX_FILE_PATH,
Expand All @@ -208,42 +211,59 @@ def run(self):
adapter = get_adapter(self.config)
with adapter.connection_named('generate_catalog'):
dbt.ui.printer.print_timestamped_line("Building catalog")
catalog_table = adapter.get_catalog(self.manifest)
catalog_table, exceptions = adapter.get_catalog(self.manifest)

catalog_data: List[PrimitiveDict] = [
dict(zip(catalog_table.column_names, map(_coerce_decimal, row)))
for row in catalog_table
]

catalog = Catalog(catalog_data)

errors: Optional[List[str]] = None
if exceptions:
errors = [str(e) for e in exceptions]

results = self.get_catalog_results(
nodes=catalog.make_unique_id_map(self.manifest),
generated_at=datetime.utcnow(),
compile_results=compile_results,
errors=errors,
)

path = os.path.join(self.config.target_path, CATALOG_FILENAME)
results.write(path)
write_manifest(self.config, self.manifest)

if exceptions:
logger.error(
'dbt encountered {} failure{} while writing the catalog'
.format(len(exceptions), (len(exceptions) != 1) * 's')
)

dbt.ui.printer.print_timestamped_line(
'Catalog written to {}'.format(os.path.abspath(path))
)

return results

def get_catalog_results(
self,
nodes: Dict[str, CatalogTable],
generated_at: datetime,
compile_results: Optional[Any]
compile_results: Optional[Any],
errors: Optional[List[str]]
) -> CatalogResults:
return CatalogResults(
nodes=nodes,
generated_at=generated_at,
_compile_results=compile_results,
errors=errors,
)

def interpret_results(self, results):
if results.errors:
return False
compile_results = results._compile_results
if compile_results is None:
return True
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ def set_args(self, params: RPCDocsGenerateParameters) -> None:
self.args.compile = params.compile

def get_catalog_results(
self, nodes, generated_at, compile_results
self, nodes, generated_at, compile_results, errors
) -> RemoteCatalogResults:
return RemoteCatalogResults(
nodes=nodes,
generated_at=datetime.utcnow(),
_compile_results=compile_results,
errors=errors,
logs=[],
)

Expand Down
22 changes: 11 additions & 11 deletions plugins/bigquery/dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Set

import dbt.deprecations
import dbt.exceptions
Expand All @@ -7,6 +7,7 @@
import dbt.clients.agate_helper

from dbt.adapters.base import BaseAdapter, available, RelationType
from dbt.adapters.base.impl import SchemaSearchMap
from dbt.adapters.bigquery.relation import (
BigQueryRelation, BigQueryInformationSchema
)
Expand Down Expand Up @@ -492,23 +493,22 @@ def _catalog_filter_table(
})
return super()._catalog_filter_table(table, manifest)

def _get_catalog_information_schemas(
self, manifest: Manifest
) -> List[BigQueryInformationSchema]:
def _get_cache_schemas(
self, manifest: Manifest, exec_only: bool = False
) -> SchemaSearchMap:
candidates = super()._get_cache_schemas(manifest, exec_only)
db_schemas: Dict[str, Set[str]] = {}
result = SchemaSearchMap()

candidates = super()._get_catalog_information_schemas(manifest)
information_schemas = []
db_schemas = {}
for candidate in candidates:
for candidate, schemas in candidates.items():
database = candidate.database
if database not in db_schemas:
db_schemas[database] = set(self.list_schemas(database))
if candidate.schema in db_schemas[database]:
information_schemas.append(candidate)
result[candidate] = schemas
else:
logger.debug(
'Skipping catalog for {}.{} - schema does not exist'
.format(database, candidate.schema)
)

return information_schemas
return result
Loading