Skip to content

Commit

Permalink
Merge pull request #1338 from fishtown-analytics/fix/jeb-snowflake-so…
Browse files Browse the repository at this point in the history
…urce-quoting

Fix snowflake source quoting / information_schema uses
  • Loading branch information
beckjake authored Mar 6, 2019
2 parents 1a700c1 + a335857 commit 03aa086
Show file tree
Hide file tree
Showing 26 changed files with 366 additions and 132 deletions.
101 changes: 85 additions & 16 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import abc
import copy
import multiprocessing
import time

import agate
import pytz
Expand All @@ -13,11 +10,11 @@
import dbt.clients.agate_helper

from dbt.compat import abstractclassmethod, classmethod
from dbt.contracts.connection import Connection
from dbt.node_types import NodeType
from dbt.loader import GraphLoader
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.schema import Column
from dbt.utils import filter_null_values, translate_aliases
from dbt.utils import filter_null_values

from dbt.adapters.base.meta import AdapterMeta, available, available_raw, \
available_deprecated
Expand Down Expand Up @@ -94,6 +91,51 @@ def _utc(dt, source, field_name):
return dt.replace(tzinfo=pytz.UTC)


class SchemaSearchMap(dict):
"""A utility class to keep track of what information_schema tables to
search for what schemas
"""
def add(self, relation):
key = relation.information_schema_only()
if key not in self:
self[key] = set()
self[key].add(relation.schema.lower())

def search(self):
for information_schema_name, schemas in self.items():
for schema in schemas:
yield information_schema_name, schema

def schemas_searched(self):
result = set()
for information_schema_name, schemas in self.items():
result.update(
(information_schema_name.database, schema)
for schema in schemas
)
return result

def flatten(self):
new = self.__class__()

database = None
# iterate once to look for a database name
seen = {r.database.lower() for r in self if r.database}
if len(seen) > 1:
dbt.exceptions.raise_compiler_error(str(seen))
elif len(seen) == 1:
database = list(seen)[0]

for information_schema_name, schema in self.search():
new.add(information_schema_name.incorporate(
path={'database': database, 'schema': schema},
quote_policy={'database': False},
include_policy={'database': False},
))

return new


@six.add_metaclass(AdapterMeta)
class BaseAdapter(object):
"""The BaseAdapter provides an abstract base class for adapters.
Expand Down Expand Up @@ -237,24 +279,44 @@ def _relations_filter_table(cls, table, schemas):
"""
return table.where(_relations_filter_schemas(schemas))

def _get_cache_schemas(self, manifest, exec_only=False):
"""Get a mapping of each node's "information_schema" relations to a
set of all schemas expected in that information_schema.
There may be keys that are technically duplicates on the database side,
for example all of '"foo", 'foo', '"FOO"' and 'FOO' could coexist as
databases, and values could overlap as appropriate. All values are
lowercase strings.
"""
info_schema_name_map = SchemaSearchMap()
for node in manifest.nodes.values():
if exec_only and node.resource_type not in NodeType.executable():
continue
relation = self.Relation.create_from(self.config, node)
info_schema_name_map.add(relation)
# result is a map whose keys are information_schema Relations without
# identifiers that have appropriate database prefixes, and whose values
# are sets of lowercase schema names that are valid members of those
# schemas
return info_schema_name_map

def _relations_cache_for_schemas(self, manifest):
"""Populate the relations cache for the given schemas. Returns an
iteratble of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
return

schemas = manifest.get_used_schemas()

relations = []
# add all relations
for db, schema in schemas:
info_schema_name_map = self._get_cache_schemas(manifest,
exec_only=True)
for db, schema in info_schema_name_map.search():
for relation in self.list_relations_without_caching(db, schema):
self.cache.add(relation)

# it's possible that there were no relations in some schemas. We want
# to insert the schemas we query into the cache's `.schemas` attribute
# so we can check it later
self.cache.update_schemas(schemas)
self.cache.update_schemas(info_schema_name_map.schemas_searched())

def set_relations_cache(self, manifest, clear=False):
"""Run a query that gets a populated cache of the relations in the
Expand Down Expand Up @@ -415,13 +477,14 @@ def expand_column_types(self, goal, current, model_name=None):
)

