Skip to content

Commit

Permalink
Merge pull request #1881 from fishtown-analytics/fix/bigquery-case-se…
Browse files Browse the repository at this point in the history
…nsitive

Fix bigquery case sensitive caching issue (#1810)
  • Loading branch information
beckjake authored Nov 4, 2019
2 parents c4cd4fc + 670c26b commit 31ca9a1
Show file tree
Hide file tree
Showing 18 changed files with 310 additions and 133 deletions.
2 changes: 1 addition & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _get_cache_schemas(
# result is a map whose keys are information_schema Relations without
# identifiers that have appropriate database prefixes, and whose values
# are sets of lowercase schema names that are valid members of those
# schemas
# databases
return info_schema_name_map

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
Expand Down
141 changes: 106 additions & 35 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Mapping, Hashable
from dataclasses import dataclass, fields
from typing import (
Optional, TypeVar, Generic, Any, Type, Dict, Union, List
Optional, TypeVar, Generic, Any, Type, Dict, Union, List, Iterator, Tuple
)
from typing_extensions import Protocol

Expand Down Expand Up @@ -106,6 +106,21 @@ class Path(_ComponentObject[Optional[str]]):
schema: Optional[str]
identifier: Optional[str]

def __post_init__(self):
# handle pesky jinja2.Undefined sneaking in here and messing up render
if not isinstance(self.database, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid path database: {}'.format(self.database)
)
if not isinstance(self.schema, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid path schema: {}'.format(self.schema)
)
if not isinstance(self.identifier, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid path identifier: {}'.format(self.identifier)
)

def get_lowered_part(self, key: ComponentName) -> Optional[str]:
part = self.get_part(key)
if part is not None:
Expand Down Expand Up @@ -193,6 +208,9 @@ def matches(

return exact_match

def replace_path(self, **kwargs):
return self.replace(path=self.path.replace(**kwargs))

def quote(
self: Self,
database: Optional[bool] = None,
Expand Down Expand Up @@ -223,46 +241,32 @@ def include(
new_include_policy = self.include_policy.replace_dict(policy)
return self.replace(include_policy=new_include_policy)

def information_schema(self: Self, identifier=None) -> Self:
include_policy = self.include_policy.replace(
database=self.database is not None,
schema=True,
identifier=identifier is not None
)
quote_policy = self.quote_policy.replace(
schema=False,
identifier=False,
)

path = self.path.replace(
schema='information_schema',
identifier=identifier,
)
def information_schema(self, view_name=None) -> 'InformationSchema':
# some of our data comes from jinja, where things can be `Undefined`.
if not isinstance(view_name, str):
view_name = None

return self.replace(
quote_policy=quote_policy,
include_policy=include_policy,
path=path,
)
return InformationSchema.from_relation(self, view_name)

def information_schema_only(self: Self) -> Self:
def information_schema_only(self) -> 'InformationSchema':
return self.information_schema()

def information_schema_table(self: Self, identifier: str) -> Self:
return self.information_schema(identifier)
def _render_iterator(
self
) -> Iterator[Tuple[Optional[ComponentName], Optional[str]]]:

def render(self) -> str:
parts: List[str] = []
for key in ComponentName:
path_part: Optional[str] = None
if self.include_policy.get_part(key):
path_part = self.path.get_part(key)
if path_part is not None and self.quote_policy.get_part(key):
path_part = self.quoted(path_part)
yield key, path_part

for k in ComponentName:
if self.include_policy.get_part(k):
path_part = self.path.get_part(k)

if path_part is not None:
part: str = path_part
if self.quote_policy.get_part(k):
part = self.quoted(path_part)
parts.append(part)
def render(self) -> str:
parts: List[str] = [
part for _, part in self._render_iterator() if part is not None
]

if len(parts) == 0:
raise dbt.exceptions.RuntimeException(
Expand Down Expand Up @@ -417,3 +421,70 @@ def External(cls) -> str:
@classproperty
def RelationType(cls) -> Type[RelationType]:
return RelationType


@dataclass(frozen=True, eq=False, repr=False)
class InformationSchema(BaseRelation):
information_schema_view: Optional[str] = None

def __post_init__(self):
if not isinstance(self.information_schema_view, (type(None), str)):
raise dbt.exceptions.CompilationException(
'Got an invalid name: {}'.format(self.information_schema_view)
)

@classmethod
def get_path(
cls, relation: BaseRelation, information_schema_view: Optional[str]
) -> Path:
return Path(
database=relation.database,
schema=relation.schema,
identifier='INFORMATION_SCHEMA',
)

@classmethod
def get_include_policy(
cls,
relation,
information_schema_view: Optional[str],
) -> Policy:
return relation.include_policy.replace(
database=relation.database is not None,
schema=False,
identifier=True,
)

@classmethod
def get_quote_policy(
cls,
relation,
information_schema_view: Optional[str],
) -> Policy:
return relation.quote_policy.replace(
identifier=False,
)

@classmethod
def from_relation(
cls: Self,
relation: BaseRelation,
information_schema_view: Optional[str],
) -> Self:
include_policy = cls.get_include_policy(
relation, information_schema_view
)
quote_policy = cls.get_quote_policy(relation, information_schema_view)
path = cls.get_path(relation, information_schema_view)
return cls(
type=RelationType.View,
path=path,
include_policy=include_policy,
quote_policy=quote_policy,
information_schema_view=information_schema_view,
)

def _render_iterator(self):
for k, v in super()._render_iterator():
yield k, v
yield None, self.information_schema_view
5 changes: 3 additions & 2 deletions core/dbt/adapters/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,13 @@ def get_relations(self, database, schema):
:return List[BaseRelation]: The list of relations with the given
schema
"""
database = _lower(database)
schema = _lower(schema)
with self.lock:
results = [
r.inner for r in self.relations.values()
if (r.schema == _lower(schema) and
r.database == _lower(database))
if (_lower(r.schema) == schema and
_lower(r.database) == database)
]

if None in results:
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def list_schemas(self, database):

def check_schema_exists(self, database, schema):
information_schema = self.Relation.create(
database=database, schema=schema,
database=database,
schema=schema,
identifier='INFORMATION_SCHEMA',
quote_policy=self.config.quoting
).information_schema()

Expand Down
20 changes: 10 additions & 10 deletions core/dbt/include/global_project/macros/adapters/common.sql
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@

{% macro default__information_schema_name(database) -%}
{%- if database -%}
{{ adapter.quote_as_configured(database, 'database') }}.information_schema
{{ adapter.quote_as_configured(database, 'database') }}.INFORMATION_SCHEMA
{%- else -%}
information_schema
INFORMATION_SCHEMA
{%- endif -%}
{%- endmacro %}

Expand All @@ -194,12 +194,12 @@
{% endmacro %}

{% macro default__list_schemas(database) -%}
{% call statement('list_schemas', fetch_result=True, auto_begin=False) %}
{% set sql %}
select distinct schema_name
from {{ information_schema_name(database) }}.schemata
from {{ information_schema_name(database) }}.SCHEMATA
where catalog_name ilike '{{ database }}'
{% endcall %}
{{ return(load_result('list_schemas').table) }}
{% endset %}
{{ return(run_query(sql)) }}
{% endmacro %}


Expand All @@ -208,13 +208,13 @@
{% endmacro %}

{% macro default__check_schema_exists(information_schema, schema) -%}
{% call statement('check_schema_exists', fetch_result=True, auto_begin=False) -%}
{% set sql -%}
select count(*)
from {{ information_schema }}.schemata
from {{ information_schema.replace(information_schema_view='SCHEMATA') }}
where catalog_name='{{ information_schema.database }}'
and schema_name='{{ schema }}'
{%- endcall %}
{{ return(load_result('check_schema_exists').table) }}
{%- endset %}
{{ return(run_query(sql)) }}
{% endmacro %}


Expand Down
15 changes: 4 additions & 11 deletions plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,21 +296,14 @@ def drop_dataset(self, database, schema):
client = conn.handle

with self.exception_handler('drop dataset'):
for table in client.list_tables(dataset):
client.delete_table(table.reference)
client.delete_dataset(dataset)
client.delete_dataset(
dataset, delete_contents=True, not_found_ok=True
)

def create_dataset(self, database, schema):
conn = self.get_thread_connection()
client = conn.handle
dataset = self.dataset(database, schema, conn)

# Emulate 'create schema if not exists ...'
try:
client.get_dataset(dataset)
return
except google.api_core.exceptions.NotFound:
pass

with self.exception_handler('create dataset'):
client.create_dataset(dataset)
client.create_dataset(dataset, exists_ok=True)
Loading

0 comments on commit 31ca9a1

Please sign in to comment.