Skip to content

Commit

Permalink
Handle the fallout of closing connections in release
Browse files Browse the repository at this point in the history
- close() implies rollback, so do not call it
- make sure to not open new connections for executors in single-threaded mode
- logging cleanups
- fix a test case that never acquired connections
- to cancel other connections, one must first acquire a connection for the master thread
  • Loading branch information
Jacob Beck committed Jul 27, 2020
1 parent c1a92aa commit 08802d8
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 100 deletions.
22 changes: 12 additions & 10 deletions core/dbt/adapters/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,8 @@ def release(self) -> None:
return

try:
if conn.state == 'open':
if conn.transaction_open is True:
self._rollback(conn)
# always close the connection
# always close the connection. close() calls _rollback() if there
# is an open transaction
self.close(conn)
except Exception:
# if rollback or close failed, remove our busted connection
Expand Down Expand Up @@ -230,11 +228,13 @@ def _close_handle(cls, connection: Connection) -> None:
"""Perform the actual close operation."""
# On windows, sometimes connection handles don't have a close() attr.
if hasattr(connection.handle, 'close'):
logger.debug('On {}: Close'.format(connection.name))
import traceback
tbtext = ''.join(traceback.format_stack())
print(f'Closing {connection.name}:\n{tbtext}')
logger.debug(f'On {connection.name}: Close')
connection.handle.close()
else:
logger.debug('On {}: No close available on handle'
.format(connection.name))
logger.debug(f'On {connection.name}: No close available on handle')

@classmethod
def _rollback(cls, connection: Connection) -> None:
Expand All @@ -247,10 +247,11 @@ def _rollback(cls, connection: Connection) -> None:

if connection.transaction_open is False:
raise dbt.exceptions.InternalException(
'Tried to rollback transaction on connection "{}", but '
'it does not have one open!'.format(connection.name))
f'Tried to rollback transaction on connection '
f'"{connection.name}", but it does not have one open!'
)

logger.debug('On {}: ROLLBACK'.format(connection.name))
logger.debug(f'On {connection.name}: ROLLBACK')
cls._rollback_handle(connection)

connection.transaction_open = False
Expand All @@ -268,6 +269,7 @@ def close(cls, connection: Connection) -> Connection:
return connection

if connection.transaction_open and connection.handle:
logger.debug('On {}: ROLLBACK'.format(connection.name))
cls._rollback_handle(connection)
connection.transaction_open = False

Expand Down
72 changes: 38 additions & 34 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,6 @@ def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
# databases
return info_schema_name_map

def _list_relations_get_connection(
self, schema_relation: BaseRelation
) -> List[BaseRelation]:
name = f'list_{schema_relation.database}_{schema_relation.schema}'
with self.connection_named(name):
return self.list_relations_without_caching(schema_relation)

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
Expand All @@ -328,10 +321,16 @@ def _relations_cache_for_schemas(self, manifest: Manifest) -> None:

cache_schemas = self._get_cache_schemas(manifest)
with executor(self.config) as tpe:
futures: List[Future[List[BaseRelation]]] = [
tpe.submit(self._list_relations_get_connection, cache_schema)
for cache_schema in cache_schemas
]
futures: List[Future[List[BaseRelation]]] = []
for cache_schema in cache_schemas:
fut = tpe.submit_connected(
self,
f'list_{cache_schema.database}_{cache_schema.schema}',
self.list_relations_without_caching,
cache_schema
)
futures.append(fut)

for future in as_completed(futures):
# if we can't read the relations we need to just raise anyway,
# so just call future.result() and let that raise on failure
Expand Down Expand Up @@ -1001,24 +1000,18 @@ def _get_one_catalog(
manifest: Manifest,
) -> agate.Table:

name = '.'.join([
str(information_schema.database),
'information_schema'
])

with self.connection_named(name):
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
release=True,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)
kwargs = {
'information_schema': information_schema,
'schemas': schemas
}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
kwargs=kwargs,
release=False,
# pass in the full manifest so we get any local project
# overrides
manifest=manifest,
)

