diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index 74e7ed68811..b3c532e5a6c 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -3,7 +3,6 @@ from dbt.contracts.graph.parsed import ParsedSourceDefinition from dbt.contracts.util import Writable, Replaceable from dbt.logger import ( - LogMessage, TimingProcessor, JsonOnly, GLOBAL_LOGGER as logger, @@ -207,35 +206,6 @@ class FreshnessRunOutput(JsonSchemaMixin, Writable): sources: Dict[str, SourceFreshnessRunResult] -@dataclass -class RemoteCompileResult(JsonSchemaMixin): - raw_sql: str - compiled_sql: str - node: CompileResultNode - timing: List[TimingInfo] - logs: List[LogMessage] - - @property - def error(self): - return None - - -@dataclass -class RemoteExecutionResult(ExecutionResult): - logs: List[LogMessage] - - -@dataclass -class ResultTable(JsonSchemaMixin): - column_names: List[str] - rows: List[Any] - - -@dataclass -class RemoteRunResult(RemoteCompileResult): - table: ResultTable - - Primitive = Union[bool, str, float, None] CatalogKey = NamedTuple( @@ -298,8 +268,3 @@ class CatalogResults(JsonSchemaMixin, Writable): nodes: Dict[str, CatalogTable] generated_at: datetime _compile_results: Optional[Any] = None - - -@dataclass -class RemoteCatalogResults(CatalogResults): - logs: List[LogMessage] = field(default_factory=list) diff --git a/core/dbt/contracts/rpc.py b/core/dbt/contracts/rpc.py new file mode 100644 index 00000000000..a90e5986d2e --- /dev/null +++ b/core/dbt/contracts/rpc.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass, field +from numbers import Real +from typing import Optional, Union, List, Any, Dict + +from hologram import JsonSchemaMixin + +from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.results import ( + TimingInfo, + CatalogResults, + ExecutionResult, +) +from dbt.logger import LogMessage + +# Inputs + + +@dataclass +class RPCParameters(JsonSchemaMixin): + timeout: Optional[Real] + task_tags: Optional[Dict[str, Any]] + + +@dataclass +class RPCExecParameters(RPCParameters): + name: str + sql: str + macros: Optional[str] + + +@dataclass +class RPCCompileParameters(RPCParameters): + models: Union[None, str, List[str]] = None + exclude: Union[None, str, List[str]] = None + + +@dataclass +class RPCTestParameters(RPCCompileParameters): + data: bool = False + schema: bool = False + + +@dataclass +class RPCSeedParameters(RPCParameters): + show: bool = False + + +@dataclass +class RPCDocsGenerateParameters(RPCParameters): + compile: bool = True + + +@dataclass +class RPCCliParameters(RPCParameters): + cli: str + + +# Outputs + + +@dataclass +class RemoteCatalogResults(CatalogResults): + logs: List[LogMessage] = field(default_factory=list) + + +@dataclass +class RemoteCompileResult(JsonSchemaMixin): + raw_sql: str + compiled_sql: str + node: CompileResultNode + timing: List[TimingInfo] + logs: List[LogMessage] + + @property + def error(self): + return None + + +@dataclass +class RemoteExecutionResult(ExecutionResult): + logs: List[LogMessage] + + +@dataclass +class ResultTable(JsonSchemaMixin): + column_names: List[str] + rows: List[Any] + + +@dataclass +class RemoteRunResult(RemoteCompileResult): + table: ResultTable diff --git a/core/dbt/helper_types.py b/core/dbt/helper_types.py index 25e9f85506e..e45c4bc01f0 100644 --- a/core/dbt/helper_types.py +++ b/core/dbt/helper_types.py @@ -1,5 +1,6 @@ # never name this package "types", or mypy will crash in ugly ways from datetime import timedelta +from numbers import Real from typing import NewType from hologram import ( @@ -37,7 +38,14 @@ def json_schema(self) -> JsonDict: return {'type': 'number'} +class RealEncoder(FieldEncoder): + @property + def json_schema(self): + return {'type': 'number'} + + JsonSchemaMixin.register_field_encoders({ Port: PortEncoder(), timedelta: TimeDeltaFieldEncoder(), + Real: RealEncoder(), }) diff --git a/core/dbt/rpc/error.py b/core/dbt/rpc/error.py index 7f32d04328d..d5ec8702948 100644 --- a/core/dbt/rpc/error.py +++ b/core/dbt/rpc/error.py @@ -12,6 +12,7 @@ def __init__( message: Optional[str] = None, data: Optional[Dict[str, Any]] = None, logs: Optional[List[Dict[str, Any]]] = None, + tags: Optional[Dict[str, Any]] = None ) -> None: if code is None: code = -32000 @@ -23,6 +24,7 @@ def __init__( super().__init__(code=code, message=message, data=data) if logs is not None: self.logs = logs + self.error.data['tags'] = tags def __str__(self): return ( @@ -40,9 +42,25 @@ def logs(self, value): return self.error.data['logs'] = value + @property + def tags(self): + return self.error.data.get('tags') + + @tags.setter + def tags(self, value): + if value is None: + return + self.error.data['tags'] = value + @classmethod def from_error(cls, err): - return cls(err.code, err.message, err.data, err.data.get('logs')) + return cls( + code=err.code, + message=err.message, + data=err.data, + logs=err.data.get('logs'), + tags=err.data.get('tags'), + ) def invalid_params(data): @@ -53,17 +71,17 @@ def invalid_params(data): ) -def server_error(err, logs=None): +def server_error(err, logs=None, tags=None): exc = dbt.exceptions.Exception(str(err)) - return dbt_error(exc, logs) + return dbt_error(exc, logs, tags) -def timeout_error(timeout_value, logs=None): +def timeout_error(timeout_value, logs=None, tags=None): exc = dbt.exceptions.RPCTimeoutException(timeout_value) - return dbt_error(exc, logs) + return dbt_error(exc, logs, tags) -def dbt_error(exc, logs=None): +def dbt_error(exc, logs=None, tags=None): exc = RPCException(code=exc.CODE, message=exc.MESSAGE, data=exc.data(), - logs=logs) + logs=logs, tags=tags) return exc diff --git a/core/dbt/rpc/logger.py b/core/dbt/rpc/logger.py index 1f7b733af6e..2bd06e84f19 100644 --- a/core/dbt/rpc/logger.py +++ b/core/dbt/rpc/logger.py @@ -9,7 +9,7 @@ from queue import Empty from typing import Optional, Any, Union -from dbt.contracts.results import ( +from dbt.contracts.rpc import ( RemoteCompileResult, RemoteExecutionResult, RemoteCatalogResults ) from dbt.exceptions import InternalException diff --git a/core/dbt/rpc/node_runners.py b/core/dbt/rpc/node_runners.py index 6138a922db9..39108ff3a47 100644 --- a/core/dbt/rpc/node_runners.py +++ b/core/dbt/rpc/node_runners.py @@ -1,6 +1,9 @@ +from abc import abstractmethod +from typing import Generic, TypeVar + import dbt.exceptions from dbt.compilation import compile_node -from dbt.contracts.results import ( +from dbt.contracts.rpc import ( RemoteCompileResult, RemoteRunResult, ResultTable, ) from dbt.logger import GLOBAL_LOGGER as logger @@ -8,9 +11,14 @@ from dbt.rpc.error import dbt_error, RPCException, server_error -class RPCCompileRunner(CompileRunner): +RPCSQLResult = TypeVar('RPCSQLResult', bound=RemoteCompileResult) + + +class GenericRPCRunner(CompileRunner, Generic[RPCSQLResult]): def __init__(self, config, adapter, node, node_index, num_nodes): - super().__init__(config, adapter, node, node_index, num_nodes) + CompileRunner.__init__( + self, config, adapter, node, node_index, num_nodes + ) def handle_exception(self, e, ctx): logger.debug('Got an exception: {}'.format(e), exc_info=True) @@ -33,14 +41,13 @@ def compile(self, manifest): return compile_node(self.adapter, self.config, self.node, manifest, {}, write=False) - def execute(self, compiled_node, manifest): - return RemoteCompileResult( - raw_sql=compiled_node.raw_sql, - compiled_sql=compiled_node.injected_sql, - node=compiled_node, - timing=[], # this will get added later - logs=[], - ) + @abstractmethod + def execute(self, compiled_node, manifest) -> RPCSQLResult: + pass + + @abstractmethod + def from_run_result(self, result, start_time, timing_info) -> RPCSQLResult: + pass def error_result(self, node, error, start_time, timing_info): raise error @@ -50,34 +57,38 @@ def ephemeral_result(self, node, start_time, timing_info): 'cannot execute ephemeral nodes remotely!' ) - def from_run_result(self, result, start_time, timing_info): + +class RPCCompileRunner(GenericRPCRunner[RemoteCompileResult]): + def execute(self, compiled_node, manifest) -> RemoteCompileResult: return RemoteCompileResult( - raw_sql=result.raw_sql, - compiled_sql=result.compiled_sql, - node=result.node, - timing=timing_info, + raw_sql=compiled_node.raw_sql, + compiled_sql=compiled_node.injected_sql, + node=compiled_node, + timing=[], # this will get added later logs=[], ) - -class RPCExecuteRunner(RPCCompileRunner): - def from_run_result(self, result, start_time, timing_info): - return RemoteRunResult( + def from_run_result( + self, result, start_time, timing_info + ) -> RemoteCompileResult: + return RemoteCompileResult( raw_sql=result.raw_sql, compiled_sql=result.compiled_sql, node=result.node, - table=result.table, timing=timing_info, logs=[], ) - def execute(self, compiled_node, manifest): - status, table = self.adapter.execute(compiled_node.injected_sql, - fetch=True) + +class RPCExecuteRunner(GenericRPCRunner[RemoteRunResult]): + def execute(self, compiled_node, manifest) -> RemoteRunResult: + _, execute_result = self.adapter.execute( + compiled_node.injected_sql, fetch=True + ) table = ResultTable( - column_names=list(table.column_names), - rows=[list(row) for row in table], + column_names=list(execute_result.column_names), + rows=[list(row) for row in execute_result], ) return RemoteRunResult( @@ -88,3 +99,15 @@ def execute(self, compiled_node, manifest): timing=[], logs=[], ) + + def from_run_result( + self, result, start_time, timing_info + ) -> RemoteRunResult: + return RemoteRunResult( + raw_sql=result.raw_sql, + compiled_sql=result.compiled_sql, + node=result.node, + table=result.table, + timing=timing_info, + logs=[], + ) diff --git a/core/dbt/rpc/task.py b/core/dbt/rpc/task.py index e9f8f7ef6e0..e698db1b955 100644 --- a/core/dbt/rpc/task.py +++ b/core/dbt/rpc/task.py @@ -1,51 +1,54 @@ import base64 import inspect -from abc import ABCMeta, abstractmethod -from typing import Union, List, Optional, Type +from abc import abstractmethod +from typing import Union, List, Optional, Type, TypeVar, Generic -from hologram import JsonSchemaMixin - -from dbt.exceptions import NotImplementedException +from dbt.contracts.rpc import RPCParameters +from dbt.exceptions import NotImplementedException, InternalException from dbt.rpc.logger import RemoteCallableResult, RemoteExecutionResult from dbt.rpc.error import invalid_params from dbt.task.compile import CompileTask -class RemoteCallable(metaclass=ABCMeta): +Parameters = TypeVar('Parameters', bound=RPCParameters) +Result = TypeVar('Result', bound=RemoteCallableResult) + + +class RemoteCallable(Generic[Parameters, Result]): METHOD_NAME: Optional[str] = None is_async = False @classmethod - def get_parameters(cls) -> Type[JsonSchemaMixin]: + def get_parameters(cls) -> Type[Parameters]: argspec = inspect.getfullargspec(cls.set_args) annotations = argspec.annotations if 'params' not in annotations: - raise TypeError( + raise InternalException( 'set_args must have parameter named params with a valid ' - 'JsonSchemaMixin type definition (no params annotation found)' + 'RPCParameters type definition (no params annotation found)' ) params_type = annotations['params'] - if not issubclass(params_type, JsonSchemaMixin): - raise TypeError( + if not issubclass(params_type, RPCParameters): + raise InternalException( 'set_args must have parameter named params with a valid ' - 'JsonSchemaMixin type definition (got {}, expected ' - 'JsonSchemaMixin subclass)'.format(params_type) + 'RPCParameters type definition (got {}, expected ' + 'RPCParameters subclass)'.format(params_type) ) - if params_type is JsonSchemaMixin: - raise TypeError( + if params_type is RPCParameters: + raise InternalException( 'set_args must have parameter named params with a valid ' - 'JsonSchemaMixin type definition (got JsonSchemaMixin itself!)' + 'RPCParameters type definition (got RPCParameters itself!)' ) return params_type @abstractmethod - def set_args(self, params: JsonSchemaMixin): + def set_args(self, params: Parameters): raise NotImplementedException( 'set_args not implemented' ) @abstractmethod - def handle_request(self) -> RemoteCallableResult: + def handle_request(self) -> Result: raise NotImplementedException( 'handle_request not implemented' ) @@ -88,15 +91,20 @@ def raise_invalid_base64(sql): ) -class RPCTask(CompileTask, RemoteCallable): +# If you call recursive_subclasses on a subclass of RPCTask, it should only +# return subtypes of the given subclass. +T = TypeVar('T', bound='RPCTask') + + +class RPCTask(CompileTask, RemoteCallable[Parameters, RemoteExecutionResult]): def __init__(self, args, config, manifest): super().__init__(args, config) self._base_manifest = manifest.deepcopy(config=config) @classmethod def recursive_subclasses( - cls, named_only: bool = True - ) -> List[Type['RPCTask']]: + cls: Type[T], named_only: bool = True + ) -> List[Type[T]]: classes = [] current = [cls] while current: @@ -108,7 +116,9 @@ def recursive_subclasses( classes = [c for c in classes if c.METHOD_NAME is not None] return classes - def get_result(self, results, elapsed_time, generated_at): + def get_result( + self, results, elapsed_time, generated_at + ) -> RemoteExecutionResult: return RemoteExecutionResult( results=results, elapsed_time=elapsed_time, diff --git a/core/dbt/rpc/task_handler.py b/core/dbt/rpc/task_handler.py index eaee2795d9f..9d260607929 100644 --- a/core/dbt/rpc/task_handler.py +++ b/core/dbt/rpc/task_handler.py @@ -3,15 +3,15 @@ import sys import threading import uuid -from contextlib import contextmanager from datetime import datetime -from typing import Any, Dict, Union, Optional, List +from typing import Any, Dict, Union, Optional, List, Type from hologram import JsonSchemaMixin, ValidationError from hologram.helpers import StrEnum import dbt.exceptions from dbt.adapters.factory import cleanup_connections +from dbt.contracts.rpc import RPCParameters from dbt.logger import ( GLOBAL_LOGGER as logger, list_handler, LogMessage, OutputHandler ) @@ -135,6 +135,52 @@ def _task_bootstrap( handler.emit_error(error.error) +class StateHandler: + """A helper context manager to manage task handler state.""" + def __init__(self, task_handler: 'RequestTaskHandler') -> None: + self.handler = task_handler + + def __enter__(self) -> None: + return None + + def set_end(self): + self.handler.ended = datetime.utcnow() + + def handle_success(self): + self.handler.state = TaskHandlerState.Success + self.set_end() + + def handle_error(self, exc_type, exc_value, exc_tb) -> bool: + if isinstance(exc_value, RPCException): + self.handler.error = exc_value + self.handler.state = TaskHandlerState.Error + elif isinstance(exc_value, dbt.exceptions.Exception): + self.handler.error = dbt_error(exc_value) + self.handler.state = TaskHandlerState.Error + else: + # we should only get here if we got a BaseException that is not + # an Exception (we caught those in _wait_for_results), or a bug + # in get_result's call stack. Either way, we should set an + # error so we can figure out what happened on thread death + self.handler.error = server_error(exc_value) + self.handler.state = TaskHandlerState.Error + self.set_end() + return False + + def __exit__(self, exc_type, exc_value, exc_tb) -> bool: + if exc_type is not None: + return self.handle_error(exc_type, exc_value, exc_tb) + + self.handle_success() + return False + + +class ErrorOnlyStateHandler(StateHandler): + """A state handler that does not touch state on success.""" + def handle_success(self): + pass + + class RequestTaskHandler(threading.Thread): """Handler for the single task triggered by a given jsonrpc request.""" def __init__(self, manager, task, http_request, json_rpc_request): @@ -147,7 +193,6 @@ def __init__(self, manager, task, http_request, json_rpc_request): self.thread: Optional[threading.Thread] = None self.started: Optional[datetime] = None self.ended: Optional[datetime] = None - self.timeout: Optional[float] = None self.task_id: uuid.UUID = uuid.uuid4() # the are multiple threads potentially operating on these attributes: # - the task manager has the RequestTaskHandler and any requests @@ -159,6 +204,8 @@ def __init__(self, manager, task, http_request, json_rpc_request): self.error: Optional[RPCException] = None self.state: TaskHandlerState = TaskHandlerState.NotStarted self.logs: List[LogMessage] = [] + self.task_kwargs: Optional[Dict[str, Any]] = None + self.task_params: Optional[RPCParameters] = None super().__init__( name='{}-handler-{}'.format(self.task_id, self.method), daemon=True, # if the RPC server goes away, we probably should too @@ -180,6 +227,20 @@ def method(self) -> str: def _single_threaded(self): return self.task.args.single_threaded or SINGLE_THREADED_HANDLER + @property + def timeout(self) -> Optional[float]: + if self.task_params is None or self.task_params.timeout is None: + return None + # task_params.timeout is a `Real` for encoding reasons, but we just + # want it as a float. + return float(self.task_params.timeout) + + @property + def tags(self) -> Optional[Dict[str, Any]]: + if self.task_params is None: + return None + return self.task_params.task_tags + def _wait_for_results(self) -> RemoteCallableResult: """Wait for results off the queue. If there is an exception raised, raise an appropriate RPC exception. @@ -218,34 +279,6 @@ def _wait_for_results(self) -> RemoteCallableResult: 'Invalid message type {} (result={})'.format(msg) ) - @contextmanager - def state_handler(self): - try: - try: - yield - finally: - # make sure to set this _before_ updating state - self.ended = datetime.utcnow() - except RPCException as exc: - self.error = exc - self.state = TaskHandlerState.Error - raise # this re-raises for single-threaded operation - except dbt.exceptions.Exception as exc: - self.error = dbt_error(exc) - self.state = TaskHandlerState.Error - raise - except BaseException as exc: - # we should only get here if we got a BaseException that is not an - # Exception (we caught those in _wait_for_results), or a bug in - # get_result's call stack. Either way, we should set an error so we - # can figure out what happened on thread death, and re-raise in - # case it's something python-internal. - self.error = server_error(exc) - self.state = TaskHandlerState.Error - raise - else: - self.state = TaskHandlerState.Success - def get_result(self) -> RemoteCallableResult: if self.process is None: raise dbt.exceptions.InternalException( @@ -263,6 +296,7 @@ def get_result(self) -> RemoteCallableResult: # RPC Exceptions come already preserialized for the jsonrpc # framework exc.logs = [l.to_dict() for l in self.logs] + exc.tags = self.tags raise # results get real logs @@ -271,7 +305,7 @@ def get_result(self) -> RemoteCallableResult: def run(self): try: - with self.state_handler(): + with StateHandler(self): self.result = self.get_result() except RPCException: pass # rpc exceptions are fine, the managing thread will handle it @@ -282,7 +316,7 @@ def handle_singlethreaded(self, kwargs): # note this shouldn't call self.run() as that has different semantics # (we want errors to raise) self.process.run() - with self.state_handler(): + with StateHandler(self): self.result = self.get_result() return self.result @@ -301,24 +335,32 @@ def start(self): self.state = TaskHandlerState.Running super().start() + def _collect_parameters(self): + # both get_parameters and the argparse can raise a TypeError. + cls: Type[RPCParameters] = self.task.get_parameters() + + try: + return cls.from_dict(self.task_kwargs) + except ValidationError as exc: + # raise a TypeError to indicate invalid parameters so we get a nice + # error from our json-rpc library + raise TypeError(exc) from exc + def handle(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: self.started = datetime.utcnow() self.state = TaskHandlerState.Initializing - self.timeout = kwargs.pop('timeout', None) - try: - params = self.task.get_parameters().from_dict(kwargs) - except ValidationError as exc: - # raise a TypeError to indicate invalid parameters - self.state = TaskHandlerState.Error - raise TypeError(exc) - except TypeError: - # we got this from our argument parser, already a nice TypeError - self.state = TaskHandlerState.Error - raise + self.task_kwargs = kwargs + with ErrorOnlyStateHandler(self): + # this will raise a TypeError if you provided bad arguments. + self.task_params = self._collect_parameters() + if self.task_params is None: + raise dbt.exceptions.InternalException( + 'Task params set to None!' + ) self.subscriber = QueueSubscriber() self.process = multiprocessing.Process( target=_task_bootstrap, - args=(self.task, self.subscriber.queue, params) + args=(self.task, self.subscriber.queue, self.task_params) ) if self._single_threaded: diff --git a/core/dbt/rpc/task_manager.py b/core/dbt/rpc/task_manager.py index ebdea474762..a709d9395e0 100644 --- a/core/dbt/rpc/task_manager.py +++ b/core/dbt/rpc/task_manager.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from functools import wraps from typing import ( - Any, Dict, Optional, List, Union, Set, Callable, Iterable, Tuple, Type + Any, Dict, Optional, List, Union, Set, Callable, Iterable, Tuple, Type, ) from hologram import JsonSchemaMixin, ValidationError @@ -15,7 +15,7 @@ import dbt.exceptions from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.results import ( +from dbt.contracts.rpc import ( RemoteCompileResult, RemoteRunResult, RemoteExecutionResult, @@ -58,6 +58,7 @@ class TaskRow(JsonSchemaMixin): end: Optional[datetime] elapsed: Optional[float] timeout: Optional[float] + tags: Optional[Dict[str, Any]] @classmethod def from_task(cls, task_handler: RequestTaskHandler, now_time: datetime): @@ -75,6 +76,7 @@ def from_task(cls, task_handler: RequestTaskHandler, now_time: datetime): if state.finished: elapsed_end = _assert_ended(task_handler) + end = elapsed_end elapsed = (elapsed_end - start).total_seconds() @@ -88,6 +90,7 @@ def from_task(cls, task_handler: RequestTaskHandler, now_time: datetime): end=end, elapsed=elapsed, timeout=task_handler.timeout, + tags=task_handler.tags, ) @@ -105,7 +108,8 @@ class KillResult(JsonSchemaMixin): @dataclass class PollResult(JsonSchemaMixin): - status: TaskHandlerState + tags: Optional[Dict[str, Any]] = None + status: TaskHandlerState = TaskHandlerState.NotStarted class GCResultState(StrEnum): @@ -158,31 +162,47 @@ class _GCArguments(JsonSchemaMixin): settings: Optional[GCSettings] +TaskTags = Optional[Dict[str, Any]] + + @dataclass class PollExecuteSuccessResult(PollResult, RemoteExecutionResult): status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success) + metadata=restrict_to(TaskHandlerState.Success), + default=TaskHandlerState.Success, ) @classmethod - def from_result(cls, status, base): + def from_result( + cls: Type['PollExecuteSuccessResult'], + status: TaskHandlerState, + base: RemoteExecutionResult, + tags: TaskTags, + ) -> 'PollExecuteSuccessResult': return cls( status=status, results=base.results, generated_at=base.generated_at, elapsed_time=base.elapsed_time, logs=base.logs, + tags=tags, ) @dataclass class PollCompileSuccessResult(PollResult, RemoteCompileResult): status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success) + metadata=restrict_to(TaskHandlerState.Success), + default=TaskHandlerState.Success, ) @classmethod - def from_result(cls, status, base): + def from_result( + cls: Type['PollCompileSuccessResult'], + status: TaskHandlerState, + base: RemoteCompileResult, + tags: TaskTags, + ) -> 'PollCompileSuccessResult': return cls( status=status, raw_sql=base.raw_sql, @@ -190,17 +210,24 @@ def from_result(cls, status, base): node=base.node, timing=base.timing, logs=base.logs, + tags=tags, ) @dataclass class PollRunSuccessResult(PollResult, RemoteRunResult): status: TaskHandlerState = field( - metadata=restrict_to(TaskHandlerState.Success) + metadata=restrict_to(TaskHandlerState.Success), + default=TaskHandlerState.Success, ) @classmethod - def from_result(cls, status, base): + def from_result( + cls: Type['PollRunSuccessResult'], + status: TaskHandlerState, + base: RemoteRunResult, + tags: TaskTags, + ) -> 'PollRunSuccessResult': return cls( status=status, raw_sql=base.raw_sql, @@ -209,6 +236,7 @@ def from_result(cls, status, base): timing=base.timing, logs=base.logs, table=base.table, + tags=tags, ) @@ -216,35 +244,43 @@ def from_result(cls, status, base): class PollCatalogSuccessResult(PollResult, RemoteCatalogResults): status: TaskHandlerState = field( metadata=restrict_to(TaskHandlerState.Success), - default=TaskHandlerState.Success + default=TaskHandlerState.Success, ) @classmethod - def from_result(cls, status, base): + def from_result( + cls: Type['PollCatalogSuccessResult'], + status: TaskHandlerState, + base: RemoteCatalogResults, + tags: TaskTags, + ) -> 'PollCatalogSuccessResult': return cls( status=status, nodes=base.nodes, generated_at=base.generated_at, _compile_results=base._compile_results, logs=base.logs, + tags=tags, ) -def poll_success(status, logs, result): +def poll_success( + status: TaskHandlerState, result: Any, tags: TaskTags +) -> PollResult: if status != TaskHandlerState.Success: raise dbt.exceptions.InternalException( 'got invalid result status in poll_success: {}'.format(status) ) if isinstance(result, RemoteExecutionResult): - return PollExecuteSuccessResult.from_result(status=status, base=result) + return PollExecuteSuccessResult.from_result(status, result, tags) # order matters here, as RemoteRunResult subclasses RemoteCompileResult elif isinstance(result, RemoteRunResult): - return PollRunSuccessResult.from_result(status=status, base=result) + return PollRunSuccessResult.from_result(status, result, tags) elif isinstance(result, RemoteCompileResult): - return PollCompileSuccessResult.from_result(status=status, base=result) + return PollCompileSuccessResult.from_result(status, result, tags) elif isinstance(result, RemoteCatalogResults): - return PollCatalogSuccessResult.from_result(status=status, base=result) + return PollCatalogSuccessResult.from_result(status, result, tags) else: raise dbt.exceptions.InternalException( 'got invalid result in poll_success: {}'.format(result) @@ -253,7 +289,7 @@ def poll_success(status, logs, result): @dataclass class PollInProgressResult(PollResult): - logs: List[LogMessage] + logs: List[LogMessage] = field(default_factory=list) @dataclass @@ -421,7 +457,7 @@ def process_poll( ) -> PollResult: task_id = uuid.UUID(request_token) try: - task = self.tasks[task_id] + task: RequestTaskHandler = self.tasks[task_id] except KeyError: # We don't recognize that ID. raise dbt.exceptions.UnknownAsyncIDException(task_id) from None @@ -460,11 +496,15 @@ def process_poll( return poll_success( status=state, - logs=task_logs, result=task.result, + tags=task.tags, ) - return PollInProgressResult(state, task_logs) + return PollInProgressResult( + status=state, + tags=task.tags, + logs=task_logs, + ) def _rpc_builtins(self) -> Dict[str, UnmanagedHandler]: if self._builtins: diff --git a/core/dbt/task/remote.py b/core/dbt/task/remote.py index 4986a08cc57..1de95a13c52 100644 --- a/core/dbt/task/remote.py +++ b/core/dbt/task/remote.py @@ -1,24 +1,32 @@ import shlex import signal import threading -from dataclasses import dataclass from datetime import datetime -from typing import Union, List, Optional - -from hologram import JsonSchemaMixin +from typing import Type import dbt.exceptions import dbt.ui.printer from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks from dbt.compilation import compile_manifest -from dbt.contracts.results import RemoteCatalogResults +from dbt.contracts.rpc import ( + RPCExecParameters, + RPCCompileParameters, + RPCTestParameters, + RPCSeedParameters, + RPCDocsGenerateParameters, + RPCCliParameters, + RemoteCatalogResults, + RemoteExecutionResult, +) from dbt.parser.results import ParseResult from dbt.parser.rpc import RPCCallParser, RPCMacroParser from dbt.parser.util import ParserUtils from dbt.logger import GLOBAL_LOGGER as logger -from dbt.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner -from dbt.rpc.task import RemoteCallableResult, RPCTask +from dbt.rpc.node_runners import ( + RPCCompileRunner, RPCExecuteRunner +) +from dbt.rpc.task import RPCTask, Parameters from dbt.task.generate import GenerateTask from dbt.task.run import RunTask @@ -26,41 +34,7 @@ from dbt.task.test import TestTask -@dataclass -class RPCExecParameters(JsonSchemaMixin): - name: str - sql: str - macros: Optional[str] - - -@dataclass -class RPCCompileProjectParameters(JsonSchemaMixin): - models: Union[None, str, List[str]] = None - exclude: Union[None, str, List[str]] = None - - -@dataclass -class RPCTestProjectParameters(RPCCompileProjectParameters): - data: bool = False - schema: bool = False - - -@dataclass -class RPCSeedProjectParameters(JsonSchemaMixin): - show: bool = False - - -@dataclass -class RPCDocsGenerateProjectParameters(JsonSchemaMixin): - compile: bool = True - - -@dataclass -class RPCCliParameters(JsonSchemaMixin): - cli: str - - -class _RPCExecTask(RPCTask): +class _RPCExecTask(RPCTask[RPCExecParameters]): def runtime_cleanup(self, selected_uids): """Do some pre-run cleanup that is usually performed in Task __init__. """ @@ -133,7 +107,7 @@ def set_args(self, params: RPCExecParameters): self.args.sql = params.sql self.args.macros = params.macros - def handle_request(self) -> RemoteCallableResult: + def handle_request(self) -> RemoteExecutionResult: # we could get a ctrl+c at any time, including during parsing. thread = None started = datetime.utcnow() @@ -149,7 +123,7 @@ def handle_request(self) -> RemoteCallableResult: thread.start() thread_done.wait() except KeyboardInterrupt: - adapter = get_adapter(self.config) + adapter = get_adapter(self.config) # type: ignore if adapter.is_cancelable(): for conn_name in adapter.cancel_open_connections(): @@ -179,6 +153,11 @@ def handle_request(self) -> RemoteCallableResult: class RemoteCompileTask(_RPCExecTask): METHOD_NAME = 'compile_sql' + def handle_request(self) -> RemoteExecutionResult: + # TODO: annotate that this is a RemoteExecutionResult of + # RemoteCompileResults. + return super().handle_request() + def get_runner_type(self): return RPCCompileRunner @@ -186,11 +165,16 @@ def get_runner_type(self): class RemoteRunTask(_RPCExecTask, RunTask): METHOD_NAME = 'run_sql' + def handle_request(self) -> RemoteExecutionResult: + # TODO: annotate that this is a RemoteExecutionResult of + # RemoteRunResult. + return super().handle_request() + def get_runner_type(self): return RPCExecuteRunner -class _RPCCommandTask(RPCTask): +class _RPCCommandTask(RPCTask[Parameters]): def __init__(self, args, config, manifest): super().__init__(args, config, manifest) self.manifest = self._base_manifest @@ -199,47 +183,50 @@ def load_manifest(self): # we started out with a manifest! pass - def handle_request(self) -> RemoteCallableResult: + def handle_request(self) -> RemoteExecutionResult: return self.run() -class RemoteCompileProjectTask(_RPCCommandTask): +class RemoteCompileProjectTask(_RPCCommandTask[RPCCompileParameters]): METHOD_NAME = 'compile' - def set_args(self, params: RPCCompileProjectParameters) -> None: + def set_args(self, params: RPCCompileParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) -class RemoteRunProjectTask(_RPCCommandTask, RunTask): +class RemoteRunProjectTask(_RPCCommandTask[RPCCompileParameters], RunTask): METHOD_NAME = 'run' - def set_args(self, params: RPCCompileProjectParameters) -> None: + def set_args(self, params: RPCCompileParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) -class RemoteSeedProjectTask(_RPCCommandTask, SeedTask): +class RemoteSeedProjectTask(_RPCCommandTask[RPCSeedParameters], SeedTask): METHOD_NAME = 'seed' - def set_args(self, params: RPCSeedProjectParameters) -> None: + def set_args(self, params: RPCSeedParameters) -> None: self.args.show = params.show -class RemoteTestProjectTask(_RPCCommandTask, TestTask): +class RemoteTestProjectTask(_RPCCommandTask[RPCTestParameters], TestTask): METHOD_NAME = 'test' - def set_args(self, params: RPCTestProjectParameters) -> None: + def set_args(self, params: RPCTestParameters) -> None: self.args.models = self._listify(params.models) self.args.exclude = self._listify(params.exclude) self.args.data = params.data self.args.schema = params.schema -class RemoteDocsGenerateProjectTask(_RPCCommandTask, GenerateTask): +class RemoteDocsGenerateProjectTask( + _RPCCommandTask[RPCDocsGenerateParameters], + GenerateTask, +): METHOD_NAME = 'docs.generate' - def set_args(self, params: RPCDocsGenerateProjectParameters) -> None: + def set_args(self, params: RPCDocsGenerateParameters) -> None: self.args.models = None self.args.exclude = None self.args.compile = params.compile @@ -255,7 +242,7 @@ def get_catalog_results( ) -class RemoteRPCParameters(_RPCCommandTask): +class RemoteRPCParameters(_RPCCommandTask[RPCCliParameters]): METHOD_NAME = 'cli_args' def set_args(self, params: RPCCliParameters) -> None: @@ -264,11 +251,12 @@ def set_args(self, params: RPCCliParameters) -> None: split = shlex.split(params.cli) self.args = parse_args(split, RPCArgumentParser) - def get_rpc_task_cls(self): + def get_rpc_task_cls(self) -> Type[_RPCCommandTask]: # This is obnoxious, but we don't have actual access to the TaskManager # so instead we get to dig through all the subclasses of RPCTask # (recursively!) looking for a matching METHOD_NAME - for candidate in RPCTask.recursive_subclasses(): + candidate: Type[_RPCCommandTask] + for candidate in _RPCCommandTask.recursive_subclasses(): if candidate.METHOD_NAME == self.args.rpc_method: return candidate # this shouldn't happen @@ -277,7 +265,7 @@ def get_rpc_task_cls(self): .format(self.args.rpc_method, self.args.which) ) - def handle_request(self) -> JsonSchemaMixin: + def handle_request(self) -> RemoteExecutionResult: cls = self.get_rpc_task_cls() # we parsed args from the cli, so we're set on that front task = cls(self.args, self.config, self.manifest) diff --git a/test/integration/048_rpc_test/test_rpc.py b/test/integration/048_rpc_test/test_rpc.py index a00bcebd78f..8c3ede85f8b 100644 --- a/test/integration/048_rpc_test/test_rpc.py +++ b/test/integration/048_rpc_test/test_rpc.py @@ -12,6 +12,7 @@ from pytest import mark from test.integration.base import DBTIntegrationTest, use_profile +from dbt.version import __version__ from dbt.logger import log_manager from dbt.main import handle_and_check @@ -171,7 +172,7 @@ def url(self): def poll_for_result(self, request_token, request_id=1, timeout=60): start = time.time() - while timeout is None or ((time.time() - start) < timeout): + while True: time.sleep(0.5) response = self.query('poll', request_token=request_token, _test_request_id=request_id) response_json = response.json() @@ -181,6 +182,9 @@ def poll_for_result(self, request_token, request_id=1, timeout=60): self.assertIn('status', result) if result['status'] == 'success': return response + if timeout is not None: + self.assertGreater(timeout, (time.time() - start)) + def async_query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs): response = self.query(_method, _sql, _test_request_id, macros, **kwargs).json() @@ -497,8 +501,16 @@ def _get_sleep_query(self, duration=15, request_id=90890): @mark.flaky(rerun_filter=None) @use_profile('postgres') def test_ps_kill_postgres(self): - done_query = self.async_query('compile_sql', 'select 1 as id', name='done').json() - self.assertIsResult(done_query) + task_tags = { + 'dbt_version': __version__, + 'my_custom_tag': True, + } + done_query = self.async_query( + 'compile_sql', 'select 1 as id', name='done', task_tags=task_tags + ).json() + done_result = self.assertIsResult(done_query) + self.assertIn('tags', done_result) + self.assertEqual(done_result['tags'], task_tags) request_token, request_id = self._get_sleep_query() @@ -516,6 +528,7 @@ def test_ps_kill_postgres(self): self.assertIsNone(rowdict[0]['timeout']) self.assertEqual(rowdict[0]['task_id'], request_token) self.assertGreater(rowdict[0]['elapsed'], 0) + self.assertIsNone(rowdict[0]['tags']) complete_ps_result = self.query('ps', completed=True, active=False).json() result = self.assertIsResult(complete_ps_result) @@ -526,6 +539,7 @@ def test_ps_kill_postgres(self): self.assertEqual(rowdict[0]['state'], 'success') self.assertIsNone(rowdict[0]['timeout']) self.assertGreater(rowdict[0]['elapsed'], 0) + self.assertEqual(rowdict[0]['tags'], task_tags) all_ps_result = self.query('ps', completed=True, active=True).json() result = self.assertIsResult(all_ps_result) @@ -537,11 +551,13 @@ def test_ps_kill_postgres(self): self.assertEqual(rowdict[0]['state'], 'success') self.assertIsNone(rowdict[0]['timeout']) self.assertGreater(rowdict[0]['elapsed'], 0) + self.assertEqual(rowdict[0]['tags'], task_tags) self.assertEqual(rowdict[1]['request_id'], request_id) self.assertEqual(rowdict[1]['method'], 'run_sql') self.assertEqual(rowdict[1]['state'], 'running') self.assertIsNone(rowdict[1]['timeout']) self.assertGreater(rowdict[1]['elapsed'], 0) + self.assertIsNone(rowdict[1]['tags']) # try to GC our running task gc_response = self.query('gc', task_ids=[request_token]).json() @@ -561,11 +577,13 @@ def test_ps_kill_postgres(self): self.assertEqual(rowdict[0]['state'], 'success') self.assertIsNone(rowdict[0]['timeout']) self.assertGreater(rowdict[0]['elapsed'], 0) + self.assertEqual(rowdict[0]['tags'], task_tags) self.assertEqual(rowdict[1]['request_id'], request_id) self.assertEqual(rowdict[1]['method'], 'run_sql') self.assertEqual(rowdict[1]['state'], 'error') self.assertIsNone(rowdict[1]['timeout']) self.assertGreater(rowdict[1]['elapsed'], 0) + self.assertIsNone(rowdict[1]['tags']) def kill_and_assert(self, request_token, request_id): kill_response = self.query('kill', task_id=request_token).json() @@ -609,13 +627,15 @@ def test_invalid_requests_postgres(self): data = self.async_query( 'compile_sql', 'select * from {{ reff("nonsource_descendant") }}', - name='mymodel' + name='mymodel', + task_tags={'some_tag': True, 'another_tag': 'blah blah blah'} ).json() error_data = self.assertIsErrorWith(data, 10004, 'Compilation Error', { 'type': 'CompilationException', 'message': "Compilation Error in rpc mymodel (from remote system)\n 'reff' is undefined", 'compiled_sql': None, 'raw_sql': 'select * from {{ reff("nonsource_descendant") }}', + 'tags': {'some_tag': True, 'another_tag': 'blah blah blah'} }) self.assertIn('logs', error_data) self.assertTrue(len(error_data['logs']) > 0) @@ -670,7 +690,7 @@ def test_timeout_postgres(self): self.assertIn('timeout', error_data) self.assertEqual(error_data['timeout'], 1) self.assertIn('message', error_data) - self.assertEqual(error_data['message'], 'RPC timed out after 1s') + self.assertEqual(error_data['message'], 'RPC timed out after 1.0s') self.assertIn('logs', error_data) self.assertTrue(len(error_data['logs']) > 0) diff --git a/tox.ini b/tox.ini index 9ac72d6586c..278f59265a8 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ commands = /bin/bash -c '$(which mypy) \ core/dbt/adapters/cache.py \ core/dbt/clients \ core/dbt/config \ + core/dbt/contracts/rpc.py \ core/dbt/deprecations.py \ core/dbt/exceptions.py \ core/dbt/flags.py \ @@ -42,6 +43,7 @@ commands = /bin/bash -c '$(which mypy) \ core/dbt/task/generate.py \ core/dbt/task/init.py \ core/dbt/task/list.py \ + core/dbt/task/remote.py \ core/dbt/task/run_operation.py \ core/dbt/task/runnable.py \ core/dbt/task/seed.py \