From eb9bfcda4ab314e4ee5bf548fab07ee4fadbee4d Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 24 Sep 2019 09:40:55 -0600 Subject: [PATCH 1/2] Convert Relation types to hologram.JsonSchemaMixin Fix a lot of mypy things, add a number of adapter-ish modules to it Split relations and columns into separate files split context.common into base + common - base is all that's required for the config renderer Move Credentials into connection contracts since that's what they really are Removed model_name/table_name -> consolidated to identifier - I hope I did not break seeds, which claimed to care about render(False) Unify shared 'external' relation type with bigquery's own hack workarounds for some import cycles with plugin registration and config p arsing Assorted backwards compatibility fixes around types, deep_merge vs shallow merge Remove APIObject --- core/dbt/adapters/base/__init__.py | 8 +- core/dbt/adapters/base/column.py | 93 +++ core/dbt/adapters/base/connections.py | 159 ++--- core/dbt/adapters/base/impl.py | 16 +- core/dbt/adapters/base/plugin.py | 35 +- core/dbt/adapters/base/relation.py | 573 +++++++++--------- core/dbt/adapters/cache.py | 1 - core/dbt/adapters/factory.py | 51 +- core/dbt/adapters/sql/connections.py | 50 +- core/dbt/adapters/sql/impl.py | 24 +- core/dbt/api/__init__.py | 5 - core/dbt/api/object.py | 125 ---- core/dbt/config/__init__.py | 2 +- core/dbt/config/renderer.py | 2 +- core/dbt/config/runtime.py | 12 +- core/dbt/context/base.py | 138 +++++ core/dbt/context/common.py | 216 ++----- core/dbt/context/parser.py | 2 +- core/dbt/context/runtime.py | 3 +- core/dbt/contracts/connection.py | 64 +- core/dbt/contracts/util.py | 4 +- core/dbt/deprecations.py | 10 + .../global_project/macros/adapters/common.sql | 3 +- .../macros/materializations/seed/seed.sql | 4 +- core/dbt/parser/schemas.py | 2 +- core/dbt/tracking.py | 4 +- core/dbt/utils.py | 17 +- .../dbt/adapters/bigquery/__init__.py | 2 +- .../bigquery/dbt/adapters/bigquery/column.py | 121 ++++ .../dbt/adapters/bigquery/connections.py | 2 +- .../bigquery/dbt/adapters/bigquery/impl.py | 14 +- .../dbt/adapters/bigquery/relation.py | 209 +------ .../dbt/include/postgres/macros/adapters.sql | 1 - .../dbt/adapters/snowflake/relation.py | 54 +- .../dbt/include/snowflake/macros/adapters.sql | 4 +- .../test_concurrent_transaction.py | 3 + test/unit/test_bigquery_adapter.py | 24 +- test/unit/test_cache.py | 3 +- test/unit/utils.py | 1 - tox.ini | 5 +- 40 files changed, 1012 insertions(+), 1054 deletions(-) create mode 100644 core/dbt/adapters/base/column.py delete mode 100644 core/dbt/api/__init__.py delete mode 100644 core/dbt/api/object.py create mode 100644 core/dbt/context/base.py create mode 100644 plugins/bigquery/dbt/adapters/bigquery/column.py diff --git a/core/dbt/adapters/base/__init__.py b/core/dbt/adapters/base/__init__.py index 5edf237447b..39461477c69 100644 --- a/core/dbt/adapters/base/__init__.py +++ b/core/dbt/adapters/base/__init__.py @@ -1,8 +1,10 @@ # these are all just exports, #noqa them so flake8 will be happy + +# TODO: Should we still include this in the `adapters` namespace? +from dbt.contracts.connection import Credentials # noqa from dbt.adapters.base.meta import available # noqa -from dbt.adapters.base.relation import BaseRelation # noqa -from dbt.adapters.base.relation import Column # noqa from dbt.adapters.base.connections import BaseConnectionManager # noqa -from dbt.adapters.base.connections import Credentials # noqa +from dbt.adapters.base.relation import BaseRelation, RelationType # noqa +from dbt.adapters.base.column import Column # noqa from dbt.adapters.base.impl import BaseAdapter # noqa from dbt.adapters.base.plugin import AdapterPlugin # noqa diff --git a/core/dbt/adapters/base/column.py b/core/dbt/adapters/base/column.py new file mode 100644 index 00000000000..c6e6fcb3288 --- /dev/null +++ b/core/dbt/adapters/base/column.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass + +from hologram import JsonSchemaMixin + +from typing import TypeVar, Dict, ClassVar, Any, Optional, Type + +Self = TypeVar('Self', bound='Column') + + +@dataclass +class Column(JsonSchemaMixin): + TYPE_LABELS: ClassVar[Dict[str, str]] = { + 'STRING': 'TEXT', + 'TIMESTAMP': 'TIMESTAMP', + 'FLOAT': 'FLOAT', + 'INTEGER': 'INT' + } + column: str + dtype: str + char_size: Optional[int] = None + numeric_precision: Optional[Any] = None + numeric_scale: Optional[Any] = None + + @classmethod + def translate_type(cls, dtype: str) -> str: + return cls.TYPE_LABELS.get(dtype.upper(), dtype) + + @classmethod + def create(cls: Type[Self], name, label_or_dtype: str) -> Self: + column_type = cls.translate_type(label_or_dtype) + return cls(name, column_type) + + @property + def name(self) -> str: + return self.column + + @property + def quoted(self) -> str: + return '"{}"'.format(self.column) + + @property + def data_type(self) -> str: + if self.is_string(): + return Column.string_type(self.string_size()) + elif self.is_numeric(): + return Column.numeric_type(self.dtype, self.numeric_precision, + self.numeric_scale) + else: + return self.dtype + + def is_string(self) -> bool: + return self.dtype.lower() in ['text', 'character varying', 'character', + 'varchar'] + + def is_numeric(self) -> bool: + return self.dtype.lower() in ['numeric', 'number'] + + def string_size(self) -> int: + if not self.is_string(): + raise RuntimeError("Called string_size() on non-string field!") + + if self.dtype == 'text' or self.char_size is None: + # char_size should never be None. Handle it reasonably just in case + return 256 + else: + return int(self.char_size) + + def can_expand_to(self: Self, other_column: Self) -> bool: + """returns True if this column can be expanded to the size of the + other column""" + if not self.is_string() or not other_column.is_string(): + return False + + return other_column.string_size() > self.string_size() + + def literal(self, value: Any) -> str: + return "{}::{}".format(value, self.data_type) + + @classmethod + def string_type(cls, size: int) -> str: + return "character varying({})".format(size) + + @classmethod + def numeric_type(cls, dtype: str, precision: Any, scale: Any) -> str: + # This could be decimal(...), numeric(...), number(...) + # Just use whatever was fed in here -- don't try to get too clever + if precision is None or scale is None: + return dtype + else: + return "{}({},{})".format(dtype, precision, scale) + + def __repr__(self) -> str: + return "".format(self.name, self.data_type) diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index ddb7a15522e..9f2cedf4616 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -2,66 +2,17 @@ import multiprocessing import os from threading import get_ident +from typing import ( + Dict, Tuple, Hashable, Optional, ContextManager, List +) + +import agate import dbt.exceptions import dbt.flags -from dbt.contracts.connection import Connection -from dbt.contracts.util import Replaceable +from dbt.config import Profile +from dbt.contracts.connection import Connection, Identifier, ConnectionState from dbt.logger import GLOBAL_LOGGER as logger -from dbt.utils import translate_aliases - -from hologram.helpers import ExtensibleJsonSchemaMixin - -from dataclasses import dataclass, field -from typing import Any, ClassVar, Dict, Tuple - - -@dataclass -class Credentials( - ExtensibleJsonSchemaMixin, - Replaceable, - metaclass=abc.ABCMeta -): - database: str - schema: str - _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) - - @abc.abstractproperty - def type(self): - raise NotImplementedError( - 'type not implemented for base credentials class' - ) - - def connection_info(self): - """Return an ordered iterator of key/value pairs for pretty-printing. - """ - as_dict = self.to_dict() - for key in self._connection_keys(): - if key in as_dict: - yield key, as_dict[key] - - @abc.abstractmethod - def _connection_keys(self) -> Tuple[str, ...]: - raise NotImplementedError - - @classmethod - def from_dict(cls, data): - data = cls.translate_aliases(data) - return super().from_dict(data) - - @classmethod - def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return translate_aliases(kwargs, cls._ALIASES) - - def to_dict(self, omit_none=True, validate=False, with_aliases=False): - serialized = super().to_dict(omit_none=omit_none, validate=validate) - if with_aliases: - serialized.update({ - new_name: serialized[canonical_name] - for new_name, canonical_name in self._ALIASES.items() - if canonical_name in serialized - }) - return serialized class BaseConnectionManager(metaclass=abc.ABCMeta): @@ -79,18 +30,18 @@ class BaseConnectionManager(metaclass=abc.ABCMeta): """ TYPE: str = NotImplemented - def __init__(self, profile): + def __init__(self, profile: Profile): self.profile = profile - self.thread_connections = {} + self.thread_connections: Dict[Hashable, Connection] = {} self.lock = multiprocessing.RLock() @staticmethod - def get_thread_identifier(): + def get_thread_identifier() -> Hashable: # note that get_ident() may be re-used, but we should never experience # that within a single process return (os.getpid(), get_ident()) - def get_thread_connection(self): + def get_thread_connection(self) -> Connection: key = self.get_thread_identifier() with self.lock: if key not in self.thread_connections: @@ -100,18 +51,18 @@ def get_thread_connection(self): ) return self.thread_connections[key] - def get_if_exists(self): + def get_if_exists(self) -> Optional[Connection]: key = self.get_thread_identifier() with self.lock: return self.thread_connections.get(key) - def clear_thread_connection(self): + def clear_thread_connection(self) -> None: key = self.get_thread_identifier() with self.lock: if key in self.thread_connections: del self.thread_connections[key] - def clear_transaction(self): + def clear_transaction(self) -> None: """Clear any existing transactions.""" conn = self.get_thread_connection() if conn is not None: @@ -121,7 +72,7 @@ def clear_transaction(self): self.commit() @abc.abstractmethod - def exception_handler(self, sql): + def exception_handler(self, sql: str) -> ContextManager: """Create a context manager that handles exceptions caused by database interactions. @@ -133,70 +84,73 @@ def exception_handler(self, sql): raise dbt.exceptions.NotImplementedException( '`exception_handler` is not implemented for this adapter!') - def set_connection_name(self, name=None): + def set_connection_name(self, name: Optional[str] = None) -> Connection: + conn_name: str if name is None: # if a name isn't specified, we'll re-use a single handle # named 'master' - name = 'master' + conn_name = 'master' + else: + assert isinstance(name, str) + conn_name = name conn = self.get_if_exists() thread_id_key = self.get_thread_identifier() if conn is None: conn = Connection( - type=self.TYPE, + type=Identifier(self.TYPE), name=None, - state='init', + state=ConnectionState.INIT, transaction_open=False, handle=None, credentials=self.profile.credentials ) self.thread_connections[thread_id_key] = conn - if conn.name == name and conn.state == 'open': + if conn.name == conn_name and conn.state == 'open': return conn - logger.debug('Acquiring new {} connection "{}".' - .format(self.TYPE, name)) + logger.debug( + 'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name)) if conn.state == 'open': logger.debug( 'Re-using an available connection from the pool (formerly {}).' - .format(conn.name)) + .format(conn.name) + ) else: - logger.debug('Opening a new connection, currently in state {}' - .format(conn.state)) + logger.debug( + 'Opening a new connection, currently in state {}' + .format(conn.state) + ) self.open(conn) - conn.name = name + conn.name = conn_name return conn @abc.abstractmethod - def cancel_open(self): + def cancel_open(self) -> Optional[List[str]]: """Cancel all open connections on the adapter. (passable)""" raise dbt.exceptions.NotImplementedException( '`cancel_open` is not implemented for this adapter!' ) @abc.abstractclassmethod - def open(cls, connection): - """Open a connection on the adapter. + def open(cls, connection: Connection) -> Connection: + """Open the given connection on the adapter and return it. This may mutate the given connection (in particular, its state and its handle). This should be thread-safe, or hold the lock if necessary. The given connection should not be in either in_use or available. - - :param Connection connection: A connection object to open. - :return: A connection with a handle attached and an 'open' state. - :rtype: Connection """ raise dbt.exceptions.NotImplementedException( '`open` is not implemented for this adapter!' ) - def release(self): + def release(self) -> None: with self.lock: conn = self.get_if_exists() if conn is None: @@ -213,7 +167,7 @@ def release(self): self.clear_thread_connection() raise - def cleanup_all(self): + def cleanup_all(self) -> None: with self.lock: for connection in self.thread_connections.values(): if connection.state not in {'closed', 'init'}: @@ -228,24 +182,21 @@ def cleanup_all(self): self.thread_connections.clear() @abc.abstractmethod - def begin(self): - """Begin a transaction. (passable) - - :param str name: The name of the connection to use. - """ + def begin(self) -> None: + """Begin a transaction. (passable)""" raise dbt.exceptions.NotImplementedException( '`begin` is not implemented for this adapter!' ) @abc.abstractmethod - def commit(self): + def commit(self) -> None: """Commit a transaction. (passable)""" raise dbt.exceptions.NotImplementedException( '`commit` is not implemented for this adapter!' ) @classmethod - def _rollback_handle(cls, connection): + def _rollback_handle(cls, connection: Connection) -> None: """Perform the actual rollback operation.""" try: connection.handle.rollback() @@ -256,7 +207,7 @@ def _rollback_handle(cls, connection): ) @classmethod - def _close_handle(cls, connection): + def _close_handle(cls, connection: Connection) -> None: """Perform the actual close operation.""" # On windows, sometimes connection handles don't have a close() attr. if hasattr(connection.handle, 'close'): @@ -267,9 +218,8 @@ def _close_handle(cls, connection): .format(connection.name)) @classmethod - def _rollback(cls, connection): - """Roll back the given connection. - """ + def _rollback(cls, connection: Connection) -> None: + """Roll back the given connection.""" if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) @@ -283,15 +233,13 @@ def _rollback(cls, connection): connection.transaction_open = False - return connection - @classmethod - def close(cls, connection): + def close(cls, connection: Connection) -> Connection: if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) # if the connection is in closed or init, there's nothing to do - if connection.state in {'closed', 'init'}: + if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}: return connection if connection.transaction_open and connection.handle: @@ -299,21 +247,20 @@ def close(cls, connection): connection.transaction_open = False cls._close_handle(connection) - connection.state = 'closed' + connection.state = ConnectionState.CLOSED return connection - def commit_if_has_connection(self): - """If the named connection exists, commit the current transaction. - - :param str name: The name of the connection to use. - """ + def commit_if_has_connection(self) -> None: + """If the named connection exists, commit the current transaction.""" connection = self.get_if_exists() if connection: self.commit() @abc.abstractmethod - def execute(self, sql, auto_begin=False, fetch=False): + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> Tuple[str, agate.Table]: """Execute the given SQL. :param str sql: The sql to execute. diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 93208be858b..a4ab29d51cc 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -21,7 +21,7 @@ from dbt.adapters.base.connections import BaseConnectionManager from dbt.adapters.base.meta import AdapterMeta, available -from dbt.adapters.base import BaseRelation +from dbt.adapters.base.relation import ComponentName, BaseRelation from dbt.adapters.base import Column as BaseColumn from dbt.adapters.cache import RelationsCache @@ -645,7 +645,7 @@ def list_relations(self, database: str, schema: str) -> List[BaseRelation]: information_schema = self.Relation.create( database=database, schema=schema, - model_name='', + identifier='', quote_policy=self.config.quoting ).information_schema() @@ -762,11 +762,13 @@ def quote_as_configured(self, identifier: str, quote_key: str) -> str: The quote key should be one of 'database' (on bigquery, 'profile'), 'identifier', or 'schema', or it will be treated as if you set `True`. """ - # TODO: Convert BaseRelation to a hologram.JsonSchemaMixin so mypy - # likes this - quotes = self.Relation.DEFAULTS['quote_policy'] - default = quotes.get(quote_key) # type: ignore - if self.config.quoting.get(quote_key, default): + try: + key = ComponentName(quote_key) + except ValueError: + return identifier + + default = self.Relation.get_default_quote_policy().get_part(key) + if self.config.quoting.get(key, default): return self.quote(identifier) else: return identifier diff --git a/core/dbt/adapters/base/plugin.py b/core/dbt/adapters/base/plugin.py index d731c3493c9..c307c97d62c 100644 --- a/core/dbt/adapters/base/plugin.py +++ b/core/dbt/adapters/base/plugin.py @@ -1,23 +1,30 @@ +from typing import List, Optional, Type + from dbt.config.project import Project +from dbt.adapters.base import BaseAdapter, Credentials class AdapterPlugin: """Defines the basic requirements for a dbt adapter plugin. - :param type adapter: An adapter class, derived from BaseAdapter - :param type credentials: A credentials object, derived from Credentials - :param str project_name: The name of this adapter plugin's associated dbt - project. - :param str include_path: The path to this adapter plugin's root - :param Optional[List[str]] dependencies: A list of adapter names that this - adapter depends upon. + :param include_path: The path to this adapter plugin's root + :param dependencies: A list of adapter names that this adapter depends + upon. """ - def __init__(self, adapter, credentials, include_path, dependencies=None): - self.adapter = adapter - self.credentials = credentials - self.include_path = include_path + def __init__( + self, + adapter: Type[BaseAdapter], + credentials: Type[Credentials], + include_path: str, + dependencies: Optional[List[str]] = None + ): + self.adapter: Type[BaseAdapter] = adapter + self.credentials: Type[Credentials] = credentials + self.include_path: str = include_path project = Project.from_project_root(include_path, {}) - self.project_name = project.project_name + self.project_name: str = project.project_name + self.dependencies: List[str] if dependencies is None: - dependencies = [] - self.dependencies = dependencies + self.dependencies = [] + else: + self.dependencies = dependencies diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 84cd6bc7ecc..59728a55dff 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -1,102 +1,171 @@ -from dbt.api import APIObject -from dbt.utils import filter_null_values +from dbt.utils import filter_null_values, deep_merge, classproperty from dbt.node_types import NodeType import dbt.exceptions +from collections.abc import Mapping, Hashable +from dataclasses import dataclass, fields +from typing import ( + Optional, TypeVar, Generic, Any, Type, Dict, Union, List +) +from typing_extensions import Protocol -class BaseRelation(APIObject): - - Table = "table" - View = "view" - CTE = "cte" - MaterializedView = "materializedview" - ExternalTable = "externaltable" - - RelationTypes = [ - Table, - View, - CTE, - MaterializedView, - ExternalTable - ] - - DEFAULTS = { - 'metadata': { - 'type': 'BaseRelation' - }, - 'quote_character': '"', - 'quote_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'include_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'dbt_created': False, - } - - PATH_SCHEMA = { - 'type': 'object', - 'properties': { - 'database': {'type': ['string', 'null']}, - 'schema': {'type': ['string', 'null']}, - 'identifier': {'type': ['string', 'null']}, - }, - 'required': ['database', 'schema', 'identifier'], - } - - POLICY_SCHEMA = { - 'type': 'object', - 'properties': { - 'database': {'type': 'boolean'}, - 'schema': {'type': 'boolean'}, - 'identifier': {'type': 'boolean'}, - }, - 'required': ['database', 'schema', 'identifier'], - } - - SCHEMA = { - 'type': 'object', - 'properties': { - 'metadata': { - 'type': 'object', - 'properties': { - 'type': { - 'type': 'string', - 'const': 'BaseRelation', - }, - }, - }, - 'type': { - 'enum': RelationTypes + [None], - }, - 'path': PATH_SCHEMA, - 'include_policy': POLICY_SCHEMA, - 'quote_policy': POLICY_SCHEMA, - 'quote_character': {'type': 'string'}, - 'dbt_created': {'type': 'boolean'}, - }, - 'required': ['metadata', 'type', 'path', 'include_policy', - 'quote_policy', 'quote_character', 'dbt_created'] - } - - PATH_ELEMENTS = ['database', 'schema', 'identifier'] - - def _is_exactish_match(self, field, value): - if self.dbt_created and self.quote_policy.get(field) is False: - return self.get_path_part(field).lower() == value.lower() +from hologram import JsonSchemaMixin +from hologram.helpers import StrEnum + +from dbt.contracts.util import Replaceable +from dbt.contracts.graph.compiled import CompiledNode +from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode +from dbt import deprecations + + +class RelationType(StrEnum): + Table = 'table' + View = 'view' + CTE = 'cte' + MaterializedView = 'materializedview' + External = 'external' + + +class ComponentName(StrEnum): + Database = 'database' + Schema = 'schema' + Identifier = 'identifier' + + +class HasQuoting(Protocol): + quoting: Dict[str, bool] + + +class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping): + # override the mapping truthiness, len is always >1 + def __bool__(self): + return True + + def __getitem__(self, key): + # deprecations.warn('not-a-dictionary', obj=self) + try: + return getattr(self, key) + except AttributeError: + raise KeyError(key) from None + + def __iter__(self): + deprecations.warn('not-a-dictionary', obj=self) + for _, name in self._get_fields(): + yield name + + def __len__(self): + deprecations.warn('not-a-dictionary', obj=self) + return len(fields(self.__class__)) + + def incorporate(self, **kwargs): + value = self.to_dict() + value = deep_merge(value, kwargs) + return self.from_dict(value) + + +T = TypeVar('T') + + +@dataclass +class _ComponentObject(FakeAPIObject, Generic[T]): + database: T + schema: T + identifier: T + + def get_part(self, key: ComponentName) -> T: + if key == ComponentName.Database: + return self.database + elif key == ComponentName.Schema: + return self.schema + elif key == ComponentName.Identifier: + return self.identifier + else: + raise ValueError( + 'Got a key of {}, expected one of {}' + .format(key, list(ComponentName)) + ) + + def replace_dict(self, dct: Dict[ComponentName, T]): + kwargs: Dict[str, T] = {} + for k, v in dct.items(): + kwargs[str(k)] = v + return self.replace(**kwargs) + + +@dataclass +class Policy(_ComponentObject[bool]): + database: bool = True + schema: bool = True + identifier: bool = True + + +@dataclass +class Path(_ComponentObject[Optional[str]]): + database: Optional[str] + schema: Optional[str] + identifier: Optional[str] + + def get_lowered_part(self, key: ComponentName) -> Optional[str]: + part = self.get_part(key) + if part is not None: + part = part.lower() + return part + + +Self = TypeVar('Self', bound='BaseRelation') + + +@dataclass(frozen=True, eq=False, repr=False) +class BaseRelation(FakeAPIObject, Hashable): + type: Optional[RelationType] + path: Path + quote_character: str = '"' + include_policy: Policy = Policy() + quote_policy: Policy = Policy() + dbt_created: bool = False + + def _is_exactish_match(self, field: ComponentName, value: str) -> bool: + if self.dbt_created and self.quote_policy.get_part(field) is False: + return self.path.get_lowered_part(field) == value.lower() else: - return self.get_path_part(field) == value + return self.path.get_part(field) == value + + @classmethod + def _get_field_named(cls, field_name): + for field, _ in cls._get_fields(): + if field.name == field_name: + return field + # this should be unreachable + raise ValueError(f'BaseRelation has no {field_name} field!') + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.to_dict() == other.to_dict() + + @classmethod + def get_default_quote_policy(cls: Type[Self]) -> Policy: + return cls._get_field_named('quote_policy').default + + @classmethod + def get_default_include_policy(cls: Type[Self]) -> Policy: + return cls._get_field_named('include_policy').default - def matches(self, database=None, schema=None, identifier=None): + @classmethod + def get_relation_type_class(cls: Type[Self]) -> Type[RelationType]: + return cls._get_field_named('type') + + def matches( + self, + database: Optional[str] = None, + schema: Optional[str] = None, + identifier: Optional[str] = None, + ) -> bool: search = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) if not search: @@ -111,7 +180,7 @@ def matches(self, database=None, schema=None, identifier=None): if not self._is_exactish_match(k, v): exact_match = False - if self.get_path_part(k).lower() != v.lower(): + if self.path.get_lowered_part(k) != v.lower(): approximate_match = False if approximate_match and not exact_match: @@ -122,107 +191,100 @@ def matches(self, database=None, schema=None, identifier=None): return exact_match - def get_path_part(self, part): - return self.path.get(part) - - def should_quote(self, part): - return self.quote_policy.get(part) - - def should_include(self, part): - return self.include_policy.get(part) - - def quote(self, database=None, schema=None, identifier=None): + def quote( + self: Self, + database: Optional[bool] = None, + schema: Optional[bool] = None, + identifier: Optional[bool] = None, + ) -> Self: policy = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) - return self.incorporate(quote_policy=policy) + new_quote_policy = self.quote_policy.replace_dict(policy) + return self.replace(quote_policy=new_quote_policy) - def include(self, database=None, schema=None, identifier=None): + def include( + self: Self, + database: Optional[bool] = None, + schema: Optional[bool] = None, + identifier: Optional[bool] = None, + ) -> Self: policy = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) - return self.incorporate(include_policy=policy) + new_include_policy = self.include_policy.replace_dict(policy) + return self.replace(include_policy=new_include_policy) - def information_schema(self, identifier=None): - include_db = self.database is not None - include_policy = filter_null_values({ - 'database': include_db, - 'schema': True, - 'identifier': identifier is not None - }) - quote_policy = filter_null_values({ - 'database': self.quote_policy['database'], - 'schema': False, - 'identifier': False, - }) + def information_schema(self: Self, identifier=None) -> Self: + include_policy = self.include_policy.replace( + database=self.database is not None, + schema=True, + identifier=identifier is not None + ) + quote_policy = self.quote_policy.replace( + schema=False, + identifier=False, + ) - path_update = { - 'schema': 'information_schema', - 'identifier': identifier - } + path = self.path.replace( + schema='information_schema', + identifier=identifier, + ) - return self.incorporate( + return self.replace( quote_policy=quote_policy, include_policy=include_policy, - path=path_update, - table_name=identifier) + path=path, + ) - def information_schema_only(self): + def information_schema_only(self: Self) -> Self: return self.information_schema() - def information_schema_table(self, identifier): + def information_schema_table(self: Self, identifier: str) -> Self: return self.information_schema(identifier) - def render(self, use_table_name=True): - parts = [] - - for k in self.PATH_ELEMENTS: - if self.should_include(k): - path_part = self.get_path_part(k) + def render(self) -> str: + parts: List[str] = [] - if path_part is None: - continue - elif k == 'identifier': - if use_table_name: - path_part = self.table - else: - path_part = self.identifier + for k in ComponentName: + if self.include_policy.get_part(k): + path_part = self.path.get_part(k) - parts.append( - self.quote_if( - path_part, - self.should_quote(k))) + if path_part is not None: + part: str = path_part + if self.quote_policy.get_part(k): + part = self.quoted(path_part) + parts.append(part) if len(parts) == 0: raise dbt.exceptions.RuntimeException( - "No path parts are included! Nothing to render.") + "No path parts are included! Nothing to render." + ) return '.'.join(parts) - def quote_if(self, identifier, should_quote): - if should_quote: - return self.quoted(identifier) - - return identifier - def quoted(self, identifier): return '{quote_char}{identifier}{quote_char}'.format( quote_char=self.quote_character, - identifier=identifier) + identifier=identifier, + ) @classmethod - def create_from_source(cls, source, **kwargs): - quote_policy = dbt.utils.deep_merge( - cls.DEFAULTS['quote_policy'], + def create_from_source( + cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any + ) -> Self: + quote_policy = deep_merge( + cls.get_default_quote_policy().to_dict(), source.quoting.to_dict(), - kwargs.get('quote_policy', {}) + kwargs.get('quote_policy', {}), ) + return cls.create( database=source.database, schema=source.schema, @@ -232,8 +294,13 @@ def create_from_source(cls, source, **kwargs): ) @classmethod - def create_from_node(cls, config, node, table_name=None, quote_policy=None, - **kwargs): + def create_from_node( + cls: Type[Self], + config: HasQuoting, + node: Union[ParsedNode, CompiledNode], + quote_policy: Optional[Dict[str, bool]] = None, + **kwargs: Any, + ) -> Self: if quote_policy is None: quote_policy = {} @@ -243,164 +310,96 @@ def create_from_node(cls, config, node, table_name=None, quote_policy=None, database=node.database, schema=node.schema, identifier=node.alias, - table_name=table_name, quote_policy=quote_policy, **kwargs) @classmethod - def create_from(cls, config, node, **kwargs): + def create_from( + cls: Type[Self], + config: HasQuoting, + node: Union[CompiledNode, ParsedNode, ParsedSourceDefinition], + **kwargs: Any, + ) -> Self: if node.resource_type == NodeType.Source: + assert isinstance(node, ParsedSourceDefinition) return cls.create_from_source(node, **kwargs) else: + assert isinstance(node, (ParsedNode, CompiledNode)) return cls.create_from_node(config, node, **kwargs) @classmethod - def create(cls, database=None, schema=None, - identifier=None, table_name=None, - type=None, **kwargs): - if table_name is None: - table_name = identifier - - return cls(type=type, - path={ - 'database': database, - 'schema': schema, - 'identifier': identifier - }, - table_name=table_name, - **kwargs) - - def __repr__(self): + def create( + cls: Type[Self], + database: Optional[str] = None, + schema: Optional[str] = None, + identifier: Optional[str] = None, + type: Optional[RelationType] = None, + **kwargs, + ) -> Self: + kwargs.update({ + 'path': { + 'database': database, + 'schema': schema, + 'identifier': identifier, + }, + 'type': type, + }) + return cls.from_dict(kwargs) + + def __repr__(self) -> str: return "<{} {}>".format(self.__class__.__name__, self.render()) - def __hash__(self): + def __hash__(self) -> int: return hash(self.render()) - def __str__(self): + def __str__(self) -> str: return self.render() @property - def path(self): - return self.get('path', {}) + def database(self) -> Optional[str]: + return self.path.database @property - def database(self): - return self.path.get('database') + def schema(self) -> Optional[str]: + return self.path.schema @property - def schema(self): - return self.path.get('schema') + def identifier(self) -> Optional[str]: + return self.path.identifier @property - def identifier(self): - return self.path.get('identifier') + def table(self) -> Optional[str]: + return self.path.identifier # Here for compatibility with old Relation interface @property - def name(self): + def name(self) -> Optional[str]: return self.identifier - # Here for compatibility with old Relation interface - @property - def table(self): - return self.table_name - - @property - def is_table(self): - return self.type == self.Table - @property - def is_cte(self): - return self.type == self.CTE + def is_table(self) -> bool: + return self.type == RelationType.Table @property - def is_view(self): - return self.type == self.View - - -class Column: - TYPE_LABELS = { - 'STRING': 'TEXT', - 'TIMESTAMP': 'TIMESTAMP', - 'FLOAT': 'FLOAT', - 'INTEGER': 'INT' - } - - def __init__(self, column, dtype, char_size=None, numeric_precision=None, - numeric_scale=None): - self.column = column - self.dtype = dtype - self.char_size = char_size - self.numeric_precision = numeric_precision - self.numeric_scale = numeric_scale - - @classmethod - def translate_type(cls, dtype): - return cls.TYPE_LABELS.get(dtype.upper(), dtype) - - @classmethod - def create(cls, name, label_or_dtype): - column_type = cls.translate_type(label_or_dtype) - return cls(name, column_type) + def is_cte(self) -> bool: + return self.type == RelationType.CTE @property - def name(self): - return self.column - - @property - def quoted(self): - return '"{}"'.format(self.column) - - @property - def data_type(self): - if self.is_string(): - return Column.string_type(self.string_size()) - elif self.is_numeric(): - return Column.numeric_type(self.dtype, self.numeric_precision, - self.numeric_scale) - else: - return self.dtype - - def is_string(self): - return self.dtype.lower() in ['text', 'character varying', 'character', - 'varchar'] + def is_view(self) -> bool: + return self.type == RelationType.View - def is_numeric(self): - return self.dtype.lower() in ['numeric', 'number'] + @classproperty + def Table(self) -> str: + return str(RelationType.Table) - def string_size(self): - if not self.is_string(): - raise RuntimeError("Called string_size() on non-string field!") + @classproperty + def CTE(self) -> str: + return str(RelationType.CTE) - if self.dtype == 'text' or self.char_size is None: - # char_size should never be None. Handle it reasonably just in case - return 256 - else: - return int(self.char_size) - - def can_expand_to(self, other_column): - """returns True if this column can be expanded to the size of the - other column""" - if not self.is_string() or not other_column.is_string(): - return False - - return other_column.string_size() > self.string_size() - - def literal(self, value): - return "{}::{}".format(value, self.data_type) - - @classmethod - def string_type(cls, size): - return "character varying({})".format(size) - - @classmethod - def numeric_type(cls, dtype, precision, scale): - # This could be decimal(...), numeric(...), number(...) - # Just use whatever was fed in here -- don't try to get too clever - if precision is None or scale is None: - return dtype - else: - return "{}({},{})".format(dtype, precision, scale) + @classproperty + def View(self) -> str: + return str(RelationType.View) - def __repr__(self): - return "".format(self.name, self.data_type) + @classproperty + def External(self) -> str: + return str(RelationType.External) diff --git a/core/dbt/adapters/cache.py b/core/dbt/adapters/cache.py index 521ff52d975..78472ecc674 100644 --- a/core/dbt/adapters/cache.py +++ b/core/dbt/adapters/cache.py @@ -130,7 +130,6 @@ def rename(self, new_relation): 'schema': new_relation.inner.schema, 'identifier': new_relation.inner.identifier }, - table_name=new_relation.inner.identifier ) def rename_key(self, old_key, new_key): diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index a380e393a62..9c7a78b2e6a 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -1,50 +1,66 @@ -import dbt.exceptions +import threading from importlib import import_module +from typing import Type, Dict, TypeVar + +from dbt.exceptions import RuntimeException from dbt.include.global_project import PACKAGES from dbt.logger import GLOBAL_LOGGER as logger +from dbt.contracts.connection import Credentials -import threading -ADAPTER_TYPES = {} +# TODO: we can't import these because they cause an import cycle. +# currently RuntimeConfig needs to figure out default quoting for its adapter. +# We should push that elsewhere when we fixup project/profile stuff +# Instead here are some import loop avoiding-hacks right now. And Profile has +# to call into load_plugin to get credentials, so adapter/relation don't work +RuntimeConfig = TypeVar('RuntimeConfig') +BaseAdapter = TypeVar('BaseAdapter') +BaseRelation = TypeVar('BaseRelation') -_ADAPTERS = {} +ADAPTER_TYPES: Dict[str, Type[BaseAdapter]] = {} + +_ADAPTERS: Dict[str, BaseAdapter] = {} _ADAPTER_LOCK = threading.Lock() -def get_adapter_class_by_name(adapter_name): +def get_adapter_class_by_name(adapter_name: str) -> Type[BaseAdapter]: with _ADAPTER_LOCK: if adapter_name in ADAPTER_TYPES: return ADAPTER_TYPES[adapter_name] + adapter_names = ", ".join(ADAPTER_TYPES.keys()) + message = "Invalid adapter type {}! Must be one of {}" - adapter_names = ", ".join(ADAPTER_TYPES.keys()) formatted_message = message.format(adapter_name, adapter_names) - raise dbt.exceptions.RuntimeException(formatted_message) + raise RuntimeException(formatted_message) -def get_relation_class_by_name(adapter_name): +def get_relation_class_by_name(adapter_name: str) -> Type[BaseRelation]: adapter = get_adapter_class_by_name(adapter_name) return adapter.Relation -def load_plugin(adapter_name): +def load_plugin(adapter_name: str) -> Credentials: + # this doesn't need a lock: in the worst case we'll overwrite PACKAGES and + # _ADAPTER_TYPE entries with the same value, as they're all singletons try: mod = import_module('.' + adapter_name, 'dbt.adapters') except ImportError as e: logger.info("Error importing adapter: {}".format(e)) - raise dbt.exceptions.RuntimeException( + raise RuntimeException( "Could not find adapter type {}!".format(adapter_name) ) plugin = mod.Plugin if plugin.adapter.type() != adapter_name: - raise dbt.exceptions.RuntimeException( + raise RuntimeException( 'Expected to find adapter with type named {}, got adapter with ' 'type {}' .format(adapter_name, plugin.adapter.type()) ) with _ADAPTER_LOCK: + # things do hold the lock to iterate over it so we need ot to add stuff ADAPTER_TYPES[adapter_name] = plugin.adapter PACKAGES[plugin.project_name] = plugin.include_path @@ -55,19 +71,16 @@ def load_plugin(adapter_name): return plugin.credentials -def get_adapter(config): +def get_adapter(config: RuntimeConfig) -> BaseAdapter: adapter_name = config.credentials.type + + # Atomically check to see if we already have an adapter if adapter_name in _ADAPTERS: return _ADAPTERS[adapter_name] - with _ADAPTER_LOCK: - if adapter_name not in ADAPTER_TYPES: - raise dbt.exceptions.RuntimeException( - "Could not find adapter type {}!".format(adapter_name) - ) - - adapter_type = ADAPTER_TYPES[adapter_name] + adapter_type = get_adapter_class_by_name(adapter_name) + with _ADAPTER_LOCK: # check again, in case something was setting it before if adapter_name in _ADAPTERS: return _ADAPTERS[adapter_name] diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 5b9c7f459cd..e96a5eae5cb 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -1,5 +1,8 @@ import abc import time +from typing import List, Optional, Tuple, Any, Iterable, Dict + +import agate import dbt.clients.agate_helper import dbt.exceptions @@ -18,16 +21,13 @@ class SQLConnectionManager(BaseConnectionManager): - open """ @abc.abstractmethod - def cancel(self, connection): - """Cancel the given connection. - - :param Connection connection: The connection to cancel. - """ + def cancel(self, connection: Connection): + """Cancel the given connection.""" raise dbt.exceptions.NotImplementedException( '`cancel` is not implemented for this adapter!' ) - def cancel_open(self): + def cancel_open(self) -> List[str]: names = [] this_connection = self.get_if_exists() with self.lock: @@ -39,11 +39,17 @@ def cancel_open(self): # nothing to cancel. if connection.handle is not None: self.cancel(connection) - names.append(connection.name) + if connection.name is not None: + names.append(connection.name) return names - def add_query(self, sql, auto_begin=True, bindings=None, - abridge_sql_log=False): + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False + ) -> Tuple[Connection, Any]: connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: self.begin() @@ -76,25 +82,25 @@ def add_query(self, sql, auto_begin=True, bindings=None, return connection, cursor @abc.abstractclassmethod - def get_status(cls, cursor): - """Get the status of the cursor. - - :param cursor: A database handle to get status from - :return: The current status - :rtype: str - """ + def get_status(cls, cursor: Any) -> str: + """Get the status of the cursor.""" raise dbt.exceptions.NotImplementedException( '`get_status` is not implemented for this adapter!' ) @classmethod - def process_results(cls, column_names, rows): + def process_results( + cls, + column_names: Iterable[str], + rows: Iterable[Any] + ) -> List[Dict[str, Any]]: + return [dict(zip(column_names, row)) for row in rows] @classmethod - def get_result_from_cursor(cls, cursor): - data = [] - column_names = [] + def get_result_from_cursor(cls, cursor: Any) -> agate.Table: + data: List[Any] = [] + column_names: List[str] = [] if cursor.description is not None: column_names = [col[0] for col in cursor.description] @@ -103,7 +109,9 @@ def get_result_from_cursor(cls, cursor): return dbt.clients.agate_helper.table_from_data(data, column_names) - def execute(self, sql, auto_begin=False, fetch=False): + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> Tuple[str, agate.Table]: _, cursor = self.add_query(sql, auto_begin) status = self.get_status(cursor) if fetch: diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 87bbb39db20..260b17b49d8 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -1,9 +1,12 @@ import agate +from typing import Any, Optional, Tuple, Type import dbt.clients.agate_helper +from dbt.contracts.connection import Connection import dbt.exceptions import dbt.flags from dbt.adapters.base import BaseAdapter, available +from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger @@ -35,18 +38,25 @@ class SQLAdapter(BaseAdapter): - list_relations_without_caching - get_columns_in_relation """ + ConnectionManager: Type[SQLConnectionManager] + connections: SQLConnectionManager + @available.parse(lambda *a, **k: (None, None)) - def add_query(self, sql, auto_begin=True, bindings=None, - abridge_sql_log=False): + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False, + ) -> Tuple[Connection, Any]: """Add a query to the current transaction. A thin wrapper around ConnectionManager.add_query. - :param str sql: The SQL query to add - :param bool auto_begin: If set and there is no transaction in progress, + :param sql: The SQL query to add + :param auto_begin: If set and there is no transaction in progress, begin a new one. - :param Optional[List[object]]: An optional list of bindings for the - query. - :param bool abridge_sql_log: If set, limit the raw sql logged to 512 + :param bindings: An optional list of bindings for the query. + :param abridge_sql_log: If set, limit the raw sql logged to 512 characters """ return self.connections.add_query(sql, auto_begin, bindings, diff --git a/core/dbt/api/__init__.py b/core/dbt/api/__init__.py deleted file mode 100644 index a6fe655f9c8..00000000000 --- a/core/dbt/api/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from dbt.api.object import APIObject - -__all__ = [ - 'APIObject' -] diff --git a/core/dbt/api/object.py b/core/dbt/api/object.py deleted file mode 100644 index d6408f1160e..00000000000 --- a/core/dbt/api/object.py +++ /dev/null @@ -1,125 +0,0 @@ -import copy -from collections import Mapping -from jsonschema import Draft7Validator - -from dbt.exceptions import JSONValidationException -from dbt.utils import deep_merge -from dbt.clients.system import write_json - - -class APIObject(Mapping): - """ - A serializable / deserializable object intended for - use in a future dbt API. - - To create a new object, you'll want to extend this - class, and then implement the SCHEMA property (a - valid JSON schema), the DEFAULTS property (default - settings for this object), and a static method that - calls this constructor. - """ - - SCHEMA = { - 'type': 'object', - 'properties': {} - } - - DEFAULTS = {} - - def __init__(self, **kwargs): - """ - Create and validate an instance. Note that if you override this, you - will want to do so by modifying kwargs and only then calling - super().__init__(**kwargs). - """ - super().__init__() - # note: deep_merge does a deep copy on its arguments. - self._contents = deep_merge(self.DEFAULTS, kwargs) - self.validate() - - def __str__(self): - return '{}(**{})'.format(self.__class__.__name__, self._contents) - - def __repr__(self): - return '{}(**{})'.format(self.__class__.__name__, self._contents) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return self.serialize() == other.serialize() - - def incorporate(self, **kwargs): - """ - Given a list of kwargs, incorporate these arguments - into a new copy of this instance, and return the new - instance after validating. - """ - return type(self)(**deep_merge(self._contents, kwargs)) - - def serialize(self): - """ - Return a dict representation of this object. - """ - return copy.deepcopy(self._contents) - - def write(self, path): - write_json(path, self.serialize()) - - @classmethod - def deserialize(cls, settings): - """ - Convert a dict representation of this object into - an actual object for internal use. - """ - return cls(**settings) - - def validate(self): - """ - Using the SCHEMA property, validate the attributes - of this instance. If any attributes are missing or - invalid, raise a ValidationException. - """ - validator = Draft7Validator(self.SCHEMA) - - errors = set() # make errors a set to avoid duplicates - - for error in validator.iter_errors(self.serialize()): - errors.add('.'.join( - list(map(str, error.path)) + [error.message] - )) - - if errors: - raise JSONValidationException(type(self).__name__, errors) - - # implement the Mapping protocol: - # https://docs.python.org/3/library/collections.abc.html - def __getitem__(self, key): - return self._contents[key] - - def __iter__(self): - return self._contents.__iter__() - - def __len__(self): - return self._contents.__len__() - - # implement this because everyone always expects it. - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def set(self, key, value): - self._contents[key] = value - - # most users of APIObject also expect the attributes to be available via - # dot-notation because the previous implementation assigned to __dict__. - # we should consider removing this if we fix all uses to have properties. - def __getattr__(self, name): - if name != '_contents' and name in self._contents: - return self._contents[name] - elif hasattr(self.__class__, name): - return getattr(self.__class__, name) - raise AttributeError(( - "'{}' object has no attribute '{}'" - ).format(type(self).__name__, name)) diff --git a/core/dbt/config/__init__.py b/core/dbt/config/__init__.py index 09b5523dae1..d18fd7f0790 100644 --- a/core/dbt/config/__init__.py +++ b/core/dbt/config/__init__.py @@ -1,5 +1,5 @@ # all these are just exports, they need "noqa" so flake8 will not complain. -from .renderer import ConfigRenderer # noqa from .profile import Profile, PROFILES_DIR, read_user_config # noqa from .project import Project # noqa from .runtime import RuntimeConfig # noqa +from .renderer import ConfigRenderer # noqa diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 2ef892f6844..4f196419ab7 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -1,5 +1,5 @@ from dbt.clients.jinja import get_rendered -from dbt.context.common import generate_config_context +from dbt.context.base import generate_config_context from dbt.exceptions import DbtProfileError from dbt.exceptions import DbtProjectError from dbt.exceptions import RecursionException diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index e89b4c3fe95..1e3bdc7dbdb 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -1,13 +1,13 @@ from copy import deepcopy +from .profile import Profile +from .project import Project from dbt.utils import parse_cli_vars from dbt.contracts.project import Configuration from dbt.exceptions import DbtProjectError from dbt.exceptions import validator_error_message from dbt.adapters.factory import get_relation_class_by_name -from .profile import Profile -from .project import Project from hologram import ValidationError @@ -72,11 +72,11 @@ def from_parts(cls, project, profile, args): :param args argparse.Namespace: The parsed command-line arguments. :returns RuntimeConfig: The new configuration. """ - quoting = deepcopy( + quoting = ( get_relation_class_by_name(profile.credentials.type) - .DEFAULTS['quote_policy'] - ) - quoting.update(project.quoting) + .get_default_quote_policy() + .replace_dict(project.quoting) + ).to_dict() return cls( project_name=project.project_name, diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py new file mode 100644 index 00000000000..3994694a13e --- /dev/null +++ b/core/dbt/context/base.py @@ -0,0 +1,138 @@ +import json +import os + +import dbt.tracking +from dbt.clients.jinja import undefined_error +from dbt.utils import merge + + +# These modules are added to the context. Consider alternative +# approaches which will extend well to potentially many modules +import pytz +import datetime + + +def add_tracking(context): + if dbt.tracking.active_user is not None: + context = merge(context, { + "run_started_at": dbt.tracking.active_user.run_started_at, + "invocation_id": dbt.tracking.active_user.invocation_id, + }) + else: + context = merge(context, { + "run_started_at": None, + "invocation_id": None + }) + + return context + + +def env_var(var, default=None): + if var in os.environ: + return os.environ[var] + elif default is not None: + return default + else: + msg = "Env var required but not provided: '{}'".format(var) + undefined_error(msg) + + +def debug_here(): + import sys + import ipdb + frame = sys._getframe(3) + ipdb.set_trace(frame) + + +class Var: + UndefinedVarError = "Required var '{}' not found in config:\nVars "\ + "supplied to {} = {}" + _VAR_NOTSET = object() + + def __init__(self, model, context, overrides): + self.model = model + self.context = context + + # These are hard-overrides (eg. CLI vars) that should take + # precedence over context-based var definitions + self.overrides = overrides + + if model is None: + # during config parsing we have no model and no local vars + self.model_name = '' + local_vars = {} + else: + self.model_name = model.name + local_vars = model.local_vars() + + self.local_vars = dbt.utils.merge(local_vars, overrides) + + def pretty_dict(self, data): + return json.dumps(data, sort_keys=True, indent=4) + + def get_missing_var(self, var_name): + pretty_vars = self.pretty_dict(self.local_vars) + msg = self.UndefinedVarError.format( + var_name, self.model_name, pretty_vars + ) + dbt.exceptions.raise_compiler_error(msg, self.model) + + def assert_var_defined(self, var_name, default): + if var_name not in self.local_vars and default is self._VAR_NOTSET: + return self.get_missing_var(var_name) + + def get_rendered_var(self, var_name): + raw = self.local_vars[var_name] + # if bool/int/float/etc are passed in, don't compile anything + if not isinstance(raw, str): + return raw + + return dbt.clients.jinja.get_rendered(raw, self.context) + + def __call__(self, var_name, default=_VAR_NOTSET): + if var_name in self.local_vars: + return self.get_rendered_var(var_name) + elif default is not self._VAR_NOTSET: + return default + else: + return self.get_missing_var(var_name) + + +def get_pytz_module_context(): + context_exports = pytz.__all__ + + return { + name: getattr(pytz, name) for name in context_exports + } + + +def get_datetime_module_context(): + context_exports = [ + 'date', + 'datetime', + 'time', + 'timedelta', + 'tzinfo' + ] + + return { + name: getattr(datetime, name) for name in context_exports + } + + +def get_context_modules(): + return { + 'pytz': get_pytz_module_context(), + 'datetime': get_datetime_module_context(), + } + + +def generate_config_context(cli_vars): + context = { + 'env_var': env_var, + 'modules': get_context_modules(), + } + context['var'] = Var(None, context, cli_vars) + if os.environ.get('DBT_MACRO_DEBUGGING'): + context['debug'] = debug_here + return add_tracking(context) diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index efc5b8d9f5e..b356392e56b 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -1,26 +1,21 @@ import json import os -from dbt.adapters.factory import get_adapter -from dbt.node_types import NodeType -from dbt.include.global_project import PACKAGES -from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME - -import dbt.clients.jinja import dbt.clients.agate_helper import dbt.exceptions import dbt.flags import dbt.tracking -import dbt.writer import dbt.utils - -from dbt.logger import GLOBAL_LOGGER as logger # noqa - - -# These modules are added to the context. Consider alternative -# approaches which will extend well to potentially many modules -import pytz -import datetime +import dbt.writer +from dbt.adapters.factory import get_adapter +from dbt.node_types import NodeType +from dbt.include.global_project import PACKAGES +from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME +from dbt.logger import GLOBAL_LOGGER as logger +from dbt.clients.jinja import get_rendered +from dbt.context.base import ( + debug_here, env_var, get_context_modules, add_tracking +) class RelationProxy: @@ -125,22 +120,29 @@ def _add_macros(context, model, manifest): return context -def _add_tracking(context): - if dbt.tracking.active_user is not None: - context = dbt.utils.merge(context, { - "run_started_at": dbt.tracking.active_user.run_started_at, - "invocation_id": dbt.tracking.active_user.invocation_id, - }) - else: - context = dbt.utils.merge(context, { - "run_started_at": None, - "invocation_id": None +def _store_result(sql_results): + def call(name, status, agate_table=None): + if agate_table is None: + agate_table = dbt.clients.agate_helper.empty_table() + + sql_results[name] = dbt.utils.AttrDict({ + 'status': status, + 'data': dbt.clients.agate_helper.as_matrix(agate_table), + 'table': agate_table }) + return '' - return context + return call -def _add_validation(context): +def _load_result(sql_results): + def call(name): + return sql_results.get(name) + + return call + + +def add_validation(context): def validate_any(*args): def inner(value): for arg in args: @@ -162,46 +164,7 @@ def inner(value): {'validation': validation_utils}) -def env_var(var, default=None): - if var in os.environ: - return os.environ[var] - elif default is not None: - return default - else: - msg = "Env var required but not provided: '{}'".format(var) - dbt.clients.jinja.undefined_error(msg) - - -def _store_result(sql_results): - def call(name, status, agate_table=None): - if agate_table is None: - agate_table = dbt.clients.agate_helper.empty_table() - - sql_results[name] = dbt.utils.AttrDict({ - 'status': status, - 'data': dbt.clients.agate_helper.as_matrix(agate_table), - 'table': agate_table - }) - return '' - - return call - - -def _load_result(sql_results): - def call(name): - return sql_results.get(name) - - return call - - -def _debug_here(): - import sys - import ipdb - frame = sys._getframe(3) - ipdb.set_trace(frame) - - -def _add_sql_handlers(context): +def add_sql_handlers(context): sql_results = {} return dbt.utils.merge(context, { '_sql_results': sql_results, @@ -210,68 +173,6 @@ def _add_sql_handlers(context): }) -def log(msg, info=False): - if info: - logger.info(msg) - else: - logger.debug(msg) - return '' - - -class Var: - UndefinedVarError = "Required var '{}' not found in config:\nVars "\ - "supplied to {} = {}" - _VAR_NOTSET = object() - - def __init__(self, model, context, overrides): - self.model = model - self.context = context - - # These are hard-overrides (eg. CLI vars) that should take - # precedence over context-based var definitions - self.overrides = overrides - - if model is None: - # during config parsing we have no model and no local vars - self.model_name = '' - local_vars = {} - else: - self.model_name = model.name - local_vars = model.local_vars() - - self.local_vars = dbt.utils.merge(local_vars, overrides) - - def pretty_dict(self, data): - return json.dumps(data, sort_keys=True, indent=4) - - def get_missing_var(self, var_name): - pretty_vars = self.pretty_dict(self.local_vars) - msg = self.UndefinedVarError.format( - var_name, self.model_name, pretty_vars - ) - dbt.exceptions.raise_compiler_error(msg, self.model) - - def assert_var_defined(self, var_name, default): - if var_name not in self.local_vars and default is self._VAR_NOTSET: - return self.get_missing_var(var_name) - - def get_rendered_var(self, var_name): - raw = self.local_vars[var_name] - # if bool/int/float/etc are passed in, don't compile anything - if not isinstance(raw, str): - return raw - - return dbt.clients.jinja.get_rendered(raw, self.context) - - def __call__(self, var_name, default=_VAR_NOTSET): - if var_name in self.local_vars: - return self.get_rendered_var(var_name) - elif default is not self._VAR_NOTSET: - return default - else: - return self.get_missing_var(var_name) - - def write(node, target_path, subdirectory): def fn(payload): node.build_path = dbt.writer.write_node( @@ -283,7 +184,7 @@ def fn(payload): def render(context, node): def fn(string): - return dbt.clients.jinja.get_rendered(string, context, node) + return get_rendered(string, context, node) return fn @@ -311,46 +212,17 @@ def impl(message_if_exception, func, *args, **kwargs): return impl -def _return(value): - raise dbt.exceptions.MacroReturn(value) - - -def get_pytz_module_context(): - context_exports = pytz.__all__ - - return { - name: getattr(pytz, name) for name in context_exports - } - - -def get_datetime_module_context(): - context_exports = [ - 'date', - 'datetime', - 'time', - 'timedelta', - 'tzinfo' - ] - - return { - name: getattr(datetime, name) for name in context_exports - } - - -def get_context_modules(): - return { - 'pytz': get_pytz_module_context(), - 'datetime': get_datetime_module_context(), - } +# Base context collection, used for parsing configs. +def log(msg, info=False): + if info: + logger.info(msg) + else: + logger.debug(msg) + return '' -def generate_config_context(cli_vars): - context = { - 'env_var': env_var, - 'modules': get_context_modules(), - } - context['var'] = Var(None, context, cli_vars) - return _add_tracking(context) +def _return(value): + raise dbt.exceptions.MacroReturn(value) def _build_load_agate_table(model): @@ -422,7 +294,7 @@ def generate_base(model, model_dict, config, manifest, source_config, "try_or_compiler_error": try_or_compiler_error(model) }) if os.environ.get('DBT_MACRO_DEBUGGING'): - context['debug'] = _debug_here + context['debug'] = debug_here return context @@ -430,9 +302,9 @@ def generate_base(model, model_dict, config, manifest, source_config, def modify_generated_context(context, model, config, manifest, provider): cli_var_overrides = config.cli_vars - context = _add_tracking(context) - context = _add_validation(context) - context = _add_sql_handlers(context) + context = add_tracking(context) + context = add_validation(context) + context = add_sql_handlers(context) # we make a copy of the context for each of these ^^ diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 2d476c0595b..0e8d879ba4a 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -89,7 +89,7 @@ def __getattr__(self, name): ) -class Var(dbt.context.common.Var): +class Var(dbt.context.base.Var): def get_missing_var(self, var_name): # in the parser, just always return None. return None diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index de8cc730eb8..c97a75b50dc 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -1,6 +1,7 @@ from dbt.utils import get_materialization, add_ephemeral_model_prefix import dbt.clients.jinja +import dbt.context.base import dbt.context.common import dbt.flags from dbt.parser.util import ParserUtils @@ -144,7 +145,7 @@ def __getattr__(self, name): ) -class Var(dbt.context.common.Var): +class Var(dbt.context.base.Var): pass diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 2ed88ae38c1..56b4be0adf2 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -1,11 +1,17 @@ +import abc +from dataclasses import dataclass, field +from typing import ( + Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType +) + +from hologram import JsonSchemaMixin from hologram.helpers import ( StrEnum, register_pattern, ExtensibleJsonSchemaMixin ) -from hologram import JsonSchemaMixin + from dbt.contracts.util import Replaceable +from dbt.utils import translate_aliases -from dataclasses import dataclass -from typing import Any, Optional, NewType Identifier = NewType('Identifier', str) register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$') @@ -58,3 +64,55 @@ def handle(self): @handle.setter def handle(self, value): self._handle = value + + +# see https://github.com/python/mypy/issues/4717#issuecomment-373932080 +# and https://github.com/python/mypy/issues/5374 +# for why we have type: ignore. Maybe someday dataclasses + abstract classes +# will work. +@dataclass +class Credentials( # type: ignore + ExtensibleJsonSchemaMixin, + Replaceable, + metaclass=abc.ABCMeta +): + database: str + schema: str + _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) + + @abc.abstractproperty + def type(self) -> str: + raise NotImplementedError( + 'type not implemented for base credentials class' + ) + + def connection_info(self) -> Iterable[Tuple[str, Any]]: + """Return an ordered iterator of key/value pairs for pretty-printing. + """ + as_dict = self.to_dict() + for key in self._connection_keys(): + if key in as_dict: + yield key, as_dict[key] + + @abc.abstractmethod + def _connection_keys(self) -> Tuple[str, ...]: + raise NotImplementedError + + @classmethod + def from_dict(cls, data): + data = cls.translate_aliases(data) + return super().from_dict(data) + + @classmethod + def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return translate_aliases(kwargs, cls._ALIASES) + + def to_dict(self, omit_none=True, validate=False, with_aliases=False): + serialized = super().to_dict(omit_none=omit_none, validate=validate) + if with_aliases: + serialized.update({ + new_name: serialized[canonical_name] + for new_name, canonical_name in self._ALIASES.items() + if canonical_name in serialized + }) + return serialized diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index b2fbe834c12..842d9bbc87f 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -1,7 +1,7 @@ -from dbt.clients.system import write_json - import dataclasses +from dbt.clients.system import write_json + class Replaceable: def replace(self, **kwargs): diff --git a/core/dbt/deprecations.py b/core/dbt/deprecations.py index 5f4ff68a2d2..3a54e95395c 100644 --- a/core/dbt/deprecations.py +++ b/core/dbt/deprecations.py @@ -75,6 +75,15 @@ class MaterializationReturnDeprecation(DBTDeprecation): '''.lstrip() +class NotADictionaryDeprecation(DBTDeprecation): + _name = 'not-a-dictionary' + + _description = ''' + The object ("{obj}") was used as a dictionary. In a future version of dbt + this capability will be removed from objects of this type. + '''.lstrip() + + _adapter_renamed_description = """\ The adapter function `adapter.{old_name}` is deprecated and will be removed in a future release of dbt. Please use `adapter.{new_name}` instead. @@ -113,6 +122,7 @@ def warn(name, *args, **kwargs): DBTRepositoriesDeprecation(), GenerateSchemaNameSingleArgDeprecated(), MaterializationReturnDeprecation(), + NotADictionaryDeprecation(), ] deprecations: Dict[str, DBTDeprecation] = { diff --git a/core/dbt/include/global_project/macros/adapters/common.sql b/core/dbt/include/global_project/macros/adapters/common.sql index a6be4d8cbc1..7725f2e0981 100644 --- a/core/dbt/include/global_project/macros/adapters/common.sql +++ b/core/dbt/include/global_project/macros/adapters/common.sql @@ -265,8 +265,7 @@ {% macro default__make_temp_relation(base_relation, suffix) %} {% set tmp_identifier = base_relation.identifier ~ suffix %} {% set tmp_relation = base_relation.incorporate( - path={"identifier": tmp_identifier}, - table_name=tmp_identifier) -%} + path={"identifier": tmp_identifier}) -%} {% do return(tmp_relation) %} {% endmacro %} diff --git a/core/dbt/include/global_project/macros/materializations/seed/seed.sql b/core/dbt/include/global_project/macros/materializations/seed/seed.sql index ca836dc88e2..f83f845e3ea 100644 --- a/core/dbt/include/global_project/macros/materializations/seed/seed.sql +++ b/core/dbt/include/global_project/macros/materializations/seed/seed.sql @@ -15,7 +15,7 @@ {%- set column_override = model['config'].get('column_types', {}) -%} {% set sql %} - create table {{ this.render(False) }} ( + create table {{ this.render() }} ( {%- for col_name in agate_table.column_names -%} {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} {%- set type = column_override.get(col_name, inferred_type) -%} @@ -60,7 +60,7 @@ {% endfor %} {% set sql %} - insert into {{ this.render(False) }} ({{ cols_sql }}) values + insert into {{ this.render() }} ({{ cols_sql }}) values {% for row in chunk -%} ({%- for column in agate_table.column_names -%} %s diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index a94030554bf..0c9737e6056 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -3,7 +3,7 @@ from hologram import ValidationError -from dbt.context.common import generate_config_context +from dbt.context.base import generate_config_context from dbt.clients.jinja import get_rendered from dbt.clients.yaml_helper import load_yaml_text diff --git a/core/dbt/tracking.py b/core/dbt/tracking.py index 26e8eff18af..6c78d12b9b3 100644 --- a/core/dbt/tracking.py +++ b/core/dbt/tracking.py @@ -4,8 +4,6 @@ from snowplow_tracker import SelfDescribingJson from datetime import datetime -from dbt.adapters.factory import get_adapter - import pytz import platform import uuid @@ -125,6 +123,8 @@ def get_run_type(args): def get_invocation_context(user, config, args): + # put this in here to avoid an import cycle + from dbt.adapters.factory import get_adapter try: adapter_type = get_adapter(config).type() except Exception: diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 10016c05999..652556cbff9 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -437,9 +437,8 @@ def parse_cli_vars(var_string): V_T = TypeVar('V_T') -def filter_null_values(input: Dict[K_T, V_T]) -> Dict[K_T, V_T]: - return dict((k, v) for (k, v) in input.items() - if v is not None) +def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]: + return {k: v for k, v in input.items() if v is not None} def add_ephemeral_model_prefix(s: str) -> str: @@ -522,3 +521,15 @@ def env_set_truthy(key: str) -> Optional[str]: def restrict_to(*restrictions): """Create the metadata for a restricted dataclass field""" return {'restrict': list(restrictions)} + + +# some types need to make constants available to the jinja context as +# attributes, and regular properties only work with objects. maybe this should +# be handled by the RelationProxy? + +class classproperty(object): + def __init__(self, func): + self.func = func + + def __get__(self, obj, objtype): + return self.func(objtype) diff --git a/plugins/bigquery/dbt/adapters/bigquery/__init__.py b/plugins/bigquery/dbt/adapters/bigquery/__init__.py index c456567722c..daff48a32ee 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/__init__.py +++ b/plugins/bigquery/dbt/adapters/bigquery/__init__.py @@ -1,7 +1,7 @@ from dbt.adapters.bigquery.connections import BigQueryConnectionManager # noqa from dbt.adapters.bigquery.connections import BigQueryCredentials from dbt.adapters.bigquery.relation import BigQueryRelation # noqa -from dbt.adapters.bigquery.relation import BigQueryColumn # noqa +from dbt.adapters.bigquery.column import BigQueryColumn # noqa from dbt.adapters.bigquery.impl import BigQueryAdapter from dbt.adapters.base import AdapterPlugin diff --git a/plugins/bigquery/dbt/adapters/bigquery/column.py b/plugins/bigquery/dbt/adapters/bigquery/column.py new file mode 100644 index 00000000000..8c8a442b412 --- /dev/null +++ b/plugins/bigquery/dbt/adapters/bigquery/column.py @@ -0,0 +1,121 @@ +from dataclasses import dataclass +from typing import Optional, List, TypeVar, Iterable, Type + +from dbt.adapters.base.column import Column + +from google.cloud.bigquery import SchemaField + +Self = TypeVar('Self', bound='BigQueryColumn') + + +@dataclass(init=False) +class BigQueryColumn(Column): + TYPE_LABELS = { + 'STRING': 'STRING', + 'TIMESTAMP': 'TIMESTAMP', + 'FLOAT': 'FLOAT64', + 'INTEGER': 'INT64', + 'RECORD': 'RECORD', + } + fields: List[Self] + mode: str + + def __init__( + self, + column: str, + dtype: str, + fields: Optional[Iterable[SchemaField]] = None, + mode: str = 'NULLABLE', + ) -> None: + super().__init__(column, dtype) + + if fields is None: + fields = [] + + self.fields = self.wrap_subfields(fields) + self.mode = mode + + @classmethod + def wrap_subfields( + cls: Type[Self], fields: Iterable[SchemaField] + ) -> List[Self]: + return [cls.create_from_field(field) for field in fields] + + @classmethod + def create_from_field(cls: Type[Self], field: SchemaField) -> Self: + return cls( + field.name, + cls.translate_type(field.field_type), + field.fields, + field.mode, + ) + + @classmethod + def _flatten_recursive( + cls: Type[Self], col: Self, prefix: Optional[str] = None + ) -> List[Self]: + if prefix is None: + prefix = [] + + if len(col.fields) == 0: + prefixed_name = ".".join(prefix + [col.column]) + new_col = cls(prefixed_name, col.dtype, col.fields, col.mode) + return [new_col] + + new_fields = [] + for field in col.fields: + new_prefix = prefix + [col.column] + new_fields.extend(cls._flatten_recursive(field, new_prefix)) + + return new_fields + + def flatten(self): + return self._flatten_recursive(self) + + @property + def quoted(self): + return '`{}`'.format(self.column) + + def literal(self, value): + return "cast({} as {})".format(value, self.dtype) + + @property + def data_type(self) -> str: + if self.dtype.upper() == 'RECORD': + subcols = [ + "{} {}".format(col.name, col.data_type) for col in self.fields + ] + field_type = 'STRUCT<{}>'.format(", ".join(subcols)) + + else: + field_type = self.dtype + + if self.mode.upper() == 'REPEATED': + return 'ARRAY<{}>'.format(field_type) + + else: + return field_type + + def is_string(self) -> bool: + return self.dtype.lower() == 'string' + + def is_numeric(self) -> bool: + return False + + def can_expand_to(self: Self, other_column: Self) -> bool: + """returns True if both columns are strings""" + return self.is_string() and other_column.is_string() + + def __repr__(self) -> str: + return "".format(self.name, self.data_type, + self.mode) + + def column_to_bq_schema(self) -> SchemaField: + """Convert a column to a bigquery schema object. + """ + kwargs = {} + if len(self.fields) > 0: + fields = [field.column_to_bq_schema() for field in self.fields] + kwargs = {"fields": fields} + + return SchemaField(self.name, self.dtype, self.mode, **kwargs) diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index 00df1a82ace..b1bf8b44e6c 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -92,7 +92,7 @@ def exception_handler(self, sql): raise raise dbt.exceptions.RuntimeException(str(e)) - def cancel_open(self): + def cancel_open(self) -> None: pass @classmethod diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index fe793ad5fc8..596aff2a66c 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -6,8 +6,10 @@ import dbt.clients.gcloud import dbt.clients.agate_helper -from dbt.adapters.base import BaseAdapter, available -from dbt.adapters.bigquery import BigQueryRelation +from dbt.adapters.base import BaseAdapter, available, RelationType +from dbt.adapters.bigquery.relation import ( + BigQueryRelation +) from dbt.adapters.bigquery import BigQueryColumn from dbt.adapters.bigquery import BigQueryConnectionManager from dbt.contracts.connection import Connection @@ -36,9 +38,9 @@ def _stub_relation(*args, **kwargs): class BigQueryAdapter(BaseAdapter): RELATION_TYPES = { - 'TABLE': BigQueryRelation.Table, - 'VIEW': BigQueryRelation.View, - 'EXTERNAL': BigQueryRelation.External + 'TABLE': RelationType.Table, + 'VIEW': RelationType.View, + 'EXTERNAL': RelationType.External } Relation = BigQueryRelation @@ -102,7 +104,7 @@ def get_columns_in_relation(self, relation): table = self.connections.get_bq_table( database=relation.database, schema=relation.schema, - identifier=relation.table_name + identifier=relation.identifier ) return self._get_dbt_columns_from_bq_table(table) diff --git a/plugins/bigquery/dbt/adapters/bigquery/relation.py b/plugins/bigquery/dbt/adapters/bigquery/relation.py index 8110adc4c7a..509ae8e4e98 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/relation.py +++ b/plugins/bigquery/dbt/adapters/bigquery/relation.py @@ -1,60 +1,26 @@ -from dbt.adapters.base.relation import BaseRelation, Column -from dbt.utils import filter_null_values +from dataclasses import dataclass +from typing import Optional -import google.cloud.bigquery +from dbt.adapters.base.relation import ( + BaseRelation, ComponentName +) +from dbt.utils import filter_null_values +@dataclass(frozen=True, eq=False, repr=False) class BigQueryRelation(BaseRelation): - External = "external" - - DEFAULTS = { - 'metadata': { - 'type': 'BigQueryRelation' - }, - 'quote_character': '`', - 'quote_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'include_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'dbt_created': False, - } - - SCHEMA = { - 'type': 'object', - 'properties': { - 'metadata': { - 'type': 'object', - 'properties': { - 'type': { - 'type': 'string', - 'const': 'BigQueryRelation', - }, - }, - }, - 'type': { - 'enum': BaseRelation.RelationTypes + [External, None], - }, - 'path': BaseRelation.PATH_SCHEMA, - 'include_policy': BaseRelation.POLICY_SCHEMA, - 'quote_policy': BaseRelation.POLICY_SCHEMA, - 'quote_character': {'type': 'string'}, - 'dbt_created': {'type': 'boolean'}, - }, - 'required': ['metadata', 'type', 'path', 'include_policy', - 'quote_policy', 'quote_character', 'dbt_created'] - } - - def matches(self, database=None, schema=None, identifier=None): + quote_character: str = '`' + + def matches( + self, + database: Optional[str] = None, + schema: Optional[str] = None, + identifier: Optional[str] = None, + ) -> bool: search = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) if not search: @@ -67,145 +33,10 @@ def matches(self, database=None, schema=None, identifier=None): return True - @classmethod - def create(cls, database=None, schema=None, - identifier=None, table_name=None, - type=None, **kwargs): - if table_name is None: - table_name = identifier - - return cls(type=type, - path={ - 'database': database, - 'schema': schema, - 'identifier': identifier - }, - table_name=table_name, - **kwargs) - - def quote(self, database=None, schema=None, identifier=None): - policy = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier - }) - - return self.incorporate(quote_policy=policy) - - @property - def database(self): - return self.path.get('database') - @property def project(self): - return self.path.get('database') - - @property - def schema(self): - return self.path.get('schema') + return self.database @property def dataset(self): - return self.path.get('schema') - - @property - def identifier(self): - return self.path.get('identifier') - - -class BigQueryColumn(Column): - TYPE_LABELS = { - 'STRING': 'STRING', - 'TIMESTAMP': 'TIMESTAMP', - 'FLOAT': 'FLOAT64', - 'INTEGER': 'INT64', - 'RECORD': 'RECORD', - } - - def __init__(self, column, dtype, fields=None, mode='NULLABLE'): - super().__init__(column, dtype) - - if fields is None: - fields = [] - - self.fields = self.wrap_subfields(fields) - self.mode = mode - - @classmethod - def wrap_subfields(cls, fields): - return [BigQueryColumn.create_from_field(field) for field in fields] - - @classmethod - def create_from_field(cls, field): - return BigQueryColumn(field.name, cls.translate_type(field.field_type), - field.fields, field.mode) - - @classmethod - def _flatten_recursive(cls, col, prefix=None): - if prefix is None: - prefix = [] - - if len(col.fields) == 0: - prefixed_name = ".".join(prefix + [col.column]) - new_col = BigQueryColumn(prefixed_name, col.dtype, col.fields, - col.mode) - return [new_col] - - new_fields = [] - for field in col.fields: - new_prefix = prefix + [col.column] - new_fields.extend(cls._flatten_recursive(field, new_prefix)) - - return new_fields - - def flatten(self): - return self._flatten_recursive(self) - - @property - def quoted(self): - return '`{}`'.format(self.column) - - def literal(self, value): - return "cast({} as {})".format(value, self.dtype) - - @property - def data_type(self): - if self.dtype.upper() == 'RECORD': - subcols = [ - "{} {}".format(col.name, col.data_type) for col in self.fields - ] - field_type = 'STRUCT<{}>'.format(", ".join(subcols)) - - else: - field_type = self.dtype - - if self.mode.upper() == 'REPEATED': - return 'ARRAY<{}>'.format(field_type) - - else: - return field_type - - def is_string(self): - return self.dtype.lower() == 'string' - - def is_numeric(self): - return False - - def can_expand_to(self, other_column): - """returns True if both columns are strings""" - return self.is_string() and other_column.is_string() - - def __repr__(self): - return "".format(self.name, self.data_type, - self.mode) - - def column_to_bq_schema(self): - """Convert a column to a bigquery schema object. - """ - kwargs = {} - if len(self.fields) > 0: - fields = [field.column_to_bq_schema() for field in self.fields] - kwargs = {"fields": fields} - - return google.cloud.bigquery.SchemaField(self.name, self.dtype, - self.mode, **kwargs) + return self.schema diff --git a/plugins/postgres/dbt/include/postgres/macros/adapters.sql b/plugins/postgres/dbt/include/postgres/macros/adapters.sql index aa4852c7e56..f8892741686 100644 --- a/plugins/postgres/dbt/include/postgres/macros/adapters.sql +++ b/plugins/postgres/dbt/include/postgres/macros/adapters.sql @@ -111,7 +111,6 @@ {% macro postgres__make_temp_relation(base_relation, suffix) %} {% set tmp_identifier = base_relation.identifier ~ suffix ~ py_current_timestring() %} {% do return(base_relation.incorporate( - table_name=tmp_identifier, path={ "identifier": tmp_identifier, "schema": none, diff --git a/plugins/snowflake/dbt/adapters/snowflake/relation.py b/plugins/snowflake/dbt/adapters/snowflake/relation.py index 0c6b8555484..217292d8d17 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/relation.py +++ b/plugins/snowflake/dbt/adapters/snowflake/relation.py @@ -1,46 +1,14 @@ -from dbt.adapters.base.relation import BaseRelation +from dataclasses import dataclass +from dbt.adapters.base.relation import BaseRelation, Policy -class SnowflakeRelation(BaseRelation): - DEFAULTS = { - 'metadata': { - 'type': 'SnowflakeRelation' - }, - 'quote_character': '"', - 'quote_policy': { - 'database': False, - 'schema': False, - 'identifier': False, - }, - 'include_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'dbt_created': False, - } +@dataclass +class SnowflakeQuotePolicy(Policy): + database: bool = False + schema: bool = False + identifier: bool = False + - SCHEMA = { - 'type': 'object', - 'properties': { - 'metadata': { - 'type': 'object', - 'properties': { - 'type': { - 'type': 'string', - 'const': 'SnowflakeRelation', - }, - }, - }, - 'type': { - 'enum': BaseRelation.RelationTypes + [None], - }, - 'path': BaseRelation.PATH_SCHEMA, - 'include_policy': BaseRelation.POLICY_SCHEMA, - 'quote_policy': BaseRelation.POLICY_SCHEMA, - 'quote_character': {'type': 'string'}, - 'dbt_created': {'type': 'boolean'}, - }, - 'required': ['metadata', 'type', 'path', 'include_policy', - 'quote_policy', 'quote_character', 'dbt_created'] - } +@dataclass(frozen=True, eq=False, repr=False) +class SnowflakeRelation(BaseRelation): + quote_policy: SnowflakeQuotePolicy = SnowflakeQuotePolicy() diff --git a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql index 65bfe435680..993c539638c 100644 --- a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql +++ b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql @@ -16,7 +16,7 @@ temporary {%- elif transient -%} transient - {%- endif %} table {{ relation }} {% if copy_grants and not temporary -%} copy grants {%- endif %} as + {%- endif %} table {{ relation }} {% if copy_grants and not temporary -%} copy grants {%- endif %} as ( {%- if cluster_by_string is not none -%} select * from( @@ -83,7 +83,7 @@ case when table_type = 'BASE TABLE' then 'table' when table_type = 'VIEW' then 'view' when table_type = 'MATERIALIZED VIEW' then 'materializedview' - when table_type = 'EXTERNAL TABLE' then 'externaltable' + when table_type = 'EXTERNAL TABLE' then 'external' else table_type end as table_type from {{ information_schema }}.tables diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index 579772b854b..471cda24d04 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -99,6 +99,7 @@ def run_test(self): self.assertEqual(self.query_state['view_model'], 'good') self.assertEqual(self.query_state['model_1'], 'good') + class TableTestConcurrentTransaction(BaseTestConcurrentTransaction): @property def models(self): @@ -109,6 +110,7 @@ def test__redshift__concurrent_transaction_table(self): self.reset() self.run_test() + class ViewTestConcurrentTransaction(BaseTestConcurrentTransaction): @property def models(self): @@ -119,6 +121,7 @@ def test__redshift__concurrent_transaction_view(self): self.reset() self.run_test() + class IncrementalTestConcurrentTransaction(BaseTestConcurrentTransaction): @property def models(self): diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index f54050bef91..e7fd4a9228f 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -1,9 +1,10 @@ import unittest from unittest.mock import patch, MagicMock +import hologram + import dbt.flags as flags -from dbt.adapters.bigquery import BigQueryCredentials from dbt.adapters.bigquery import BigQueryAdapter from dbt.adapters.bigquery import BigQueryRelation import dbt.exceptions @@ -164,7 +165,7 @@ def setUp(self): self.mock_connection_manager = self.conn_manager_cls.return_value self.conn_manager_cls.TYPE = 'bigquery' - self.relation_cls.DEFAULTS = BigQueryRelation.DEFAULTS + self.relation_cls.get_default_quote_policy.side_effect = BigQueryRelation.get_default_quote_policy self.adapter = self.get_adapter('oauth') @@ -190,7 +191,7 @@ def test_drop_schema(self, mock_check_schema): def test_get_columns_in_relation(self): self.mock_connection_manager.get_bq_table.side_effect = ValueError self.adapter.get_columns_in_relation( - MagicMock(database='db', schema='schema', table_name='ident'), + MagicMock(database='db', schema='schema', identifier='ident'), ) self.mock_connection_manager.get_bq_table.assert_called_once_with( database='db', schema='schema', identifier='ident' @@ -209,12 +210,11 @@ def test_view_temp_relation(self): 'schema': 'test_schema', 'identifier': 'my_view' }, - 'table_name': 'my_view__dbt_tmp', 'quote_policy': { 'identifier': False } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_view_relation(self): kwargs = { @@ -224,13 +224,12 @@ def test_view_relation(self): 'schema': 'test_schema', 'identifier': 'my_view' }, - 'table_name': 'my_view', 'quote_policy': { 'identifier': True, 'schema': True } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_table_relation(self): kwargs = { @@ -240,13 +239,12 @@ def test_table_relation(self): 'schema': 'test_schema', 'identifier': 'generic_table' }, - 'table_name': 'generic_table', 'quote_policy': { 'identifier': True, 'schema': True } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_external_source_relation(self): kwargs = { @@ -256,13 +254,12 @@ def test_external_source_relation(self): 'schema': 'test_schema', 'identifier': 'sheet' }, - 'table_name': 'sheet', 'quote_policy': { 'identifier': True, 'schema': True } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_invalid_relation(self): kwargs = { @@ -272,11 +269,10 @@ def test_invalid_relation(self): 'schema': 'test_schema', 'identifier': 'my_invalid_id' }, - 'table_name': 'my_invalid_id', 'quote_policy': { 'identifier': False, 'schema': True } } - with self.assertRaises(dbt.exceptions.ValidationException): - BigQueryRelation(**kwargs) + with self.assertRaises(hologram.ValidationError): + BigQueryRelation.from_dict(kwargs) diff --git a/test/unit/test_cache.py b/test/unit/test_cache.py index 86e14915b5c..0c314350b9d 100644 --- a/test/unit/test_cache.py +++ b/test/unit/test_cache.py @@ -14,8 +14,7 @@ def make_relation(database, schema, identifier): def make_mock_relationship(database, schema, identifier): return BaseRelation.create( - database=database, schema=schema, identifier=identifier, - table_name=identifier, type='view' + database=database, schema=schema, identifier=identifier, type='view' ) diff --git a/test/unit/utils.py b/test/unit/utils.py index 51d7cf45d8e..b0675f7efa5 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -60,7 +60,6 @@ def inject_adapter(value): artisanal adapter will be available from get_adapter() as if dbt loaded it. """ from dbt.adapters import factory - from dbt.adapters.base.connections import BaseConnectionManager key = value.type() factory._ADAPTERS[key] = value factory.ADAPTER_TYPES[key] = type(value) diff --git a/tox.ini b/tox.ini index 67466513ed8..9ac72d6586c 100644 --- a/tox.ini +++ b/tox.ini @@ -12,8 +12,9 @@ deps = [testenv:mypy] basepython = python3.6 commands = /bin/bash -c '$(which mypy) \ - core/dbt/adapters/base/impl.py \ - core/dbt/adapters/base/meta.py \ + core/dbt/adapters/base \ + core/dbt/adapters/sql \ + core/dbt/adapters/cache.py \ core/dbt/clients \ core/dbt/config \ core/dbt/deprecations.py \ From dfb4b3a2c8ce4f4f3322d1040ba6fccc2b85f54c Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 2 Oct 2019 07:12:17 -0600 Subject: [PATCH 2/2] PR feedback add typing extensions module to setup.py Update changelog --- CHANGELOG.md | 3 +-- core/setup.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ab3924a5f65..c7dbf712de4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## dbt 0.15.0 (TBD) ### Breaking changes + - The 'table_name' parameter to relations has been removed - Cache management changes: - Materialization macros should now return a dictionary {"relations": [...]}, with the list containing all relations that have been added, in order to add them to the cache. The default behavior is to still add the materialization's model to the cache. - Materializations that perform drops via direct "drop" statements must call `adapter.cache_dropped` @@ -18,8 +19,6 @@ This is a bugfix release. ### Under the hood: - Provide a programmatic method for validating profile targets ([#1754](https://github.com/fishtown-analytics/dbt/issues/1754), [#1775](https://github.com/fishtown-analytics/dbt/pull/1775)) ->>>>>>> dev/0.14.3 - ## dbt 0.14.2 (September 13, 2019) ### Overview diff --git a/core/setup.py b/core/setup.py index a3a96aa110f..c53cf6f1c97 100644 --- a/core/setup.py +++ b/core/setup.py @@ -59,5 +59,6 @@ def read(fname): 'hologram==0.0.3', 'logbook>=1.5,<1.6', 'pytest-logbook>=1.2.0,<1.3', + 'typing-extensions>=3.7.4,<3.8', ] )