Skip to content

Commit

Permalink
Merge pull request #1909 from fishtown-analytics/feature/honor-thread…
Browse files Browse the repository at this point in the history
…s-flag

add threads to commands (#1897)
  • Loading branch information
beckjake committed Nov 11, 2019
2 parents abc7662 + 6ab537d commit c26cf19
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 4 deletions.
3 changes: 3 additions & 0 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class RPCExecParameters(RPCParameters):

@dataclass
class RPCCompileParameters(RPCParameters):
threads: Optional[int] = None
models: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCSnapshotParameters(RPCParameters):
threads: Optional[int] = None
select: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None

Expand All @@ -59,6 +61,7 @@ class RPCTestParameters(RPCCompileParameters):

@dataclass
class RPCSeedParameters(RPCParameters):
threads: Optional[int] = None
show: bool = False


Expand Down
4 changes: 4 additions & 0 deletions core/dbt/rpc/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def task_exec(self) -> None:
handler = QueueLogHandler(self.queue)
with handler.applicationbound():
self._spawn_setup()
# copy threads over into our credentials, if it exists and is set.
# some commands, like 'debug', won't have a threads value at all.
if getattr(self.task.args, 'threads', None) is not None:
self.task.config.threads = self.task.args.threads
rpc_exception = None
result = None
try:
Expand Down
10 changes: 10 additions & 0 deletions core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class RemoteCompileProjectTask(
def set_args(self, params: RPCCompileParameters) -> None:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
if params.threads is not None:
self.args.threads = params.threads


class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask):
Expand All @@ -63,12 +65,16 @@ class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask):
def set_args(self, params: RPCCompileParameters) -> None:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
if params.threads is not None:
self.args.threads = params.threads


class RemoteSeedProjectTask(RPCCommandTask[RPCSeedParameters], SeedTask):
METHOD_NAME = 'seed'

