From 3f948ae501659045c79516130a87abf9e4d33b3c Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 6 Mar 2019 13:43:13 -0700 Subject: [PATCH 1/9] Error handling improvements All dbt errors now have proper error codes/messages The raised message at runtime ends up in result.error.data.message The raised message type at runtime ends up in result.error.data.typename result.error.message is a plaintext name for result.error.code dbt.exceptions.Exception.data() becomes result.error.data Collect dbt logs and make them available to requests/responses --- core/dbt/compat.py | 5 +- core/dbt/exceptions.py | 58 +++++-- core/dbt/logger.py | 28 +++- core/dbt/node_runners.py | 141 +++++++++++------- core/dbt/rpc.py | 55 +++++++ core/dbt/task/base.py | 2 +- core/dbt/task/compile.py | 2 +- core/dbt/task/rpc_server.py | 121 ++++++++++++++- core/dbt/task/runnable.py | 70 ++------- .../042_sources_test/test_sources.py | 39 +++-- 10 files changed, 367 insertions(+), 154 deletions(-) create mode 100644 core/dbt/rpc.py diff --git a/core/dbt/compat.py b/core/dbt/compat.py index 2548476a124..50f9c217914 100644 --- a/core/dbt/compat.py +++ b/core/dbt/compat.py @@ -2,7 +2,6 @@ import abc import codecs -import json import warnings import decimal @@ -35,12 +34,12 @@ if WHICH_PYTHON == 2: from SimpleHTTPServer import SimpleHTTPRequestHandler from SocketServer import TCPServer - from Queue import PriorityQueue + from Queue import PriorityQueue, Empty as QueueEmpty from thread import get_ident else: from http.server import SimpleHTTPRequestHandler from socketserver import TCPServer - from queue import PriorityQueue + from queue import PriorityQueue, Empty as QueueEmpty from threading import get_ident diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 7885d0b771e..ab459389047 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -1,11 +1,18 @@ -from dbt.compat import basestring, builtins +from dbt.compat import builtins from dbt.logger import GLOBAL_LOGGER as logger import dbt.flags -import re class Exception(builtins.Exception): - pass + CODE = -32000 + MESSAGE = "Server Error" + + def data(self): + # if overriding, make sure the result is json-serializable. + return { + 'type': self.__class__.__name__, + 'message': str(self), + } class MacroReturn(builtins.BaseException): @@ -21,11 +28,10 @@ class InternalException(Exception): pass -class RPCException(Exception): - pass - - class RuntimeException(RuntimeError, Exception): + CODE = 10001 + MESSAGE = "Runtime error" + def __init__(self, msg, node=None): self.stack = [] self.node = node @@ -86,7 +92,28 @@ def __str__(self, prefix="! "): [" " + line for line in lines[1:]]) +class RPCFailureResult(RuntimeException): + CODE = 10002 + MESSAGE = "RPC execution error" + + +class RPCTimeoutException(RuntimeException): + CODE = 10008 + MESSAGE = 'RPC timeout error' + + def __init__(self, timeout): + self.timeout = timeout + + def data(self): + return { + 'timeout': self.timeout, + 'message': 'RPC timed out after {}s'.format(self.timeout), + } + + class DatabaseException(RuntimeException): + CODE = 10003 + MESSAGE = "Database Error" def process_stack(self): lines = [] @@ -103,6 +130,9 @@ def type(self): class CompilationException(RuntimeException): + CODE = 10004 + MESSAGE = "Compilation Error" + @property def type(self): return 'Compilation' @@ -113,7 +143,8 @@ class RecursionException(RuntimeException): class ValidationException(RuntimeException): - pass + CODE = 10005 + MESSAGE = "Validation Error" class JSONValidationException(ValidationException): @@ -134,15 +165,16 @@ class AliasException(ValidationException): pass -class ParsingException(Exception): - pass - - class DependencyException(Exception): - pass + # this can happen due to raise_dependency_error and its callers + CODE = 10006 + MESSAGE = "Dependency Error" class DbtConfigError(RuntimeException): + CODE = 10007 + MESSAGE = "DBT Configuration Error" + def __init__(self, message, project=None, result_type='invalid_project'): self.project = project super(DbtConfigError, self).__init__(message) diff --git a/core/dbt/logger.py b/core/dbt/logger.py index 725f42fc771..873fc061866 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -75,6 +75,10 @@ def filter(self, record): return True +def default_formatter(): + return logging.Formatter('%(asctime)-18s (%(threadName)s): %(message)s') + + def initialize_logger(debug_mode=False, path=None): global initialized, logger, stdout_handler @@ -82,8 +86,7 @@ def initialize_logger(debug_mode=False, path=None): return if debug_mode: - stdout_handler.setFormatter( - logging.Formatter('%(asctime)-18s (%(threadName)s): %(message)s')) + stdout_handler.setFormatter(default_formatter()) stdout_handler.setLevel(logging.DEBUG) if path is not None: @@ -101,8 +104,7 @@ def initialize_logger(debug_mode=False, path=None): color_filter = ColorFilter() logdir_handler.addFilter(color_filter) - logdir_handler.setFormatter( - logging.Formatter('%(asctime)-18s (%(threadName)s): %(message)s')) + logdir_handler.setFormatter(default_formatter()) logdir_handler.setLevel(logging.DEBUG) logger.addHandler(logdir_handler) @@ -126,3 +128,21 @@ def log_cache_events(flag): GLOBAL_LOGGER = logger + + +class QueueLogHandler(logging.Handler): + def __init__(self, queue): + super(QueueLogHandler, self).__init__() + self.queue = queue + + def emit(self, record): + msg = self.format(record) + self.queue.put_nowait(['log', msg]) + + +def add_queue_handler(queue): + """Add a queue log handler to the global logger.""" + handler = QueueLogHandler(queue) + handler.setFormatter(default_formatter()) + handler.setLevel(logging.DEBUG) + GLOBAL_LOGGER.addHandler(handler) diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index c0d2d26e76f..dc5c71842a6 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -15,6 +15,7 @@ import dbt.flags import dbt.schema import dbt.writer +from dbt import rpc import threading import time @@ -44,6 +45,15 @@ def track_model_run(index, num_nodes, run_model_result): }) +class ExecutionContext(object): + """During execution and error handling, dbt makes use of mutable state: + timing information and the newest (compiled vs executed) form of the node. + """ + def __init__(self, node): + self.timing = [] + self.node = node + + class BaseRunner(object): def __init__(self, config, adapter, node, node_index, num_nodes): self.config = config @@ -115,71 +125,78 @@ def from_run_result(self, result, start_time, timing_info): timing_info=timing_info ) - def safe_run(self, manifest): - catchable_errors = (CompilationException, RuntimeException) - - # result = self.DefaultResult(self.node) - started = time.time() - timing = [] - error = None - node = self.node + def compile_and_execute(self, manifest, ctx): result = None + self.adapter.acquire_connection(self.node.get('name')) + with collect_timing_info('compile') as timing_info: + # if we fail here, we still have a compiled node to return + # this has the benefit of showing a build path for the errant + # model + ctx.node = self.compile(manifest) + ctx.timing.append(timing_info) + + # for ephemeral nodes, we only want to compile, not run + if not ctx.node.is_ephemeral_model: + with collect_timing_info('execute') as timing_info: + result = self.run(ctx.node, manifest) + ctx.node = result.node + + ctx.timing.append(timing_info) - try: - self.adapter.acquire_connection(self.node.get('name')) - with collect_timing_info('compile') as timing_info: - # if we fail here, we still have a compiled node to return - # this has the benefit of showing a build path for the errant - # model - node = self.compile(manifest) - - timing.append(timing_info) - - # for ephemeral nodes, we only want to compile, not run - if not node.is_ephemeral_model: - with collect_timing_info('execute') as timing_info: - result = self.run(node, manifest) - node = result.node + return result - timing.append(timing_info) + def _handle_catchable_exception(self, e, ctx): + if e.node is None: + e.node = ctx.node - # result.extend(item.serialize() for item in timing) + return dbt.compat.to_string(e) - except catchable_errors as e: - if e.node is None: - e.node = node + def _handle_internal_exception(self, e, ctx): + build_path = self.node.build_path + prefix = 'Internal error executing {}'.format(build_path) - error = dbt.compat.to_string(e) + error = "{prefix}\n{error}\n\n{note}".format( + prefix=dbt.ui.printer.red(prefix), + error=str(e).strip(), + note=INTERNAL_ERROR_STRING) + logger.debug(error) + return dbt.compat.to_string(e) - except InternalException as e: - build_path = self.node.build_path - prefix = 'Internal error executing {}'.format(build_path) + def _handle_generic_exception(self, e, ctx): + node_description = self.node.get('build_path') + if node_description is None: + node_description = self.node.unique_id + prefix = "Unhandled error while executing {description}".format( + description=node_description) - error = "{prefix}\n{error}\n\n{note}".format( - prefix=dbt.ui.printer.red(prefix), - error=str(e).strip(), - note=INTERNAL_ERROR_STRING - ) - logger.debug(error) - error = dbt.compat.to_string(e) + error = "{prefix}\n{error}".format( + prefix=dbt.ui.printer.red(prefix), + error=str(e).strip()) - except Exception as e: - node_description = self.node.get('build_path') - if node_description is None: - node_description = self.node.unique_id - prefix = "Unhandled error while executing {description}".format( - description=node_description - ) + logger.error(error) + logger.debug('', exc_info=True) + return dbt.compat.to_string(e) - error = "{prefix}\n{error}".format( - prefix=dbt.ui.printer.red(prefix), - error=str(e).strip() - ) + def handle_exception(self, e, ctx): + catchable_errors = (CompilationException, RuntimeException) + if isinstance(e, catchable_errors): + error = self._handle_catchable_exception(e, ctx) + elif isinstance(e, InternalException): + error = self._handle_internal_exception(e, ctx) + else: + error = self._handle_generic_exception(e, ctx) + return error - logger.error(error) - logger.debug('', exc_info=True) - error = dbt.compat.to_string(e) + def safe_run(self, manifest): + started = time.time() + ctx = ExecutionContext(self.node) + error = None + result = None + try: + result = self.compile_and_execute(manifest, ctx) + except Exception as e: + error = self.handle_exception(e, ctx) finally: exc_str = self._safe_release_connection() @@ -190,11 +207,11 @@ def safe_run(self, manifest): if error is not None: # we could include compile time for runtime errors here - result = self.error_result(node, error, started, []) + result = self.error_result(ctx.node, error, started, []) elif result is not None: - result = self.from_run_result(result, started, timing) + result = self.from_run_result(result, started, ctx.timing) else: - result = self.ephemeral_result(node, started, timing) + result = self.ephemeral_result(ctx.node, started, ctx.timing) return result def _safe_release_connection(self): @@ -505,6 +522,14 @@ def __init__(self, config, adapter, node, node_index, num_nodes): super(RPCCompileRunner, self).__init__(config, adapter, node, node_index, num_nodes) + def handle_exception(self, e, ctx): + if isinstance(e, dbt.exceptions.Exception): + return rpc.dbt_error(e) + elif isinstance(e, rpc.RPCException): + return e + else: + return rpc.server_error(e) + def before_execute(self): pass @@ -522,10 +547,10 @@ def execute(self, compiled_node, manifest): ) def error_result(self, node, error, start_time, timing_info): - raise dbt.exceptions.RPCException(error) + raise error def ephemeral_result(self, node, start_time, timing_info): - raise dbt.exceptions.NotImplementedException( + raise NotImplementedException( 'cannot execute ephemeral nodes remotely!' ) diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py new file mode 100644 index 00000000000..88bca400d5d --- /dev/null +++ b/core/dbt/rpc.py @@ -0,0 +1,55 @@ +from jsonrpc.exceptions import JSONRPCDispatchException, JSONRPCInvalidParams + +import dbt.exceptions + + +class RPCException(JSONRPCDispatchException): + def __init__(self, code=None, message=None, data=None, logs=None): + if code is None: + code = -32000 + if message is None: + message = 'Server error' + if data is None: + data = {} + + super(RPCException, self).__init__(code=code, message=message, + data=data) + self.logs = logs + + @property + def logs(self): + return self.error.data.get('logs') + + @logs.setter + def logs(self, value): + if value is None: + return + self.error.data['logs'] = value + + @classmethod + def from_error(cls, err): + return cls(err.code, err.message, err.data, err.data.get('logs')) + + +def invalid_params(err, logs): + return RPCException( + code=JSONRPCInvalidParams.code, + message=JSONRPCInvalidParams.MESSAGE, + data={'logs': logs} + ) + + +def server_error(err, logs=None): + exc = dbt.exceptions.Exception(str(err)) + return dbt_error(exc, logs) + + +def timeout_error(timeout_value, logs=None): + exc = dbt.exceptions.RPCTimeoutException(timeout_value) + return dbt_error(exc, logs) + + +def dbt_error(exc, logs=None): + exc = RPCException(code=exc.CODE, message=exc.MESSAGE, data=exc.data(), + logs=logs) + return exc diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index a592a941af1..68c1f230a21 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -7,7 +7,7 @@ from dbt.config.profile import read_profile, PROFILES_DIR from dbt import tracking from dbt.logger import GLOBAL_LOGGER as logger -from dbt.utils import to_string +from dbt.compat import to_string import dbt.exceptions diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 1514a5f7ebc..f223d8534a2 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -36,7 +36,7 @@ class RemoteCompileTask(CompileTask, RemoteCallable): METHOD_NAME = 'compile' def __init__(self, args, config): - super(CompileTask, self).__init__(args, config) + super(RemoteCompileTask, self).__init__(args, config) self.parser = None self._base_manifest = GraphLoader.load_all( config, diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py index e78769eccb5..90b61a56445 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc_server.py @@ -1,16 +1,126 @@ import json -import os +import multiprocessing +import time from jsonrpc import Dispatcher, JSONRPCResponseManager from werkzeug.wrappers import Request, Response from werkzeug.serving import run_simple -from dbt.logger import RPC_LOGGER as logger +from dbt.logger import RPC_LOGGER as logger, add_queue_handler from dbt.task.base import ConfiguredTask from dbt.task.compile import CompileTask, RemoteCompileTask from dbt.task.run import RemoteRunTask from dbt.utils import JSONEncoder +from dbt.compat import QueueEmpty +import dbt.exceptions +from dbt import rpc + + +class RequestTaskHandler(object): + def __init__(self, task): + self.task = task + self.queue = None + self.process = None + self.started = None + self.timeout = None + self.logs = [] + + def _next_timeout(self): + if self.timeout is None: + return None + end = self.started + self.timeout + timeout = end - time.time() + if timeout < 0: + raise dbt.exceptions.RPCTimeoutException(self.timeout) + return timeout + + def _wait_for_results(self): + """Wait for results off the queue. If there is a timeout set, and it is + exceeded, raise an RPCTimeoutException. + """ + while True: + get_timeout = self._next_timeout() + try: + msgtype, value = self.queue.get(timeout=get_timeout) + except QueueEmpty: + raise dbt.exceptions.RPCTimeoutException(self.timeout) + + if msgtype == 'log': + self.logs.append(value) + else: + return msgtype, value + + def _join_process(self): + try: + msgtype, result = self._wait_for_results() + except dbt.exceptions.RPCTimeoutException as exc: + self.process.terminate() + raise rpc.timeout_error(self.timeout) + except dbt.exceptions.Exception as exc: + raise rpc.dbt_error(exc) + except Exception as exc: + raise rpc.server_error(exc) + finally: + self.process.join() + + self.process = None + self.queue = None + + if msgtype == 'error': + raise rpc.RPCException.from_error(result) + + return result + + def get_result(self): + try: + result = self._join_process() + except rpc.RPCException as exc: + exc.logs = self.logs + raise + + result['logs'] = self.logs + return result + + def task_bootstrap(self, kwargs): + # the first thing we do in a new process: start logging + add_queue_handler(self.queue) + + error = None + result = None + try: + result = self.task.handle_request(**kwargs) + except rpc.RPCException as exc: + error = exc + except dbt.exceptions.Exception as exc: + logger.debug('dbt runtime exception', exc_info=True) + error = rpc.dbt_error(exc) + except Exception as exc: + logger.debug('uncaught python exception', exc_info=True) + error = rpc.server_error(exc) + + # put whatever result we got onto the queue as well. + if error is not None: + self.queue.put(['error', error.error]) + else: + self.queue.put(['result', result]) + + def handle(self, kwargs): + self.started = time.time() + self.timeout = kwargs.pop('timeout', None) + self.queue = multiprocessing.Queue() + self.process = multiprocessing.Process( + target=self.task_bootstrap, + args=(kwargs,) + ) + self.process.start() + return self.get_result() + + @classmethod + def factory(cls, task): + def handler(**kwargs): + return cls(task).handle(kwargs) + return handler class RPCServerTask(ConfiguredTask): @@ -25,7 +135,7 @@ def __init__(self, args, config, tasks=None): self.register(cls(args, config)) def register(self, task): - self.dispatcher.add_method(task.safe_handle_request, + self.dispatcher.add_method(RequestTaskHandler.factory(task), name=task.METHOD_NAME) @property @@ -60,10 +170,7 @@ def run(self): def handle_request(self, request): msg = 'Received request ({0}) from {0.remote_addr}, data={0.data}' logger.info(msg.format(request)) - # request_data is the request as a parsedjson object - response = JSONRPCResponseManager.handle( - request.data, self.dispatcher - ) + response = JSONRPCResponseManager.handle(request.data, self.dispatcher) json_data = json.dumps(response.data, cls=JSONEncoder) response = Response(json_data, mimetype='application/json') # this looks and feels dumb, but our json encoder converts decimals and diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index cfc4d22b69b..c749cf82642 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -3,15 +3,14 @@ import re import time from abc import abstractmethod -from multiprocessing import Process, Pipe from multiprocessing.dummy import Pool as ThreadPool +from jsonrpc.exceptions import JSONRPCInvalidParams -import six - +from dbt import rpc from dbt.task.base import ConfiguredTask from dbt.adapters.factory import get_adapter from dbt.logger import GLOBAL_LOGGER as logger -from dbt.compat import abstractclassmethod, to_unicode +from dbt.compat import to_unicode from dbt.compilation import compile_manifest from dbt.contracts.graph.manifest import CompileResultNode from dbt.contracts.results import ExecutionResult @@ -346,50 +345,6 @@ def handle_request(self, **kwargs): 'from_kwargs not implemented' ) - def _subprocess_handle_request(self, conn, **kwargs): - error = None - result = None - try: - result = self.handle_request(**kwargs) - except dbt.exceptions.RuntimeException as exc: - logger.debug('dbt runtime exception', - exc_info=True) - # we have to convert this to a string for RPC responses - error = str(exc) - except dbt.exceptions.RPCException as exc: - error = str(exc) - except Exception as exc: - logger.debug('uncaught python exception', - exc_info=True) - error = str(exc) - conn.send([result, error]) - conn.close() - - def safe_handle_request(self, **kwargs): - # assumption here: we are within a thread/process already and can block - # however we like to enforce the timeout - timeout = kwargs.pop('timeout', None) - parent_conn, child_conn = Pipe() - proc = Process( - target=self._subprocess_handle_request, - args=(child_conn,), - kwargs=kwargs - ) - proc.start() - if parent_conn.poll(timeout): - result, error = parent_conn.recv() - else: - error = 'timed out after {}s'.format(timeout) - proc.terminate() - - parent_conn.close() - - proc.join() - if error: - raise dbt.exceptions.RPCException(error) - else: - return result - def decode_sql(self, sql): """Base64 decode a string. This should only be used for sql in calls. @@ -402,15 +357,22 @@ def decode_sql(self, sql): # in python3.x you can pass `validate=True` to b64decode to get this # behavior. if not re.match(b'^[A-Za-z0-9+/]*={0,2}$', base64_sql_bytes): - raise dbt.exceptions.RPCException( - 'invalid base64-encoded sql input: {!s}'.format(sql) - ) + self.raise_invalid_base64(sql) try: sql_bytes = base64.b64decode(base64_sql_bytes) except ValueError as exc: - raise dbt.exceptions.RPCException( - 'invalid base64-encoded sql input: {!s}'.format(exc) - ) + self.raise_invalid_base64(sql) return sql_bytes.decode('utf-8') + + @staticmethod + def raise_invalid_base64(sql): + raise rpc.invalid_params( + code=JSONRPCInvalidParams.CODE, + message=JSONRPCInvalidParams.MESSAGE, + data={ + 'message': 'invalid base64-encoded sql input', + 'sql': str(sql), + } + ) diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index 111feb124c6..fd6867b84ab 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -374,6 +374,8 @@ def assertResultHasSql(self, data, raw_sql, compiled_sql=None): if compiled_sql is None: compiled_sql = raw_sql result = self.assertIsResult(data) + self.assertIn('logs', result) + self.assertTrue(len(result['logs']) > 0) self.assertIn('raw_sql', result) self.assertIn('compiled_sql', result) self.assertEqual(result['raw_sql'], raw_sql) @@ -495,28 +497,34 @@ def test_invalid_requests(self): 'select * from {{ reff("nonsource_descendant") }}', name='mymodel' ).json() - error = self.assertIsErrorWithCode(data, -32000) - self.assertEqual(error['message'], 'Server error') + error = self.assertIsErrorWithCode(data, 10004) + self.assertEqual(error['message'], 'Compilation Error') self.assertIn('data', error) - self.assertEqual(error['data']['type'], 'RPCException') + error_data = error['data'] + self.assertEqual(error_data['type'], 'CompilationException') self.assertEqual( - error['data']['message'], + error_data['message'], "Compilation Error in rpc mymodel (from remote system)\n 'reff' is undefined" ) + self.assertIn('logs', error_data) + self.assertTrue(len(error_data['logs']) > 0) data = self.query( 'run', 'hi this is not sql', name='foo' ).json() - error = self.assertIsErrorWithCode(data, -32000) - self.assertEqual(error['message'], 'Server error') + error = self.assertIsErrorWithCode(data, 10003) + self.assertEqual(error['message'], 'Database Error') self.assertIn('data', error) - self.assertEqual(error['data']['type'], 'RPCException') + error_data = error['data'] + self.assertEqual(error_data['type'], 'DatabaseException') self.assertEqual( - error['data']['message'], - 'Database Error in rpc foo (from remote system)\n syntax error at or near "hi"\n LINE 1: hi this is not sql\n ^' + error_data['message'], + 'Database Error\n syntax error at or near "hi"\n LINE 1: hi this is not sql\n ^' ) + self.assertIn('logs', error_data) + self.assertTrue(len(error_data['logs']) > 0) @use_profile('postgres') def test_timeout(self): @@ -526,8 +534,13 @@ def test_timeout(self): name='foo', timeout=1 ).json() - error = self.assertIsErrorWithCode(data, -32000) - self.assertEqual(error['message'], 'Server error') + error = self.assertIsErrorWithCode(data, 10008) + self.assertEqual(error['message'], 'RPC timeout error') self.assertIn('data', error) - self.assertEqual(error['data']['type'], 'RPCException') - self.assertEqual(error['data']['message'], 'timed out after 1s') + error_data = error['data'] + self.assertIn('timeout', error_data) + self.assertEqual(error_data['timeout'], 1) + self.assertIn('message', error_data) + self.assertEqual(error_data['message'], 'RPC timed out after 1s') + self.assertIn('logs', error_data) + self.assertTrue(len(error_data['logs']) > 0) From 7e181280b397635c9b16b98eb21e77f23b8dff1e Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 8 Mar 2019 08:19:24 -0700 Subject: [PATCH 2/9] PR feedback: QueueMessageType class, remove extra assignments --- core/dbt/rpc.py | 13 +++++++++++++ core/dbt/task/rpc_server.py | 17 +++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index 88bca400d5d..ce598b8866d 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -53,3 +53,16 @@ def dbt_error(exc, logs=None): exc = RPCException(code=exc.CODE, message=exc.MESSAGE, data=exc.data(), logs=logs) return exc + + +class QueueMessageType(object): + Error = 'error' + Result = 'result' + Log = 'log' + + @classmethod + def terminating(cls): + return [ + cls.Error, + cls.Result + ] diff --git a/core/dbt/task/rpc_server.py b/core/dbt/task/rpc_server.py index 90b61a56445..3ef909cd782 100644 --- a/core/dbt/task/rpc_server.py +++ b/core/dbt/task/rpc_server.py @@ -46,10 +46,14 @@ def _wait_for_results(self): except QueueEmpty: raise dbt.exceptions.RPCTimeoutException(self.timeout) - if msgtype == 'log': + if msgtype == rpc.QueueMessageType.Log: self.logs.append(value) - else: + elif msgtype in rpc.QueueMessageType.terminating(): return msgtype, value + else: + raise dbt.exceptions.InternalException( + 'Got invalid queue message type {}'.format(msgtype) + ) def _join_process(self): try: @@ -64,10 +68,7 @@ def _join_process(self): finally: self.process.join() - self.process = None - self.queue = None - - if msgtype == 'error': + if msgtype == rpc.QueueMessageType.Error: raise rpc.RPCException.from_error(result) return result @@ -101,9 +102,9 @@ def task_bootstrap(self, kwargs): # put whatever result we got onto the queue as well. if error is not None: - self.queue.put(['error', error.error]) + self.queue.put([rpc.QueueMessageType.Error, error.error]) else: - self.queue.put(['result', result]) + self.queue.put([rpc.QueueMessageType.Result, result]) def handle(self, kwargs): self.started = time.time() From 6620a3cd90e34a92b1ef54ed43edbf6b2093f27a Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 8 Mar 2019 10:14:01 -0700 Subject: [PATCH 3/9] wrap all context-raised exceptions in node info Fixes "called by " --- core/dbt/context/common.py | 2 +- core/dbt/exceptions.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 53e120a8985..ca4f71ff209 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -383,7 +383,7 @@ def generate_base(model, model_dict, config, manifest, source_config, "config": provider.Config(model_dict, source_config), "database": config.credentials.database, "env_var": env_var, - "exceptions": dbt.exceptions.CONTEXT_EXPORTS, + "exceptions": dbt.exceptions.wrapped_exports(model), "execute": provider.execute, "flags": dbt.flags, # TODO: Do we have to leave this in? diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index ab459389047..ae1fecdd1a6 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -1,3 +1,7 @@ +import sys +import six +import functools + from dbt.compat import builtins from dbt.logger import GLOBAL_LOGGER as logger import dbt.flags @@ -648,3 +652,26 @@ def warn_or_error(msg, node=None, log_fmt=None): relation_wrong_type, ] } + + +def wrapper(model): + def wrap(func): + @functools.wraps(func) + def inner(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: + exc_type, exc, exc_tb = sys.exc_info() + if hasattr(exc, 'node') and exc.node is None: + exc.node = model + six.reraise(exc_type, exc, exc_tb) + + return inner + return wrap + + +def wrapped_exports(model): + wrap = wrapper(model) + return { + name: wrap(export) for name, export in CONTEXT_EXPORTS.items() + } From d890642c280c700c45b034559b6aa5b1cbd0b495 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 8 Mar 2019 10:54:10 -0700 Subject: [PATCH 4/9] add NOTICE level logging, make log messages richer types --- core/dbt/logger.py | 72 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/core/dbt/logger.py b/core/dbt/logger.py index 873fc061866..af9efd1cbee 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -4,11 +4,9 @@ import logging.handlers import os import sys -import warnings import colorama - # Colorama needs some help on windows because we're using logger.info # intead of print(). If the Windows env doesn't have a TERM var set, # then we should override the logging stream to use the colorama @@ -17,6 +15,27 @@ colorama_stdout = sys.stdout colorama_wrap = True +colorama.init(wrap=colorama_wrap) + +DEBUG = logging.DEBUG +NOTICE = 15 +INFO = logging.INFO +WARNING = logging.WARNING +ERROR = logging.ERROR +CRITICAL = logging.CRITICAL + +logging.addLevelName(NOTICE, 'NOTICE') + + +class Logger(logging.Logger): + def notice(self, msg, *args, **kwargs): + if self.isEnabledFor(NOTICE): + self._log(NOTICE, msg, args, **kwargs) + + +logging.setLoggerClass(Logger) + + if sys.platform == 'win32' and not os.environ.get('TERM'): colorama_wrap = False colorama_stdout = colorama.AnsiToWin32(sys.stdout).stream @@ -29,22 +48,22 @@ # create a global console logger for dbt stdout_handler = logging.StreamHandler(colorama_stdout) stdout_handler.setFormatter(logging.Formatter('%(message)s')) -stdout_handler.setLevel(logging.INFO) +stdout_handler.setLevel(INFO) logger = logging.getLogger('dbt') logger.addHandler(stdout_handler) -logger.setLevel(logging.DEBUG) -logging.getLogger().setLevel(logging.CRITICAL) +logger.setLevel(DEBUG) +logging.getLogger().setLevel(CRITICAL) # Quiet these down in the logs -logging.getLogger('botocore').setLevel(logging.INFO) -logging.getLogger('requests').setLevel(logging.INFO) -logging.getLogger('urllib3').setLevel(logging.INFO) -logging.getLogger('google').setLevel(logging.INFO) -logging.getLogger('snowflake.connector').setLevel(logging.INFO) -logging.getLogger('parsedatetime').setLevel(logging.INFO) +logging.getLogger('botocore').setLevel(INFO) +logging.getLogger('requests').setLevel(INFO) +logging.getLogger('urllib3').setLevel(INFO) +logging.getLogger('google').setLevel(INFO) +logging.getLogger('snowflake.connector').setLevel(INFO) +logging.getLogger('parsedatetime').setLevel(INFO) # we never want to seek werkzeug logs -logging.getLogger('werkzeug').setLevel(logging.CRITICAL) +logging.getLogger('werkzeug').setLevel(CRITICAL) # provide this for the cache. CACHE_LOGGER = logging.getLogger('dbt.cache') @@ -87,7 +106,7 @@ def initialize_logger(debug_mode=False, path=None): if debug_mode: stdout_handler.setFormatter(default_formatter()) - stdout_handler.setLevel(logging.DEBUG) + stdout_handler.setLevel(DEBUG) if path is not None: make_log_dir_if_missing(path) @@ -105,14 +124,14 @@ def initialize_logger(debug_mode=False, path=None): logdir_handler.addFilter(color_filter) logdir_handler.setFormatter(default_formatter()) - logdir_handler.setLevel(logging.DEBUG) + logdir_handler.setLevel(DEBUG) logger.addHandler(logdir_handler) # Log Python warnings to file warning_logger = logging.getLogger('py.warnings') warning_logger.addHandler(logdir_handler) - warning_logger.setLevel(logging.DEBUG) + warning_logger.setLevel(DEBUG) initialized = True @@ -130,6 +149,25 @@ def log_cache_events(flag): GLOBAL_LOGGER = logger +class QueueFormatter(logging.Formatter): + def format(self, record): + record.message = record.getMessage() + record.asctime = self.formatTime(record, self.datefmt) + formatted = self.formatMessage(record) + + output = { + 'message': formatted, + 'timestamp': record.asctime, + 'levelname': record.levelname, + 'level': record.levelno, + } + if record.exc_info: + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + output['exc_info'] = record.exc_text + return output + + class QueueLogHandler(logging.Handler): def __init__(self, queue): super(QueueLogHandler, self).__init__() @@ -143,6 +181,6 @@ def emit(self, record): def add_queue_handler(queue): """Add a queue log handler to the global logger.""" handler = QueueLogHandler(queue) - handler.setFormatter(default_formatter()) - handler.setLevel(logging.DEBUG) + handler.setFormatter(QueueFormatter()) + handler.setLevel(DEBUG) GLOBAL_LOGGER.addHandler(handler) From c86390e139244bccfafe093faa7bd6a083bf1e86 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 8 Mar 2019 10:54:50 -0700 Subject: [PATCH 5/9] use notice logging for "Found x models, ...", change a couple other levels --- core/dbt/compilation.py | 2 +- core/dbt/logger.py | 2 +- core/dbt/main.py | 2 +- core/dbt/task/base.py | 8 ++++---- core/dbt/task/runnable.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 52368161aab..5a1045e0350 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -45,7 +45,7 @@ def print_compile_stats(stats): stat_line = ", ".join( ["{} {}".format(ct, names.get(t)) for t, ct in results.items()]) - logger.info("Found {}".format(stat_line)) + logger.notice("Found {}".format(stat_line)) def _add_prepended_cte(prepended_ctes, new_cte): diff --git a/core/dbt/logger.py b/core/dbt/logger.py index af9efd1cbee..fd2eb764790 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -48,7 +48,7 @@ def notice(self, msg, *args, **kwargs): # create a global console logger for dbt stdout_handler = logging.StreamHandler(colorama_stdout) stdout_handler.setFormatter(logging.Formatter('%(message)s')) -stdout_handler.setLevel(INFO) +stdout_handler.setLevel(NOTICE) logger = logging.getLogger('dbt') logger.addHandler(stdout_handler) diff --git a/core/dbt/main.py b/core/dbt/main.py index 231c8e80802..fc4353214d5 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -166,7 +166,7 @@ def track_run(task): ) except (dbt.exceptions.NotImplementedException, dbt.exceptions.FailedToConnectException) as e: - logger.info('ERROR: {}'.format(e)) + logger.error('ERROR: {}'.format(e)) dbt.tracking.track_invocation_end( config=task.config, args=task.args, result_type="error" ) diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 68c1f230a21..29355478685 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -52,16 +52,16 @@ def from_args(cls, args): try: config = cls.ConfigType.from_args(args) except dbt.exceptions.DbtProjectError as exc: - logger.info("Encountered an error while reading the project:") - logger.info(to_string(exc)) + logger.error("Encountered an error while reading the project:") + logger.error(" ERROR: {}".format(str(exc))) tracking.track_invalid_invocation( args=args, result_type=exc.result_type) raise dbt.exceptions.RuntimeException('Could not run dbt') except dbt.exceptions.DbtProfileError as exc: - logger.info("Encountered an error while reading profiles:") - logger.info(" ERROR {}".format(str(exc))) + logger.error("Encountered an error while reading profiles:") + logger.error(" ERROR {}".format(str(exc))) all_profiles = read_profiles(args.profiles_dir).keys() diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index c749cf82642..35a346ec97b 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -273,8 +273,8 @@ def run(self): self._runtime_initialize() if len(self._flattened_nodes) == 0: - logger.info("WARNING: Nothing to do. Try checking your model " - "configs and model specification args") + logger.warning("WARNING: Nothing to do. Try checking your model " + "configs and model specification args") return [] else: logger.info("") From fbaae2e4930f8e8261cdf43c4692047b63ea43af Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 8 Mar 2019 12:02:36 -0700 Subject: [PATCH 6/9] fix Python 2.7 --- core/dbt/logger.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/core/dbt/logger.py b/core/dbt/logger.py index fd2eb764790..333f5de725e 100644 --- a/core/dbt/logger.py +++ b/core/dbt/logger.py @@ -150,6 +150,22 @@ def log_cache_events(flag): class QueueFormatter(logging.Formatter): + def formatMessage(self, record): + superself = super(QueueFormatter, self) + if hasattr(superself, 'formatMessage'): + # python 3.x + return superself.formatMessage(record) + + # python 2.x, handling weird unicode things + try: + return self._fmt % record.__dict__ + except UnicodeDecodeError as e: + try: + record.name = record.name.decode('utf-8') + return self._fmt % record.__dict__ + except UnicodeDecodeError as e: + raise e + def format(self, record): record.message = record.getMessage() record.asctime = self.formatTime(record, self.datefmt) From fc22cb2bf0a1103a7ced53e3d5f4733be955ff1c Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 8 Mar 2019 13:28:44 -0700 Subject: [PATCH 7/9] when encoding json, handle dates and times like datetimes --- core/dbt/utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 5139cee9dc6..e4cb9eb7c24 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -1,18 +1,14 @@ -from datetime import datetime -from decimal import Decimal - import collections import copy +import datetime import functools import hashlib import itertools import json -import numbers import os import dbt.exceptions -from dbt.include.global_project import PACKAGES from dbt.compat import basestring, DECIMALS from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType @@ -442,7 +438,7 @@ def add_ephemeral_model_prefix(s): def timestring(): """Get the current datetime as an RFC 3339-compliant string""" # isoformat doesn't include the mandatory trailing 'Z' for UTC. - return datetime.utcnow().isoformat() + 'Z' + return datetime.datetime.utcnow().isoformat() + 'Z' class JSONEncoder(json.JSONEncoder): @@ -453,8 +449,9 @@ class JSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, DECIMALS): return float(obj) - if isinstance(obj, datetime): + if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): return obj.isoformat() + return super(JSONEncoder, self).default(obj) From 81426ae800abf896bf583d031a48dcf3f4267674 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 11 Mar 2019 18:16:52 -0600 Subject: [PATCH 8/9] add optional "macros" parameter to dbt rpc calls --- core/dbt/parser/util.py | 5 +- core/dbt/rpc.py | 6 +- core/dbt/task/compile.py | 30 ++++-- core/dbt/task/runnable.py | 2 - .../042_sources_test/macros/macro.sql | 7 ++ .../042_sources_test/test_sources.py | 100 +++++++++++++++--- 6 files changed, 120 insertions(+), 30 deletions(-) create mode 100644 test/integration/042_sources_test/macros/macro.sql diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py index 92c90fe68a0..09b03185818 100644 --- a/core/dbt/parser/util.py +++ b/core/dbt/parser/util.py @@ -231,12 +231,15 @@ def process_sources(cls, manifest, current_project): return manifest @classmethod - def add_new_refs(cls, manifest, current_project, node): + def add_new_refs(cls, manifest, current_project, node, macros): """Given a new node that is not in the manifest, copy the manifest and insert the new node into it as if it were part of regular ref processing """ manifest = manifest.deepcopy(config=current_project) + # it's ok for macros to silently override a local project macro name + manifest.macros.update(macros) + if node.unique_id in manifest.nodes: # this should be _impossible_ due to the fact that rpc calls get # a unique ID that starts with 'rpc'! diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index ce598b8866d..f20f2ab32fb 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -31,11 +31,11 @@ def from_error(cls, err): return cls(err.code, err.message, err.data, err.data.get('logs')) -def invalid_params(err, logs): +def invalid_params(data): return RPCException( - code=JSONRPCInvalidParams.code, + code=JSONRPCInvalidParams.CODE, message=JSONRPCInvalidParams.MESSAGE, - data={'logs': logs} + data=data ) diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index f223d8534a2..3cbbe78f880 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -6,6 +6,7 @@ from dbt.node_runners import CompileRunner, RPCCompileRunner from dbt.node_types import NodeType from dbt.parser.analysis import RPCCallParser +from dbt.parser.macros import MacroParser from dbt.parser.util import ParserUtils import dbt.ui.printer @@ -37,7 +38,6 @@ class RemoteCompileTask(CompileTask, RemoteCallable): def __init__(self, args, config): super(RemoteCompileTask, self).__init__(args, config) - self.parser = None self._base_manifest = GraphLoader.load_all( config, internal_manifest=get_adapter(config).check_internal_manifest() @@ -56,15 +56,28 @@ def runtime_cleanup(self, selected_uids): self._skipped_children = {} self._raise_next_tick = None - def handle_request(self, name, sql): - self.parser = RPCCallParser( + def handle_request(self, name, sql, macros=None): + request_path = os.path.join(self.config.target_path, 'rpc', name) + all_projects = load_all_projects(self.config) + macro_overrides = {} + if macros is not None: + macros = self.decode_sql(macros) + macro_parser = MacroParser(self.config, all_projects) + macro_overrides.update(macro_parser.parse_macro_file( + macro_file_path='from remote system', + macro_file_contents=macros, + root_path=request_path, + package_name=self.config.project_name, + resource_type=NodeType.Macro + )) + + rpc_parser = RPCCallParser( self.config, - all_projects=load_all_projects(self.config), + all_projects=all_projects, macro_manifest=self._base_manifest ) sql = self.decode_sql(sql) - request_path = os.path.join(self.config.target_path, 'rpc', name) node_dict = { 'name': name, 'root_path': request_path, @@ -74,13 +87,16 @@ def handle_request(self, name, sql): 'package_name': self.config.project_name, 'raw_sql': sql, } - unique_id, node = self.parser.parse_sql_node(node_dict) + + unique_id, node = rpc_parser.parse_sql_node(node_dict) self.manifest = ParserUtils.add_new_refs( manifest=self._base_manifest, current_project=self.config, - node=node + node=node, + macros=macro_overrides ) + # don't write our new, weird manifest! self.linker = compile_manifest(self.config, self.manifest, write=False) selected_uids = [node.unique_id] diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 35a346ec97b..d06bf83c35c 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -369,8 +369,6 @@ def decode_sql(self, sql): @staticmethod def raise_invalid_base64(sql): raise rpc.invalid_params( - code=JSONRPCInvalidParams.CODE, - message=JSONRPCInvalidParams.MESSAGE, data={ 'message': 'invalid base64-encoded sql input', 'sql': str(sql), diff --git a/test/integration/042_sources_test/macros/macro.sql b/test/integration/042_sources_test/macros/macro.sql new file mode 100644 index 00000000000..c1d3b1f47bb --- /dev/null +++ b/test/integration/042_sources_test/macros/macro.sql @@ -0,0 +1,7 @@ +{% macro override_me() -%} + exceptions.raise_compiler_error('this is a bad macro') +{%- endmacro %} + +{% macro happy_little_macro() -%} + {{ override_me() }} +{%- endmacro %} diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index fd6867b84ab..d4f0edc81cf 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -1,14 +1,21 @@ import unittest -from nose.plugins.attrib import attr from datetime import datetime, timedelta import json import os +import multiprocessing +from base64 import standard_b64encode as b64 +import requests +import socket +import time + + from dbt.exceptions import CompilationException from test.integration.base import DBTIntegrationTest, use_profile, AnyFloat, \ AnyStringWith from dbt.main import handle_and_check + class BaseSourcesTest(DBTIntegrationTest): @property def schema(self): @@ -260,16 +267,6 @@ def test_postgres_malformed_schema_strict_will_break_run(self): self.run_dbt_with_vars(['run'], strict=True) -import multiprocessing -from base64 import standard_b64encode as b64 -import json -import requests -import socket -import time -import os - - - class ServerProcess(multiprocessing.Process): def __init__(self, cli_vars=None): self.port = 22991 @@ -303,7 +300,7 @@ def start(self): raise Exception('server never appeared!') -@unittest.skipIf(os.name=='nt', 'Windows not supported for now') +@unittest.skipIf(os.name == 'nt', 'Windows not supported for now') class TestRPCServer(BaseSourcesTest): def setUp(self): super(TestRPCServer, self).setUp() @@ -316,10 +313,20 @@ def tearDown(self): self._server.terminate() super(TestRPCServer, self).tearDown() - def build_query(self, method, kwargs, sql=None, test_request_id=1): + @property + def project_config(self): + return { + 'data-paths': ['test/integration/042_sources_test/data'], + 'quoting': {'database': True, 'schema': True, 'identifier': True}, + 'macro-paths': ['test/integration/042_sources_test/macros'], + } + + def build_query(self, method, kwargs, sql=None, test_request_id=1, macros=None): if sql is not None: kwargs['sql'] = b64(sql.encode('utf-8')).decode('utf-8') + if macros is not None: + kwargs['macros'] = b64(macros.encode('utf-8')).decode('utf-8') return { 'jsonrpc': '2.0', 'method': method, @@ -333,8 +340,8 @@ def perform_query(self, query): response = requests.post(url, headers=headers, data=json.dumps(query)) return response - def query(self, _method, _sql=None, _test_request_id=1, **kwargs): - built = self.build_query(_method, kwargs, _sql, _test_request_id) + def query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs): + built = self.build_query(_method, kwargs, _sql, _test_request_id, macros) return self.perform_query(built) def assertResultHasTimings(self, result, *names): @@ -425,7 +432,6 @@ def test_compile(self): 'select * from {{ source("test_source", "test_table") }}', name='foo' ).json() - self.assertSuccessfulCompilationResult( source, 'select * from {{ source("test_source", "test_table") }}', @@ -434,6 +440,30 @@ def test_compile(self): self.unique_schema()) ) + macro = self.query( + 'compile', + 'select {{ my_macro() }}', + name='foo', + macros='{% macro my_macro() %}1 as id{% endmacro %}' + ).json() + self.assertSuccessfulCompilationResult( + macro, + 'select {{ my_macro() }}', + compiled_sql='select 1 as id' + ) + + macro_override = self.query( + 'compile', + 'select {{ happy_little_macro() }}', + name='foo', + macros='{% macro override_me() %}2 as id{% endmacro %}' + ).json() + self.assertSuccessfulCompilationResult( + macro_override, + 'select {{ happy_little_macro() }}', + compiled_sql='select 2 as id' + ) + @use_profile('postgres') def test_run(self): # seed + run dbt to make models before using them! @@ -470,7 +500,6 @@ def test_run(self): 'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1', name='foo' ).json() - self.assertSuccessfulRunResult( source, 'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1', @@ -483,6 +512,32 @@ def test_run(self): } ) + macro = self.query( + 'run', + 'select {{ my_macro() }}', + name='foo', + macros='{% macro my_macro() %}1 as id{% endmacro %}' + ).json() + self.assertSuccessfulRunResult( + macro, + raw_sql='select {{ my_macro() }}', + compiled_sql='select 1 as id', + table={'column_names': ['id'], 'rows': [[1.0]]} + ) + + macro_override = self.query( + 'run', + 'select {{ happy_little_macro() }}', + name='foo', + macros='{% macro override_me() %}2 as id{% endmacro %}' + ).json() + self.assertSuccessfulRunResult( + macro_override, + raw_sql='select {{ happy_little_macro() }}', + compiled_sql='select 2 as id', + table={'column_names': ['id'], 'rows': [[2.0]]} + ) + @use_profile('postgres') def test_invalid_requests(self): data = self.query( @@ -526,6 +581,17 @@ def test_invalid_requests(self): self.assertIn('logs', error_data) self.assertTrue(len(error_data['logs']) > 0) + macro_no_override = self.query( + 'run', + 'select {{ happy_little_macro() }}', + name='foo', + ).json() + self.assertIsErrorWithCode(macro_no_override, 10003) + self.assertEqual(error['message'], 'Database Error') + self.assertIn('data', error) + error_data = error['data'] + self.assertEqual(error_data['type'], 'DatabaseException') + @use_profile('postgres') def test_timeout(self): data = self.query( From 9c8e08811ba8105bceb9b3483beb9e32127485a8 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 12 Mar 2019 15:09:07 -0600 Subject: [PATCH 9/9] redshift can just change this on you apparently --- test/integration/029_docs_generate_tests/test_docs_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index 2164f8e3bfb..26339897ad5 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -146,7 +146,7 @@ def _redshift_stats(self): "diststyle": { "id": "diststyle", "label": "Dist Style", - "value": "EVEN", + "value": AnyStringWith(None), "description": "Distribution style or distribution key column, if key distribution is defined.", "include": True },