results = self._catalog_filter_table(table, manifest)
return results
Expand All @@ -1029,10 +1022,21 @@ def get_catalog(
schema_map = self._get_catalog_schemas(manifest)

with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = [
tpe.submit(self._get_one_catalog, info, schemas, manifest)
for info, schemas in schema_map.items() if len(schemas) > 0
]
futures: List[Future[agate.Table]] = []
for info, schemas in schema_map.items():
if len(schemas) == 0:
continue
name = '.'.join([
str(info.database),
'information_schema'
])

fut = tpe.submit_connected(
self, name,
self._get_one_catalog, info, schemas, manifest
)
futures.append(fut)

catalogs, exceptions = catch_as_completed(futures)

return catalogs, exceptions
Expand All @@ -1059,7 +1063,7 @@ def calculate_freshness(
table = self.execute_macro(
FRESHNESS_MACRO_NAME,
kwargs=kwargs,
release=True,
release=False,
manifest=manifest
)
# now we have a 1-row table of the maximum `loaded_at_field` value and
Expand Down
48 changes: 27 additions & 21 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,15 @@ def _cancel_connections(self, pool):
dbt.ui.printer.print_timestamped_line(msg, yellow)

else:
for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print anyway.
dbt.ui.printer.print_cancel_line(conn_name)
with adapter.connection_named('master'):
for conn_name in adapter.cancel_open_connections():
if self.manifest is not None:
node = self.manifest.nodes.get(conn_name)
if node is not None and node.is_ephemeral_model:
continue
# if we don't have a manifest/don't have a node, print
# anyway.
dbt.ui.printer.print_cancel_line(conn_name)

pool.join()

Expand Down Expand Up @@ -457,18 +459,15 @@ def list_schemas(
db_lowercase = dbt.utils.lowercase(db_only.database)
if db_only.database is None:
database_quoted = None
conn_name = 'list_schemas'
else:
database_quoted = str(db_only)
conn_name = f'list_{db_only.database}'

with adapter.connection_named(conn_name):
# we should never create a null schema, so just filter them out
return [
(db_lowercase, s.lower())
for s in adapter.list_schemas(database_quoted)
if s is not None
]
# we should never create a null schema, so just filter them out
return [
(db_lowercase, s.lower())
for s in adapter.list_schemas(database_quoted)
if s is not None
]

def create_schema(relation: BaseRelation) -> None:
db = relation.database or ''
Expand All @@ -480,9 +479,13 @@ def create_schema(relation: BaseRelation) -> None:
create_futures = []

with dbt.utils.executor(self.config) as tpe:
list_futures = [
tpe.submit(list_schemas, db) for db in required_databases
]
for req in required_databases:
if req.database is None:
name = 'list_schemas'
else:
name = f'list_{req.database}'
fut = tpe.submit_connected(adapter, name, list_schemas, req)
list_futures.append(fut)

for ls_future in as_completed(list_futures):
existing_schemas_lowered.update(ls_future.result())
Expand All @@ -499,9 +502,12 @@ def create_schema(relation: BaseRelation) -> None:
db_schema = (db_lower, schema.lower())
if db_schema not in existing_schemas_lowered:
existing_schemas_lowered.add(db_schema)
create_futures.append(
tpe.submit(create_schema, info)

fut = tpe.submit_connected(
adapter, f'create_{info.database or ""}_{info.schema}',
create_schema, info
)
create_futures.append(fut)

for create_future in as_completed(create_futures):
# trigger/re-raise any excceptions while creating schemas
Expand Down
31 changes: 26 additions & 5 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
import json
import os
from contextlib import contextmanager
from enum import Enum
from typing_extensions import Protocol
from typing import (
Expand Down Expand Up @@ -518,8 +519,16 @@ def format_bytes(num_bytes):
return "> 1024 TB"


class ConnectingExecutor(concurrent.futures.Executor):
def submit_connected(self, adapter, conn_name, func, *args, **kwargs):
def connected(conn_name, func, *args, **kwargs):
with self.connection_named(adapter, conn_name):
return func(*args, **kwargs)
return self.submit(connected, conn_name, func, *args, **kwargs)


# a little concurrent.futures.Executor for single-threaded mode
class SingleThreadedExecutor(concurrent.futures.Executor):
class SingleThreadedExecutor(ConnectingExecutor):
def submit(*args, **kwargs):
# this basic pattern comes from concurrent.futures.Executor itself,
# but without handling the `fn=` form.
Expand All @@ -544,6 +553,20 @@ def submit(*args, **kwargs):
fut.set_result(result)
return fut

@contextmanager
def connection_named(self, adapter, name):
yield


class MultiThreadedExecutor(
ConnectingExecutor,
concurrent.futures.ThreadPoolExecutor,
):
@contextmanager
def connection_named(self, adapter, name):
with adapter.connection_named(name):
yield


class ThreadedArgs(Protocol):
single_threaded: bool
Expand All @@ -554,13 +577,11 @@ class HasThreadingConfig(Protocol):
threads: Optional[int]


def executor(config: HasThreadingConfig) -> concurrent.futures.Executor:
def executor(config: HasThreadingConfig) -> ConnectingExecutor:
if config.args.single_threaded:
return SingleThreadedExecutor()
else:
return concurrent.futures.ThreadPoolExecutor(
max_workers=config.threads
)
return MultiThreadedExecutor(max_workers=config.threads)


def fqn_search(
Expand Down
3 changes: 2 additions & 1 deletion plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from dbt.utils import format_bytes
from dbt.clients import agate_helper, gcloud
from dbt.contracts.connection import ConnectionState
from dbt.exceptions import (
FailedToConnectException, RuntimeException, DatabaseException
)
Expand Down Expand Up @@ -111,7 +112,7 @@ def cancel_open(self) -> None:

@classmethod
def close(cls, connection):
connection.state = 'closed'
connection.state = ConnectionState.CLOSED

return connection

Expand Down
1 change: 0 additions & 1 deletion plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def _rollback_handle(cls, connection):
"""On snowflake, rolling back the handle of an aborted session raises
an exception.
"""
logger.debug('initiating rollback')
try:
connection.handle.rollback()
except snowflake.connector.errors.ProgrammingError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def do_test_file(self, filename):
with open(file_path) as fh:
query = fh.read()

status, table = self.adapter.execute(query, auto_begin=False, fetch=True)
with self.adapter.connection_named('master'):
status, table = self.adapter.execute(query, auto_begin=False, fetch=True)
self.assertTrue(len(table.columns) > 0, "agate table had no columns")
self.assertTrue(len(table.rows) > 0, "agate table had no rows")

Expand Down
Loading

0 comments on commit 08802d8

Please sign in to comment.