@abc.abstractmethod
def list_relations_without_caching(self, database, schema,
def list_relations_without_caching(self, information_schema, schema,
model_name=None):
"""List relations in the given schema, bypassing the cache.
This is used as the underlying behavior to fill the cache.
:param str database: The name of the database to list relations from.
:param Relation information_schema: The information schema to list
relations from.
:param str schema: The name of the schema to list relations from.
:param Optional[str] model_name: The name of the model to use for the
connection.
Expand Down Expand Up @@ -495,10 +558,15 @@ def list_relations(self, database, schema, model_name=None):
if self._schema_is_cached(database, schema, model_name):
return self.cache.get_relations(database, schema)

information_schema = self.Relation.create(
database=database,
schema=schema,
model_name='').information_schema()

# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
relations = self.list_relations_without_caching(
database, schema, model_name=model_name
information_schema, schema, model_name=model_name
)

logger.debug('with schema={}, model_name={}, relations={}'
Expand Down Expand Up @@ -802,10 +870,11 @@ def get_catalog(self, manifest):
"""Get the catalog for this manifest by running the get catalog macro.
Returns an agate.Table of catalog information.
"""
information_schemas = list(self._get_cache_schemas(manifest).keys())
# make it a list so macros can index into it.
context = {'databases': list(manifest.get_used_databases())}
kwargs = {'information_schemas': information_schemas}
table = self.execute_macro(GET_CATALOG_MACRO_NAME,
context_override=context,
kwargs=kwargs,
release=True)

results = self._catalog_filter_table(table, manifest)
Expand Down
53 changes: 46 additions & 7 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dbt.api import APIObject
from dbt.utils import filter_null_values
from dbt.node_types import NodeType

import dbt.exceptions

Expand Down Expand Up @@ -30,15 +31,15 @@ class BaseRelation(APIObject):
'database': True,
'schema': True,
'identifier': True
}
},
}

PATH_SCHEMA = {
'type': 'object',
'properties': {
'database': {'type': ['string', 'null']},
'schema': {'type': ['string', 'null']},
'identifier': {'type': 'string'},
'identifier': {'type': ['string', 'null']},
},
'required': ['database', 'schema', 'identifier'],
}
Expand Down Expand Up @@ -135,6 +136,36 @@ def include(self, database=None, schema=None, identifier=None):

return self.incorporate(include_policy=policy)

def information_schema(self, identifier=None):
include_db = self.database is not None
include_policy = filter_null_values({
'database': include_db,
'schema': True,
'identifier': identifier is not None
})
quote_policy = filter_null_values({
'database': self.quote_policy['database'],
'schema': False,
'identifier': False,
})

path_update = {
'schema': 'information_schema',
'identifier': identifier
}

return self.incorporate(
quote_policy=quote_policy,
include_policy=include_policy,
path=path_update,
table_name=identifier)

def information_schema_only(self):
return self.information_schema()

def information_schema_table(self, identifier):
return self.information_schema(identifier)

def render(self, use_table_name=True):
parts = []

Expand Down Expand Up @@ -174,15 +205,16 @@ def quoted(self, identifier):

@classmethod
def create_from_source(cls, source, **kwargs):
quote_policy = dbt.utils.deep_merge(
cls.DEFAULTS['quote_policy'],
source.quoting,
kwargs.get('quote_policy', {})
)
return cls.create(
database=source.database,
schema=source.schema,
identifier=source.identifier,
quote_policy={
'database': True,
'schema': True,
'identifier': True,
},
quote_policy=quote_policy,
**kwargs
)

Expand All @@ -202,6 +234,13 @@ def create_from_node(cls, config, node, table_name=None, quote_policy=None,
quote_policy=quote_policy,
**kwargs)

@classmethod
def create_from(cls, config, node, **kwargs):
if node.resource_type == NodeType.Source:
return cls.create_from_source(node, **kwargs)
else:
return cls.create_from_node(config, node, **kwargs)

@classmethod
def create(cls, database=None, schema=None,
identifier=None, table_name=None,
Expand Down
17 changes: 9 additions & 8 deletions core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import abc
import time

import agate
import six

import dbt.clients.agate_helper
import dbt.exceptions
import dbt.flags
from dbt.adapters.base import BaseAdapter, available
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.compat import abstractclassmethod


LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
Expand Down Expand Up @@ -196,11 +191,12 @@ def drop_schema(self, database, schema, model_name=None):
kwargs=kwargs,
connection_name=model_name)

def list_relations_without_caching(self, database, schema,
def list_relations_without_caching(self, information_schema, schema,
model_name=None):
kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
kwargs={'database': database, 'schema': schema},
kwargs=kwargs,
connection_name=model_name,
release=True
)
Expand Down Expand Up @@ -236,9 +232,14 @@ def list_schemas(self, database, model_name=None):
return [row[0] for row in results]

def check_schema_exists(self, database, schema, model_name=None):
information_schema = self.Relation.create(
database=database, schema=schema
).information_schema()

kwargs = {'information_schema': information_schema, 'schema': schema}
results = self.execute_macro(
CHECK_SCHEMA_EXISTS_MACRO_NAME,
kwargs={'database': database, 'schema': schema},
kwargs=kwargs,
connection_name=model_name
)
return results[0][0] > 0
5 changes: 5 additions & 0 deletions core/dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def __init__(self, adapter):
def __getattr__(self, key):
return getattr(self.relation_type, key)

def create_from_source(self, *args, **kwargs):
# bypass our create when creating from source so as not to mess up
# the source quoting
return self.relation_type.create_from_source(*args, **kwargs)

def create(self, *args, **kwargs):
kwargs['quote_policy'] = dbt.utils.merge(
self.quoting_config,
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
PARSED_MACRO_CONTRACT, PARSED_DOCUMENTATION_CONTRACT, \
PARSED_SOURCE_DEFINITION_CONTRACT
from dbt.contracts.graph.compiled import COMPILED_NODE_CONTRACT, CompiledNode
from dbt.exceptions import ValidationException
from dbt.exceptions import raise_duplicate_resource_name
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt import tracking
Expand Down Expand Up @@ -401,10 +401,11 @@ def __getattr__(self, name):
type(self).__name__, name)
)

def get_used_schemas(self):
def get_used_schemas(self, resource_types=None):
return frozenset({
(node.database, node.schema)
for node in self.nodes.values()
if not resource_types or node.resource_type in resource_types
})

def get_used_databases(self):
Expand Down
Loading

0 comments on commit 03aa086

Please sign in to comment.