def set_args(self, params: RPCSeedParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads
self.args.show = params.show


Expand All @@ -80,6 +86,8 @@ def set_args(self, params: RPCTestParameters) -> None:
self.args.exclude = self._listify(params.exclude)
self.args.data = params.data
self.args.schema = params.schema
if params.threads is not None:
self.args.threads = params.threads


class RemoteDocsGenerateProjectTask(
Expand Down Expand Up @@ -140,3 +148,5 @@ def set_args(self, params: RPCSnapshotParameters) -> None:
# select has an argparse `dest` value of `models`.
self.args.models = self._listify(params.select)
self.args.exclude = self._listify(params.exclude)
if params.threads is not None:
self.args.threads = params.threads
3 changes: 1 addition & 2 deletions core/dbt/task/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def run_forever(self):
'Send requests to http://{}:{}/jsonrpc'.format(display_host, port)
)

app = self.handle_request
app = DispatcherMiddleware(app, {
app = DispatcherMiddleware(self.handle_request, {
'/jsonrpc': self.handle_jsonrpc_request,
})

Expand Down
135 changes: 135 additions & 0 deletions test/rpc/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: disable=redefined-outer-name
import time
import yaml
from .util import (
ProjectDefinition, rpc_server, Querier, built_schema, get_querier,
)
Expand Down Expand Up @@ -458,3 +459,137 @@ def test_snapshots_cli(project_root, profiles_root, postgres_profile, unique_sch
token = querier.is_async_result(querier.cli_args(cli='snapshot --select=snapshot_actual'))
results = querier.is_result(querier.async_wait(token))
assert len(results['results']) == 1


def assert_has_threads(results, num_threads):
assert 'logs' in results
c_logs = [l for l in results['logs'] if 'Concurrency' in l['message']]
assert len(c_logs) == 1, \
f'Got invalid number of concurrency logs ({len(c_logs)})'
assert 'message' in c_logs[0]
assert f'Concurrency: {num_threads} threads' in c_logs[0]['message']


def test_rpc_run_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
models={'my_model.sql': 'select 1 as id'}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
token = querier.is_async_result(querier.run(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('run --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_compile_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
models={'my_model.sql': 'select 1 as id'}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
token = querier.is_async_result(querier.compile(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('compile --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_test_threads(project_root, profiles_root, postgres_profile, unique_schema):
schema_yaml = {
'version': 2,
'models': [{
'name': 'my_model',
'columns': [
{
'name': 'id',
'tests': ['not_null', 'unique'],
},
],
}],
}
project = ProjectDefinition(
models={
'my_model.sql': 'select 1 as id',
'schema.yml': yaml.safe_dump(schema_yaml)}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
# first run dbt to get the model built
token = querier.is_async_result(querier.run())
querier.is_result(querier.async_wait(token))

token = querier.is_async_result(querier.test(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('test --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_snapshot_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
snapshots={'my_snapshots.sql': snapshot_data},
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)

with querier_ctx as querier:
token = querier.is_async_result(querier.snapshot(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('snapshot --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_seed_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
seeds={'data.csv': 'a,b\n1,hello\n2,goodbye'}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)

with querier_ctx as querier:
token = querier.is_async_result(querier.seed(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('seed --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)
29 changes: 27 additions & 2 deletions test/rpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,16 @@ def compile(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if models is not None:
params['models'] = models
if exclude is not None:
params['exclude'] = exclude
if threads is not None:
params['threads'] = threads
return self.request(
method='compile', params=params, request_id=request_id
)
Expand All @@ -208,13 +211,16 @@ def run(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if models is not None:
params['models'] = models
if exclude is not None:
params['exclude'] = exclude
if threads is not None:
params['threads'] = threads
return self.request(
method='run', params=params, request_id=request_id
)
Expand All @@ -232,10 +238,17 @@ def run_operation(
method='run-operation', params=params, request_id=request_id
)

def seed(self, show: bool = None, request_id: int = 1):
def seed(
self,
show: bool = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if show is not None:
params['show'] = show
if threads is not None:
params['threads'] = threads
return self.request(
method='seed', params=params, request_id=request_id
)
Expand All @@ -244,13 +257,16 @@ def snapshot(
self,
select: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if select is not None:
params['select'] = select
if exclude is not None:
params['exclude'] = exclude
if threads is not None:
params['threads'] = threads
return self.request(
method='snapshot', params=params, request_id=request_id
)
Expand All @@ -259,6 +275,7 @@ def test(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
data: bool = None,
schema: bool = None,
request_id: int = 1,
Expand All @@ -272,6 +289,8 @@ def test(
params['data'] = data
if schema is not None:
params['schema'] = schema
if threads is not None:
params['threads'] = threads
return self.request(
method='test', params=params, request_id=request_id
)
Expand Down Expand Up @@ -406,6 +425,7 @@ def __init__(
models=None,
macros=None,
snapshots=None,
seeds=None,
):
self.project = {
'name': name,
Expand All @@ -418,10 +438,11 @@ def __init__(
self.models = models
self.macros = macros
self.snapshots = snapshots
self.seeds = seeds

def _write_recursive(self, path, inputs):
for name, value in inputs.items():
if name.endswith('.sql'):
if name.endswith('.sql') or name.endswith('.csv'):
path.join(name).write(value)
elif name.endswith('.yml'):
if isinstance(value, str):
Expand Down Expand Up @@ -464,6 +485,9 @@ def write_macros(self, project_dir, remove=False):
def write_snapshots(self, project_dir, remove=False):
self._write_values(project_dir, remove, 'snapshots', self.snapshots)

def write_seeds(self, project_dir, remove=False):
self._write_values(project_dir, remove, 'data', self.seeds)

def write_to(self, project_dir, remove=False):
if remove:
project_dir.remove()
Expand All @@ -473,6 +497,7 @@ def write_to(self, project_dir, remove=False):
self.write_models(project_dir)
self.write_macros(project_dir)
self.write_snapshots(project_dir)
self.write_seeds(project_dir)


class TestArgs:
Expand Down

0 comments on commit c26cf19

Please sign in to comment.