Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add threads to commands (#1897) #1909

Merged
merged 1 commit into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class RPCExecParameters(RPCParameters):

@dataclass
class RPCCompileParameters(RPCParameters):
threads: Optional[int] = None
models: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None


@dataclass
class RPCSnapshotParameters(RPCParameters):
threads: Optional[int] = None
select: Union[None, str, List[str]] = None
exclude: Union[None, str, List[str]] = None

Expand All @@ -59,6 +61,7 @@ class RPCTestParameters(RPCCompileParameters):

@dataclass
class RPCSeedParameters(RPCParameters):
threads: Optional[int] = None
show: bool = False


Expand Down
4 changes: 4 additions & 0 deletions core/dbt/rpc/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def task_exec(self) -> None:
handler = QueueLogHandler(self.queue)
with handler.applicationbound():
self._spawn_setup()
# copy threads over into our credentials, if it exists and is set.
# some commands, like 'debug', won't have a threads value at all.
if getattr(self.task.args, 'threads', None) is not None:
self.task.config.threads = self.task.args.threads
rpc_exception = None
result = None
try:
Expand Down
10 changes: 10 additions & 0 deletions core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class RemoteCompileProjectTask(
def set_args(self, params: RPCCompileParameters) -> None:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
if params.threads is not None:
self.args.threads = params.threads


class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask):
Expand All @@ -63,12 +65,16 @@ class RemoteRunProjectTask(RPCCommandTask[RPCCompileParameters], RunTask):
def set_args(self, params: RPCCompileParameters) -> None:
self.args.models = self._listify(params.models)
self.args.exclude = self._listify(params.exclude)
if params.threads is not None:
self.args.threads = params.threads


class RemoteSeedProjectTask(RPCCommandTask[RPCSeedParameters], SeedTask):
METHOD_NAME = 'seed'

def set_args(self, params: RPCSeedParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads
self.args.show = params.show


Expand All @@ -80,6 +86,8 @@ def set_args(self, params: RPCTestParameters) -> None:
self.args.exclude = self._listify(params.exclude)
self.args.data = params.data
self.args.schema = params.schema
if params.threads is not None:
self.args.threads = params.threads


class RemoteDocsGenerateProjectTask(
Expand Down Expand Up @@ -140,3 +148,5 @@ def set_args(self, params: RPCSnapshotParameters) -> None:
# select has an argparse `dest` value of `models`.
self.args.models = self._listify(params.select)
self.args.exclude = self._listify(params.exclude)
if params.threads is not None:
self.args.threads = params.threads
3 changes: 1 addition & 2 deletions core/dbt/task/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def run_forever(self):
'Send requests to http://{}:{}/jsonrpc'.format(display_host, port)
)

app = self.handle_request
app = DispatcherMiddleware(app, {
app = DispatcherMiddleware(self.handle_request, {
'/jsonrpc': self.handle_jsonrpc_request,
})

Expand Down
135 changes: 135 additions & 0 deletions test/rpc/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: disable=redefined-outer-name
import time
import yaml
from .util import (
ProjectDefinition, rpc_server, Querier, built_schema, get_querier,
)
Expand Down Expand Up @@ -458,3 +459,137 @@ def test_snapshots_cli(project_root, profiles_root, postgres_profile, unique_sch
token = querier.is_async_result(querier.cli_args(cli='snapshot --select=snapshot_actual'))
results = querier.is_result(querier.async_wait(token))
assert len(results['results']) == 1


def assert_has_threads(results, num_threads):
assert 'logs' in results
c_logs = [l for l in results['logs'] if 'Concurrency' in l['message']]
assert len(c_logs) == 1, \
f'Got invalid number of concurrency logs ({len(c_logs)})'
assert 'message' in c_logs[0]
assert f'Concurrency: {num_threads} threads' in c_logs[0]['message']


def test_rpc_run_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
models={'my_model.sql': 'select 1 as id'}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
token = querier.is_async_result(querier.run(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('run --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_compile_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
models={'my_model.sql': 'select 1 as id'}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
token = querier.is_async_result(querier.compile(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('compile --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_test_threads(project_root, profiles_root, postgres_profile, unique_schema):
schema_yaml = {
'version': 2,
'models': [{
'name': 'my_model',
'columns': [
{
'name': 'id',
'tests': ['not_null', 'unique'],
},
],
}],
}
project = ProjectDefinition(
models={
'my_model.sql': 'select 1 as id',
'schema.yml': yaml.safe_dump(schema_yaml)}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)
with querier_ctx as querier:
# first run dbt to get the model built
token = querier.is_async_result(querier.run())
querier.is_result(querier.async_wait(token))

token = querier.is_async_result(querier.test(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('test --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_snapshot_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
snapshots={'my_snapshots.sql': snapshot_data},
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)

with querier_ctx as querier:
token = querier.is_async_result(querier.snapshot(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('snapshot --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)


def test_rpc_seed_threads(project_root, profiles_root, postgres_profile, unique_schema):
project = ProjectDefinition(
seeds={'data.csv': 'a,b\n1,hello\n2,goodbye'}
)
querier_ctx = get_querier(
project_def=project,
project_dir=project_root,
profiles_dir=profiles_root,
schema=unique_schema,
test_kwargs={},
)

with querier_ctx as querier:
token = querier.is_async_result(querier.seed(threads=5))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 5)

token = querier.is_async_result(querier.cli_args('seed --threads=7'))
results = querier.is_result(querier.async_wait(token))
assert_has_threads(results, 7)
29 changes: 27 additions & 2 deletions test/rpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,16 @@ def compile(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if models is not None:
params['models'] = models
if exclude is not None:
params['exclude'] = exclude
if threads is not None:
params['threads'] = threads
return self.request(
method='compile', params=params, request_id=request_id
)
Expand All @@ -208,13 +211,16 @@ def run(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if models is not None:
params['models'] = models
if exclude is not None:
params['exclude'] = exclude
if threads is not None:
params['threads'] = threads
return self.request(
method='run', params=params, request_id=request_id
)
Expand All @@ -232,10 +238,17 @@ def run_operation(
method='run-operation', params=params, request_id=request_id
)

def seed(self, show: bool = None, request_id: int = 1):
def seed(
self,
show: bool = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if show is not None:
params['show'] = show
if threads is not None:
params['threads'] = threads
return self.request(
method='seed', params=params, request_id=request_id
)
Expand All @@ -244,13 +257,16 @@ def snapshot(
self,
select: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
request_id: int = 1,
):
params = {}
if select is not None:
params['select'] = select
if exclude is not None:
params['exclude'] = exclude
if threads is not None:
params['threads'] = threads
return self.request(
method='snapshot', params=params, request_id=request_id
)
Expand All @@ -259,6 +275,7 @@ def test(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
threads: Optional[int] = None,
data: bool = None,
schema: bool = None,
request_id: int = 1,
Expand All @@ -272,6 +289,8 @@ def test(
params['data'] = data
if schema is not None:
params['schema'] = schema
if threads is not None:
params['threads'] = threads
return self.request(
method='test', params=params, request_id=request_id
)
Expand Down Expand Up @@ -406,6 +425,7 @@ def __init__(
models=None,
macros=None,
snapshots=None,
seeds=None,
):
self.project = {
'name': name,
Expand All @@ -418,10 +438,11 @@ def __init__(
self.models = models
self.macros = macros
self.snapshots = snapshots
self.seeds = seeds

def _write_recursive(self, path, inputs):
for name, value in inputs.items():
if name.endswith('.sql'):
if name.endswith('.sql') or name.endswith('.csv'):
path.join(name).write(value)
elif name.endswith('.yml'):
if isinstance(value, str):
Expand Down Expand Up @@ -464,6 +485,9 @@ def write_macros(self, project_dir, remove=False):
def write_snapshots(self, project_dir, remove=False):
self._write_values(project_dir, remove, 'snapshots', self.snapshots)

def write_seeds(self, project_dir, remove=False):
self._write_values(project_dir, remove, 'data', self.seeds)

def write_to(self, project_dir, remove=False):
if remove:
project_dir.remove()
Expand All @@ -473,6 +497,7 @@ def write_to(self, project_dir, remove=False):
self.write_models(project_dir)
self.write_macros(project_dir)
self.write_snapshots(project_dir)
self.write_seeds(project_dir)


class TestArgs:
Expand Down