From 0e01064a237339779cb40dff6b516d045a2d2463 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Fri, 6 Jan 2017 13:05:58 -0500 Subject: [PATCH] pep8 compliance (#257) add pep8 check to continuous integration tests and bring codebase into compliance --- CHANGELOG.md | 1 + dbt/__init__.py | 1 - dbt/archival.py | 45 +++--- dbt/compilation.py | 240 +++++++++++++++++++++++-------- dbt/compiled_model.py | 31 ++-- dbt/config.py | 3 +- dbt/deprecations.py | 20 ++- dbt/linker.py | 32 +++-- dbt/main.py | 200 ++++++++++++++++++++++---- dbt/model.py | 325 ++++++++++++++++++++++++++++++++---------- dbt/project.py | 34 +++-- dbt/runner.py | 283 ++++++++++++++++++++++++------------ dbt/runtime.py | 1 + dbt/schema.py | 144 +++++++++++++------ dbt/schema_tester.py | 12 +- dbt/seeder.py | 58 ++++++-- dbt/source.py | 36 +++-- dbt/ssh_forward.py | 21 +-- dbt/targets.py | 89 +++++------- dbt/task/archive.py | 10 +- dbt/task/clean.py | 6 +- dbt/task/compile.py | 5 +- dbt/task/deps.py | 35 +++-- dbt/task/init.py | 13 +- dbt/task/run.py | 28 +++- dbt/task/seed.py | 2 +- dbt/task/test.py | 18 ++- dbt/templates.py | 4 +- dbt/tracking.py | 120 +++++++++++----- dbt/utils.py | 79 +++++++--- dbt/version.py | 34 +++-- setup.py | 4 +- tox.ini | 9 +- 33 files changed, 1376 insertions(+), 567 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16c3a649fa9..ab82713d39b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### Changes - add `--debug` flag, replace calls to `print()` with a global logger ([#256](https://github.com/analyst-collective/dbt/pull/256)) +- add pep8 check to continuous integration tests and bring codebase into compliance ([#257](https://github.com/analyst-collective/dbt/pull/257)) ## dbt release 0.6.0 diff --git a/dbt/__init__.py b/dbt/__init__.py index 8b137891791..e69de29bb2d 100644 --- a/dbt/__init__.py +++ b/dbt/__init__.py @@ -1 +0,0 @@ - diff --git a/dbt/archival.py b/dbt/archival.py index 23852e56acc..390f0f2e9ee 100644 --- a/dbt/archival.py +++ b/dbt/archival.py @@ -1,4 +1,3 @@ - from __future__ import print_function import dbt.targets import dbt.schema @@ -18,41 +17,50 @@ def __init__(self, project, archive_model): def compile(self): source_schema = self.archive_model.source_schema target_schema = self.archive_model.target_schema - source_table = self.archive_model.source_table - target_table = self.archive_model.target_table - unique_key = self.archive_model.unique_key - updated_at = self.archive_model.updated_at + source_table = self.archive_model.source_table + target_table = self.archive_model.target_table + unique_key = self.archive_model.unique_key + updated_at = self.archive_model.updated_at self.schema.create_schema(target_schema) - source_columns = self.schema.get_columns_in_table(source_schema, source_table) + source_columns = self.schema.get_columns_in_table( + source_schema, source_table) if len(source_columns) == 0: - raise RuntimeError('Source table "{}"."{}" does not exist'.format(source_schema, source_table)) + raise RuntimeError( + 'Source table "{}"."{}" does not ' + 'exist'.format(source_schema, source_table)) extra_cols = [ dbt.schema.Column("valid_from", "timestamp", None), dbt.schema.Column("valid_to", "timestamp", None), - dbt.schema.Column("scd_id","text", None), - dbt.schema.Column("dbt_updated_at","timestamp", None) + dbt.schema.Column("scd_id", "text", None), + dbt.schema.Column("dbt_updated_at", "timestamp", None) ] dest_columns = source_columns + extra_cols - self.schema.create_table(target_schema, target_table, dest_columns, sort=updated_at, dist=unique_key) + self.schema.create_table( + target_schema, + target_table, + dest_columns, + sort=updated_at, + dist=unique_key + ) env = jinja2.Environment() ctx = { - "columns" : source_columns, - "updated_at" : updated_at, - "unique_key" : unique_key, - "source_schema" : source_schema, - "source_table" : source_table, - "target_schema" : target_schema, - "target_table" : target_table + "columns": source_columns, + "updated_at": updated_at, + "unique_key": unique_key, + "source_schema": source_schema, + "source_table": source_table, + "target_schema": target_schema, + "target_table": target_table } - base_query = dbt.templates.SCDArchiveTemplate + base_query = dbt.templates.SCDArchiveTemplate template = env.from_string(base_query, globals=ctx) rendered = template.render(ctx) @@ -62,4 +70,3 @@ def runtime_compile(self, compiled_model): context = self.context.copy() context.update(model.context()) model.compile(context) - diff --git a/dbt/compilation.py b/dbt/compilation.py index 653cdbf55d9..db3faff2530 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -6,16 +6,22 @@ import sqlparse import dbt.project +from dbt.source import Source +from dbt.utils import find_model_by_fqn, find_model_by_name, \ + dependency_projects, split_path, This, Var, compiler_error, \ + to_string + +from dbt.linker import Linker +from dbt.runtime import RuntimeContext import dbt.targets import dbt.templates -from dbt.linker import Linker from dbt.logger import GLOBAL_LOGGER as logger -from dbt.runtime import RuntimeContext -from dbt.source import Source -from dbt.utils import find_model_by_fqn, find_model_by_name, dependency_projects, split_path, This, Var, compiler_error, to_string -CompilableEntities = ["models", "data tests", "schema tests", "archives", "analyses"] +CompilableEntities = [ + "models", "data tests", "schema tests", "archives", "analyses" +] + def compile_string(string, ctx): try: @@ -27,6 +33,7 @@ def compile_string(string, ctx): except jinja2.exceptions.UndefinedError as e: compiler_error(None, str(e)) + class Compiler(object): def __init__(self, project, create_template_class, args): self.project = project @@ -55,13 +62,23 @@ def model_sources(self, this_project, own_project=None): paths = own_project.get('source-paths', []) if self.create_template.label == 'build': - return Source(this_project, own_project=own_project).get_models(paths, self.create_template) + return Source( + this_project, + own_project=own_project + ).get_models(paths, self.create_template) + elif self.create_template.label == 'test': - return Source(this_project, own_project=own_project).get_test_models(paths, self.create_template) + return Source( + this_project, + own_project=own_project + ).get_test_models(paths, self.create_template) + elif self.create_template.label == 'archive': return [] else: - raise RuntimeError("unexpected create template type: '{}'".format(self.create_template.label)) + raise RuntimeError( + "unexpected create template " + "type: '{}'".format(self.create_template.label)) def get_macros(self, this_project, own_project=None): if own_project is None: @@ -71,7 +88,10 @@ def get_macros(self, this_project, own_project=None): def get_archives(self, project): archive_template = dbt.templates.ArchiveInsertTemplate() - return Source(project, own_project=project).get_archives(archive_template) + return Source( + project, + own_project=project + ).get_archives(archive_template) def project_schemas(self): source_paths = self.project.get('source-paths', []) @@ -91,8 +111,15 @@ def validate_models_unique(self, models): found_models[model.name].append(model) for model_name, model_list in found_models.items(): if len(model_list) > 1: - models_str = "\n - ".join([str(model) for model in model_list]) - raise RuntimeError("Found {} models with the same name! Can't create tables. Name='{}'\n - {}".format(len(model_list), model_name, models_str)) + models_str = "\n - ".join( + [str(model) for model in model_list]) + + raise RuntimeError( + "Found {} models with the same name! Can't " + "create tables. Name='{}'\n - {}".format( + len(model_list), model_name, models_str + ) + ) def __write(self, build_filepath, payload): target_path = os.path.join(self.project['target-path'], build_filepath) @@ -103,7 +130,6 @@ def __write(self, build_filepath, payload): with open(target_path, 'w') as f: f.write(to_string(payload)) - def __model_config(self, model, linker): def do_config(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0: @@ -111,10 +137,14 @@ def do_config(*args, **kwargs): elif len(args) == 0 and len(kwargs) > 0: opts = kwargs else: - raise RuntimeError("Invalid model config given inline in {}".format(model)) + raise RuntimeError( + "Invalid model config given inline in {}".format(model) + ) if type(opts) != dict: - raise RuntimeError("Invalid model config given inline in {}".format(model)) + raise RuntimeError( + "Invalid model config given inline in {}".format(model) + ) model.update_in_model_config(opts) model.add_to_prologue("Config specified in model: {}".format(opts)) @@ -122,13 +152,17 @@ def do_config(*args, **kwargs): return do_config def model_can_reference(self, src_model, other_model): - """returns True if the src_model can reference the other_model. Models can access - other models in their package and dependency models, but a dependency model cannot - access models "up" the dependency chain""" + """ + returns True if the src_model can reference the other_model. Models + can access other models in their package and dependency models, but + a dependency model cannot access models "up" the dependency chain. + """ # hack for now b/c we don't support recursive dependencies - return other_model.own_project['name'] == src_model.own_project['name'] \ - or src_model.own_project['name'] == src_model.project['name'] + return ( + other_model.own_project['name'] == src_model.own_project['name'] or + src_model.own_project['name'] == src_model.project['name'] + ) def __ref(self, linker, ctx, model, all_models, add_dependency=True): schema = ctx['env']['schema'] @@ -142,20 +176,32 @@ def do_ref(*args): other_model = find_model_by_name(all_models, other_model_name) elif len(args) == 2: other_model_package, other_model_name = args - other_model_name = self.create_template.model_name(other_model_name) - other_model = find_model_by_name(all_models, other_model_name, package_namespace=other_model_package) + other_model_name = self.create_template.model_name( + other_model_name + ) + + other_model = find_model_by_name( + all_models, + other_model_name, + package_namespace=other_model_package + ) else: - compiler_error(model, "ref() takes at most two arguments ({} given)".format(len(args))) + compiler_error( + model, + "ref() takes at most two arguments ({} given)".format( + len(args) + ) + ) other_model_fqn = tuple(other_model.fqn[:-1] + [other_model_name]) src_fqn = ".".join(source_model) ref_fqn = ".".join(other_model_fqn) - #if not self.model_can_reference(model, other_model): - # compiler_error(model, "Model '{}' exists but cannot be referenced from dependency model '{}'".format(ref_fqn, src_fqn)) - if not other_model.is_enabled: - raise RuntimeError("Model '{}' depends on model '{}' which is disabled in the project config".format(src_fqn, ref_fqn)) + raise RuntimeError( + "Model '{}' depends on model '{}' which is disabled in " + "the project config".format(src_fqn, ref_fqn) + ) # this creates a trivial cycle -- should this be a compiler error? # we can still interpolate the name w/o making a self-cycle @@ -174,7 +220,11 @@ def wrapped_do_ref(*args): try: return do_ref(*args) except RuntimeError as e: - root = os.path.relpath(model.root_dir, model.project['project-root']) + root = os.path.relpath( + model.root_dir, + model.project['project-root'] + ) + filepath = os.path.join(root, model.rel_filepath) logger.info("Compiler error in {}".format(filepath)) logger.info("Enabled models:") @@ -190,15 +240,19 @@ def get_context(self, linker, model, models, add_dependency=False): context = self.project.context() # built-ins - context['ref'] = self.__ref(linker, context, model, models, add_dependency) + context['ref'] = self.__ref( + linker, context, model, models, add_dependency + ) context['config'] = self.__model_config(model, linker) - context['this'] = This(context['env']['schema'], model.immediate_name, model.name) + context['this'] = This( + context['env']['schema'], model.immediate_name, model.name + ) context['var'] = Var(model, context=context) context['target'] = self.project.get_target() # these get re-interpolated at runtime! context['run_started_at'] = '{{ run_started_at }}' - context['invocation_id'] = '{{ invocation_id }}' + context['invocation_id'] = '{{ invocation_id }}' # add in context from run target context.update(self.target.context) @@ -223,10 +277,13 @@ def compile_model(self, linker, model, models, add_dependency=True): fs_loader = jinja2.FileSystemLoader(searchpath=model.root_dir) jinja = jinja2.Environment(loader=fs_loader) - # this is a dumb jinja2 bug -- on windows, forward slashes are EXPECTED + # this is a dumb jinja2 bug -- on windows, forward slashes + # are EXPECTED posix_filepath = '/'.join(split_path(model.rel_filepath)) template = jinja.get_template(posix_filepath) - context = self.get_context(linker, model, models, add_dependency=add_dependency) + context = self.get_context( + linker, model, models, add_dependency=add_dependency + ) rendered = template.render(context) except jinja2.exceptions.TemplateSyntaxError as e: @@ -244,7 +301,11 @@ def write_graph_file(self, linker, label): def combine_query_with_ctes(self, model, query, ctes, compiled_models): parsed_stmts = sqlparse.parse(query) if len(parsed_stmts) != 1: - raise RuntimeError("unexpectedly parsed {} queries from model {}".format(len(parsed_stmts), model)) + raise RuntimeError( + "unexpectedly parsed {} queries from model " + "{}".format(len(parsed_stmts), model) + ) + parsed = parsed_stmts[0] with_stmt = None @@ -259,16 +320,27 @@ def combine_query_with_ctes(self, model, query, ctes, compiled_models): with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') parsed.insert_before(first_token, with_stmt) else: - # stmt exists, add a comma (which will come after our injected CTE(s) ) - trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ',') + # stmt exists, add a comma (which will come after our injected + # CTE(s) ) + trailing_comma = sqlparse.sql.Token( + sqlparse.tokens.Punctuation, ',' + ) parsed.insert_after(with_stmt, trailing_comma) - cte_mapping = [(model.cte_name, compiled_models[model]) for model in ctes] + cte_mapping = [ + (model.cte_name, compiled_models[model]) for model in ctes + ] - # these newlines are important -- comments could otherwise interfere w/ query - cte_stmts = [" {} as (\n{}\n)".format(name, contents) for (name, contents) in cte_mapping] + # these newlines are important -- comments could otherwise interfere + # w/ query + cte_stmts = [ + " {} as (\n{}\n)".format(name, contents) + for (name, contents) in cte_mapping + ] - cte_text = sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(cte_stmts)) + cte_text = sqlparse.sql.Token( + sqlparse.tokens.Keyword, ", ".join(cte_stmts) + ) parsed.insert_after(with_stmt, cte_text) return str(parsed) @@ -278,12 +350,18 @@ def __recursive_add_ctes(self, linker, model): return set() models_to_add = linker.cte_map[model] - recursive_models = [self.__recursive_add_ctes(linker, m) for m in models_to_add] + recursive_models = [ + self.__recursive_add_ctes(linker, m) for m in models_to_add + ] + for recursive_model_set in recursive_models: - models_to_add = models_to_add | recursive_model_set + models_to_add = models_to_add | recursive_model_set + return models_to_add - def add_cte_to_rendered_query(self, linker, primary_model, compiled_models): + def add_cte_to_rendered_query( + self, linker, primary_model, compiled_models + ): fqn_to_model = {tuple(model.fqn): model for model in compiled_models} sorted_nodes = linker.as_topological_ordering() @@ -296,7 +374,7 @@ def add_cte_to_rendered_query(self, linker, primary_model, compiled_models): continue model = fqn_to_model[node] - # add these in topological sort order -- that's significant for CTEs + # add these in topological sort order -- significant for CTEs if model.is_ephemeral and model in models_to_add: required_ctes.append(model) @@ -304,7 +382,9 @@ def add_cte_to_rendered_query(self, linker, primary_model, compiled_models): if len(required_ctes) == 0: return query else: - compiled_query = self.combine_query_with_ctes(primary_model, query, required_ctes, compiled_models) + compiled_query = self.combine_query_with_ctes( + primary_model, query, required_ctes, compiled_models + ) return compiled_query def remove_node_from_graph(self, linker, model, models): @@ -318,22 +398,34 @@ def remove_node_from_graph(self, linker, model, models): if other_model.is_enabled: this_fqn = ".".join(model.fqn) that_fqn = ".".join(other_model.fqn) - compiler_error(model, "Model '{}' depends on model '{}' which is disabled".format(that_fqn, this_fqn)) + compiler_error( + model, + "Model '{}' depends on model '{}' which is " + "disabled".format(that_fqn, this_fqn) + ) def compile_models(self, linker, models): - compiled_models = {model: self.compile_model(linker, model, models) for model in models} - sorted_models = [find_model_by_fqn(models, fqn) for fqn in linker.as_topological_ordering()] + compiled_models = {model: self.compile_model(linker, model, models) + for model in models} + sorted_models = [find_model_by_fqn(models, fqn) + for fqn in linker.as_topological_ordering()] written_models = [] for model in sorted_models: - # in-model configs were just evaluated. Evict anything that is newly-disabled + # in-model configs were just evaluated. Evict anything that is + # newly-disabled if not model.is_enabled: self.remove_node_from_graph(linker, model, models) continue - injected_stmt = self.add_cte_to_rendered_query(linker, model, compiled_models) + injected_stmt = self.add_cte_to_rendered_query( + linker, model, compiled_models + ) + context = self.get_context(linker, model, models) - wrapped_stmt = model.compile(injected_stmt, self.project, self.create_template, context) + wrapped_stmt = model.compile( + injected_stmt, self.project, self.create_template, context + ) serialized = model.serialize() linker.update_node_data(tuple(model.fqn), serialized) @@ -348,14 +440,22 @@ def compile_models(self, linker, models): def compile_analyses(self, linker, compiled_models): analyses = self.analysis_sources(self.project) - compiled_analyses = {analysis: self.compile_model(linker, analysis, compiled_models) for analysis in analyses} + compiled_analyses = { + analysis: self.compile_model( + linker, analysis, compiled_models + ) for analysis in analyses + } written_analyses = [] referenceable_models = {} referenceable_models.update(compiled_models) referenceable_models.update(compiled_analyses) for analysis in analyses: - injected_stmt = self.add_cte_to_rendered_query(linker, analysis, referenceable_models) + injected_stmt = self.add_cte_to_rendered_query( + linker, + analysis, + referenceable_models + ) build_path = analysis.build_path() self.__write(build_path, injected_stmt) written_analyses.append(analysis) @@ -369,7 +469,8 @@ def compile_schema_tests(self, linker): schema_tests = [] for schema in schemas: - schema_tests.extend(schema.compile()) # compiling a SchemaFile returns >= 0 SchemaTest models + # compiling a SchemaFile returns >= 0 SchemaTest models + schema_tests.extend(schema.compile()) written_tests = [] for schema_test in schema_tests: @@ -392,7 +493,9 @@ def compile_data_tests(self, linker): for data_test in tests: serialized = data_test.serialize() linker.update_node_data(tuple(data_test.fqn), serialized) - query = self.compile_model(linker, data_test, enabled_models, add_dependency=False) + query = self.compile_model( + linker, data_test, enabled_models, add_dependency=False + ) wrapped = data_test.render(query) self.__write(data_test.build_path(), wrapped) written_tests.append(data_test) @@ -424,7 +527,11 @@ def compile_archives(self): def get_models(self): all_models = self.model_sources(this_project=self.project) for project in dependency_projects(self.project): - all_models.extend(self.model_sources(this_project=self.project, own_project=project)) + all_models.extend( + self.model_sources( + this_project=self.project, own_project=project + ) + ) return all_models @@ -435,16 +542,23 @@ def compile(self, limit_to=None): all_macros = self.get_macros(this_project=self.project) for project in dependency_projects(self.project): - all_macros.extend(self.get_macros(this_project=self.project, own_project=project)) + all_macros.extend( + self.get_macros(this_project=self.project, own_project=project) + ) self.macro_generator = self.generate_macros(all_macros) if limit_to is not None and 'models' in limit_to: - enabled_models = [model for model in all_models if model.is_enabled and not model.is_empty] + enabled_models = [ + model for model in all_models + if model.is_enabled and not model.is_empty + ] else: enabled_models = [] - compiled_models, written_models = self.compile_models(linker, enabled_models) + compiled_models, written_models = self.compile_models( + linker, enabled_models + ) # TODO : only compile schema tests for enabled models if limit_to is not None and 'tests' in limit_to: @@ -458,12 +572,12 @@ def compile(self, limit_to=None): self.validate_models_unique(written_schema_tests) self.write_graph_file(linker, self.create_template.label) - if limit_to is not None and 'analyses' in limit_to and self.create_template.label not in ['test', 'archive']: + if limit_to is not None and 'analyses' in limit_to and \ + self.create_template.label not in ['test', 'archive']: written_analyses = self.compile_analyses(linker, compiled_models) else: written_analyses = [] - if limit_to is not None and 'archives' in limit_to: compiled_archives = self.compile_archives() else: @@ -471,8 +585,8 @@ def compile(self, limit_to=None): return { "models": len(written_models), - "schema tests" : len(written_schema_tests), - "data tests" : len(written_data_tests), + "schema tests": len(written_schema_tests), + "data tests": len(written_data_tests), "archives": len(compiled_archives), - "analyses" : len(written_analyses) + "analyses": len(written_analyses) } diff --git a/dbt/compiled_model.py b/dbt/compiled_model.py index eae1fd5eecd..936dbe0d4c9 100644 --- a/dbt/compiled_model.py +++ b/dbt/compiled_model.py @@ -2,6 +2,7 @@ import jinja2 from dbt.utils import compiler_error, to_unicode + class CompiledModel(object): def __init__(self, fqn, data): self.fqn = fqn @@ -76,12 +77,17 @@ def project(self): @property def schema(self): if self.target is None: - raise RuntimeError("`target` not set in compiled model {}".format(self)) + raise RuntimeError( + "`target` not set in compiled model {}".format(self) + ) else: return self.target.schema def should_execute(self, args, existing): - if args.non_destructive and self.materialization == 'view' and self.name in existing: + if args.non_destructive and \ + self.materialization == 'view' and \ + self.name in existing: + return False else: return self.data['enabled'] and self.materialization != 'ephemeral' @@ -90,14 +96,14 @@ def should_rename(self, args): if args.non_destructive and self.materialization == 'table': return False else: - return self.materialization in ['table' , 'view'] + return self.materialization in ['table', 'view'] def prepare(self, existing, target): if self.materialization == 'incremental': tmp_drop_type = None final_drop_type = None else: - tmp_drop_type = existing.get(self.tmp_name, None) + tmp_drop_type = existing.get(self.tmp_name, None) final_drop_type = existing.get(self.name, None) self.tmp_drop_type = tmp_drop_type @@ -105,7 +111,10 @@ def prepare(self, existing, target): self.target = target def __repr__(self): - return "".format(self.data['project_name'], self.name, self.data['build_path']) + return "".format( + self.data['project_name'], self.name, self.data['build_path'] + ) + class CompiledTest(CompiledModel): def __init__(self, fqn, data): @@ -121,7 +130,10 @@ def prepare(self, existing, target): self.target = target def __repr__(self): - return "".format(self.data['project_name'], self.name, self.data['build_path']) + return "".format( + self.data['project_name'], self.name, self.data['build_path'] + ) + class CompiledArchive(CompiledModel): def __init__(self, fqn, data): @@ -137,7 +149,10 @@ def prepare(self, existing, target): self.target = target def __repr__(self): - return "".format(self.data['project_name'], self.name, self.data['build_path']) + return "".format( + self.data['project_name'], self.name, self.data['build_path'] + ) + def make_compiled_model(fqn, data): run_type = data['dbt_run_type'] @@ -150,5 +165,3 @@ def make_compiled_model(fqn, data): return CompiledArchive(fqn, data) else: raise RuntimeError("invalid run_type given: {}".format(run_type)) - - diff --git a/dbt/config.py b/dbt/config.py index 8bbe77e220b..4edf2b11cf1 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -19,7 +19,8 @@ def read_config(profiles_dir): def send_anonymous_usage_stats(profiles_dir): config = read_config(profiles_dir) - if config is not None and config.get("send_anonymous_usage_stats") == False: + if config is not None \ + and not config.get("send_anonymous_usage_stats", True): return False return True diff --git a/dbt/deprecations.py b/dbt/deprecations.py index e6a26fc6341..676c5dce849 100644 --- a/dbt/deprecations.py +++ b/dbt/deprecations.py @@ -1,5 +1,6 @@ from dbt.logger import GLOBAL_LOGGER as logger + class DBTDeprecation(object): name = None description = None @@ -10,21 +11,27 @@ def show(self, *args, **kwargs): logger.info("* Deprecation Warning: {}\n".format(desc)) active_deprecations.add(self.name) + class DBTRunTargetDeprecation(DBTDeprecation): name = 'run-target' - description = """profiles.yml configuration option 'run-target' is deprecated. Please use 'target' instead. - The 'run-target' option will be removed (in favor of 'target') in DBT version 0.7.0""" + description = """profiles.yml configuration option 'run-target' is + deprecated. Please use 'target' instead. The 'run-target' option will be + removed (in favor of 'target') in DBT version 0.7.0""" + class DBTInvalidPackageName(DBTDeprecation): name = 'invalid-package-name' - description = """The package name '{package_name}' is not valid. Package names must only contain letters and underscores. - Packages with invalid names will fail to compile in DBT version 0.7.0""" + description = """The package name '{package_name}' is not valid. Package + names must only contain letters and underscores. Packages with invalid + names will fail to compile in DBT version 0.7.0""" def warn(name, *args, **kwargs): if name not in deprecations: # this should (hopefully) never happen - raise RuntimeError("Error showing deprecation warning: {}".format(name)) + raise RuntimeError( + "Error showing deprecation warning: {}".format(name) + ) deprecations[name].show(*args, **kwargs) @@ -39,7 +46,8 @@ def warn(name, *args, **kwargs): DBTInvalidPackageName() ] -deprecations = {d.name : d for d in deprecations_list} +deprecations = {d.name: d for d in deprecations_list} + def reset_deprecations(): active_deprecations.clear() diff --git a/dbt/linker.py b/dbt/linker.py index e874d355848..000dc183713 100644 --- a/dbt/linker.py +++ b/dbt/linker.py @@ -1,7 +1,7 @@ - import networkx as nx from collections import defaultdict + class Linker(object): def __init__(self, data=None): if data is None: @@ -22,15 +22,27 @@ def as_topological_ordering(self, limit_to=None): try: return nx.topological_sort(self.graph, nbunch=limit_to) except KeyError as e: - raise RuntimeError("Couldn't find model '{}' -- does it exist or is it diabled?".format(e)) + raise RuntimeError( + "Couldn't find model '{}' -- does it exist or is it " + "disabled?".format(e) + ) + except nx.exception.NetworkXUnfeasible as e: - cycle = " --> ".join([".".join(node) for node in nx.algorithms.find_cycle(self.graph)[0]]) - raise RuntimeError("Can't compile -- cycle exists in model graph\n{}".format(cycle)) + cycle = " --> ".join( + [".".join(node) for node in + nx.algorithms.find_cycle(self.graph)[0]] + ) + raise RuntimeError( + "Can't compile -- cycle exists in model graph\n" + "{}".format(cycle) + ) def as_dependency_list(self, limit_to=None): - """returns a list of list of nodes, eg. [[0,1], [2], [4,5,6]]. Each element contains nodes whose - dependenices are subsumed by the union of all lists before it. In this way, all nodes in list `i` - can be run simultaneously assuming that all lists before list `i` have been completed""" + """returns a list of list of nodes, eg. [[0,1], [2], [4,5,6]]. Each + element contains nodes whose dependenices are subsumed by the union of + all lists before it. In this way, all nodes in list `i` can be run + simultaneously assuming that all lists before list `i` have been + completed""" if limit_to is None: graph_nodes = set(self.graph.nodes()) @@ -41,7 +53,10 @@ def as_dependency_list(self, limit_to=None): if node in self.graph: graph_nodes.update(nx.descendants(self.graph, node)) else: - raise RuntimeError("Couldn't find model '{}' -- does it exist or is it diabled?".format(node)) + raise RuntimeError( + "Couldn't find model '{}' -- does it exist or is " + "it disabled?".format(node) + ) depth_nodes = defaultdict(list) @@ -83,4 +98,3 @@ def write_graph(self, outfile): def read_graph(self, infile): self.graph = nx.read_yaml(infile) - diff --git a/dbt/main.py b/dbt/main.py index bdcffa158d3..93d0b2b7618 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -20,6 +20,7 @@ import dbt.tracking import dbt.config as config + def main(args=None): if args is None: args = sys.argv[1:] @@ -32,12 +33,14 @@ def main(args=None): logger.info(str(e)) sys.exit(1) + def handle(args): parsed = parse_args(args) initialize_logger(parsed.debug) - # this needs to happen after args are parsed so we can determine the correct profiles.yml file + # this needs to happen after args are parsed so we can determine the + # correct profiles.yml file if not config.send_anonymous_usage_stats(parsed.profiles_dir): dbt.tracking.do_not_track() @@ -46,6 +49,7 @@ def handle(args): return res + def get_nearest_project_dir(): root_path = os.path.abspath(os.sep) cwd = os.getcwd() @@ -58,6 +62,7 @@ def get_nearest_project_dir(): return None + def run_from_args(parsed): task = None proj = None @@ -66,10 +71,12 @@ def run_from_args(parsed): # bypass looking for a project file if we're running `dbt init` task = parsed.cls(args=parsed) else: - nearest_project_dir = get_nearest_project_dir() if nearest_project_dir is None: - raise RuntimeError("fatal: Not a dbt project (or any of the parent directories). Missing dbt_project.yml file") + raise RuntimeError( + "fatal: Not a dbt project (or any of the parent directories). " + "Missing dbt_project.yml file" + ) os.chdir(nearest_project_dir) @@ -82,22 +89,34 @@ def run_from_args(parsed): dbt.tracking.track_invocation_start(project=proj, args=parsed) try: return task.run() - dbt.tracking.track_invocation_end(project=proj, args=parsed, result_type="ok", result=None) + dbt.tracking.track_invocation_end( + project=proj, args=parsed, result_type="ok", result=None + ) except Exception as e: - dbt.tracking.track_invocation_end(project=proj, args=parsed, result_type="error", result=str(e)) + dbt.tracking.track_invocation_end( + project=proj, args=parsed, result_type="error", result=str(e) + ) raise + def invoke_dbt(parsed): task = None proj = None try: - proj = project.read_project('dbt_project.yml', parsed.profiles_dir, validate=False, profile_to_load=parsed.profile) + proj = project.read_project( + 'dbt_project.yml', + parsed.profiles_dir, + validate=False, + profile_to_load=parsed.profile + ) proj.validate() except project.DbtProjectError as e: logger.info("Encountered an error while reading the project:") logger.info(" ERROR {}".format(str(e))) - logger.info("Did you set the correct --profile? Using: {}".format(parsed.profile)) + logger.info( + "Did you set the correct --profile? Using: {}" + .format(parsed.profile)) logger.info("Valid profiles:") @@ -105,7 +124,12 @@ def invoke_dbt(parsed): for profile in all_profiles: logger.info(" - {}".format(profile)) - dbt.tracking.track_invalid_invocation(project=proj, args=parsed, result_type="invalid_profile", result=str(e)) + dbt.tracking.track_invalid_invocation( + project=proj, + args=parsed, + result_type="invalid_profile", + result=str(e)) + return None if parsed.target is not None: @@ -114,9 +138,16 @@ def invoke_dbt(parsed): proj.cfg['target'] = parsed.target else: logger.info("Encountered an error while reading the project:") - logger.info(" ERROR Specified target {} is not a valid option for profile {}".format(parsed.target, proj.profile_to_load)) + logger.info(" ERROR Specified target {} is not a valid option " + "for profile {}" + .format(parsed.target, proj.profile_to_load)) logger.info("Valid targets are: {}".format(targets)) - dbt.tracking.track_invalid_invocation(project=proj, args=parsed, result_type="invalid_target", result="target not found") + dbt.tracking.track_invalid_invocation( + project=proj, + args=parsed, + result_type="invalid_target", + result="target not found") + return None log_dir = proj.get('log-path', 'logs') @@ -127,17 +158,54 @@ def invoke_dbt(parsed): return task, proj + def parse_args(args): - p = argparse.ArgumentParser(prog='dbt: data build tool', formatter_class=argparse.RawTextHelpFormatter) - p.add_argument('--version', action='version', version=dbt.version.get_version_information(), help="Show version information") - p.add_argument('-d', '--debug', action='store_true', help='Display debug logging during dbt execution. Useful for debugging and making bug reports.') + p = argparse.ArgumentParser( + prog='dbt: data build tool', + formatter_class=argparse.RawTextHelpFormatter + ) + + p.add_argument( + '--version', + action='version', + version=dbt.version.get_version_information(), + help="Show version information" + ) + p.add_argument( + '-d', + '--debug', + action='store_true', + help='''Display debug logging during dbt execution. Useful for + debugging and making bug reports.''') subs = p.add_subparsers() base_subparser = argparse.ArgumentParser(add_help=False) - base_subparser.add_argument('--profiles-dir', default=project.default_profiles_dir, type=str, help='Which dir to look in for the profiles.yml file. Default = {}'.format(project.default_profiles_dir)) - base_subparser.add_argument('--profile', required=False, type=str, help='Which profile to load (overrides profile setting in dbt_project.yml file)') - base_subparser.add_argument('--target', default=None, type=str, help='Which target to load for the given profile') + + base_subparser.add_argument( + '--profiles-dir', + default=project.default_profiles_dir, + type=str, + help=""" + Which directory to look in for the profiles.yml file. Default = {} + """.format(project.default_profiles_dir) + ) + + base_subparser.add_argument( + '--profile', + required=False, + type=str, + help=""" + Which profile to load. Overrides setting in dbt_project.yml. + """ + ) + + base_subparser.add_argument( + '--target', + default=None, + type=str, + help='Which target to load for the given profile' + ) sub = subs.add_parser('init', parents=[base_subparser]) sub.add_argument('project_name', type=str, help='Name of the new project') @@ -147,9 +215,27 @@ def parse_args(args): sub.set_defaults(cls=clean_task.CleanTask, which='clean') sub = subs.add_parser('compile', parents=[base_subparser]) - sub.add_argument('--dry', action='store_true', help="Compile 'dry run' models") - sub.add_argument('--non-destructive', action='store_true', help="If specified, DBT will not drop views. Tables will be truncated instead of dropped. ") - sub.add_argument('--full-refresh', action='store_true', help="If specified, DBT will drop incremental models and fully-recalculate the incremental table from the model definition.") + sub.add_argument( + '--dry', + action='store_true', + help="Compile 'dry run' models" + ) + sub.add_argument( + '--non-destructive', + action='store_true', + help=""" + If specified, DBT will not drop views. Tables will be truncated instead + of dropped. + """ + ) + sub.add_argument( + '--full-refresh', + action='store_true', + help=""" + If specified, DBT will drop incremental models and fully-recalculate + the incremental table from the model definition. + """ + ) sub.set_defaults(cls=compile_task.CompileTask, which='compile') sub = subs.add_parser('debug', parents=[base_subparser]) @@ -159,25 +245,83 @@ def parse_args(args): sub.set_defaults(cls=deps_task.DepsTask, which='deps') sub = subs.add_parser('archive', parents=[base_subparser]) - sub.add_argument('--threads', type=int, required=False, help="Specify number of threads to use while archiving tables. Overrides settings in profiles.yml") + sub.add_argument( + '--threads', + type=int, + required=False, + help=""" + Specify number of threads to use while archiving tables. Overrides + settings in profiles.yml. + """ + ) sub.set_defaults(cls=archive_task.ArchiveTask, which='archive') sub = subs.add_parser('run', parents=[base_subparser]) sub.add_argument('--dry', action='store_true', help="'dry run' models") - sub.add_argument('--models', required=False, nargs='+', help="Specify the models to run. All models depending on these models will also be run") - sub.add_argument('--threads', type=int, required=False, help="Specify number of threads to use while executing models. Overrides settings in profiles.yml") - sub.add_argument('--non-destructive', action='store_true', help="If specified, DBT will not drop views. Tables will be truncated instead of dropped. ") - sub.add_argument('--full-refresh', action='store_true', help="If specified, DBT will drop incremental models and fully-recalculate the incremental table from the model definition.") + sub.add_argument( + '--models', + required=False, + nargs='+', + help=""" + Specify the models to run. All models depending on these models will + also be run. + """ + ) + sub.add_argument( + '--threads', + type=int, + required=False, + help=""" + Specify number of threads to use while executing models. Overrides + settings in profiles.yml. + """ + ) + sub.add_argument( + '--non-destructive', + action='store_true', + help=""" + If specified, DBT will not drop views. Tables will be truncated instead + of dropped. + """ + ) + sub.add_argument( + '--full-refresh', + action='store_true', + help=""" + If specified, DBT will drop incremental models and fully-recalculate + the incremental table from the model definition. + """) sub.set_defaults(cls=run_task.RunTask, which='run') sub = subs.add_parser('seed', parents=[base_subparser]) - sub.add_argument('--drop-existing', action='store_true', help="Drop existing seed tables and recreate them") + sub.add_argument( + '--drop-existing', + action='store_true', + help="Drop existing seed tables and recreate them" + ) sub.set_defaults(cls=seed_task.SeedTask, which='seed') sub = subs.add_parser('test', parents=[base_subparser]) - sub.add_argument('--data', action='store_true', help='Run data tests defined in "tests" directory') - sub.add_argument('--schema', action='store_true', help='Run constraint validations from schema.yml files') - sub.add_argument('--threads', type=int, required=False, help="Specify number of threads to use while executing tests. Overrides settings in profiles.yml") + sub.add_argument( + '--data', + action='store_true', + help='Run data tests defined in "tests" directory.' + ) + sub.add_argument( + '--schema', + action='store_true', + help='Run constraint validations from schema.yml files' + ) + sub.add_argument( + '--threads', + type=int, + required=False, + help=""" + Specify number of threads to use while executing tests. Overrides + settings in profiles.yml + """ + ) + sub.set_defaults(cls=test_task.TestTask, which='test') if len(args) == 0: diff --git a/dbt/model.py b/dbt/model.py index a4e74b13ede..3832faf6bc1 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -8,37 +8,54 @@ import dbt.schema_tester import dbt.project import dbt.archival -from dbt.utils import deep_merge, DBTConfigKeys, compiler_error, compiler_warning +from dbt.utils import deep_merge, DBTConfigKeys, compiler_error, \ + compiler_warning + class SourceConfig(object): Materializations = ['view', 'table', 'incremental', 'ephemeral'] ConfigKeys = DBTConfigKeys - AppendListFields = ['pre-hook', 'post-hook'] + AppendListFields = ['pre-hook', 'post-hook'] ExtendDictFields = ['vars'] - ClobberFields = ['enabled', 'materialized', 'dist', 'sort', 'sql_where', 'unique_key', 'sort_type'] + ClobberFields = [ + 'enabled', + 'materialized', + 'dist', + 'sort', + 'sql_where', + 'unique_key', + 'sort_type' + ] def __init__(self, active_project, own_project, fqn): self.active_project = active_project self.own_project = own_project self.fqn = fqn - self.in_model_config = {} # the config options defined within the model + # the config options defined within the model + self.in_model_config = {} # make sure we categorize all configs - all_configs = self.AppendListFields + self.ExtendDictFields + self.ClobberFields + all_configs = self.AppendListFields + self.ExtendDictFields + \ + self.ClobberFields + for config in self.ConfigKeys: assert config in all_configs, config def _merge(self, *configs): merged_config = {} for config in configs: - intermediary_merged = deep_merge(merged_config.copy(), config.copy()) + intermediary_merged = deep_merge( + merged_config.copy(), config.copy() + ) + merged_config.update(intermediary_merged) return merged_config # this is re-evaluated every time `config` is called. - # we can cache it, but that complicates things. TODO : see how this fares performance-wise + # we can cache it, but that complicates things. + # TODO : see how this fares performance-wise @property def config(self): """ @@ -59,10 +76,14 @@ def config(self): cfg = self._merge(defaults, active_config, self.in_model_config) else: own_config = self.load_config_from_own_project() - cfg = self._merge(defaults, own_config, self.in_model_config, active_config) - - # mask this as a table if it's an incremental model with --full-refresh provided - if cfg.get('materialized') == 'incremental' and self.active_project.args.full_refresh: + cfg = self._merge( + defaults, own_config, self.in_model_config, active_config + ) + + # mask this as a table if it's an incremental model with + # --full-refresh provided + if cfg.get('materialized') == 'incremental' and \ + self.active_project.args.full_refresh: cfg['materialized'] = 'table' return cfg @@ -70,7 +91,8 @@ def config(self): def update_in_model_config(self, config): config = config.copy() - # make sure we're not clobbering an array of hooks with a single hook string + # make sure we're not clobbering an array of hooks with a single hook + # string hook_fields = ['pre-hook', 'post-hook'] for hook_field in hook_fields: if hook_field in config: @@ -91,16 +113,24 @@ def __get_hooks(self, relevant_configs, key): for hook in new_hooks: if type(hook) != str: name = ".".join(self.fqn) - compiler_error(None, "{} for model {} is not a string!".format(key, name)) + compiler_error(None, "{} for model {} is not a string!".format( + key, name + )) hooks.append(hook) return hooks def smart_update(self, mutable_config, new_configs): - relevant_configs = {key: new_configs[key] for key in new_configs if key in self.ConfigKeys} + relevant_configs = { + key: new_configs[key] for key + in new_configs if key in self.ConfigKeys + } + for key in SourceConfig.AppendListFields: new_hooks = self.__get_hooks(relevant_configs, key) - mutable_config[key].extend([h for h in new_hooks if h not in mutable_config[key]]) + mutable_config[key].extend([ + h for h in new_hooks if h not in mutable_config[key] + ]) for key in SourceConfig.ExtendDictFields: dict_val = relevant_configs.get(key, {}) @@ -113,7 +143,8 @@ def smart_update(self, mutable_config, new_configs): return relevant_configs def get_project_config(self, project): - # most configs are overwritten by a more specific config, but pre/post hooks are appended! + # most configs are overwritten by a more specific config, but pre/post + # hooks are appended! config = {} for k in SourceConfig.AppendListFields: config[k] = [] @@ -137,7 +168,12 @@ def get_project_config(self, project): # mutates config relevant_configs = self.smart_update(config, level_config) - clobber_configs = {k:v for (k,v) in relevant_configs.items() if k not in SourceConfig.AppendListFields and k not in SourceConfig.ExtendDictFields} + clobber_configs = { + k: v for (k, v) in relevant_configs.items() + if k not in SourceConfig.AppendListFields and + k not in SourceConfig.ExtendDictFields + } + config.update(clobber_configs) model_configs = model_configs[level] @@ -149,6 +185,7 @@ def load_config_from_own_project(self): def load_config_from_active_project(self): return self.get_project_config(self.active_project) + class DBTSource(object): dbt_run_type = 'base' @@ -175,10 +212,12 @@ def is_empty(self): def compile(self): raise RuntimeError("Not implemented!") - + def serialize(self): serialized = { - "build_path": os.path.join(self.project['target-path'], self.build_path()), + "build_path": os.path.join( + self.project['target-path'], self.build_path() + ), "source_path": self.filepath, "name": self.name, "tmp_name": self.tmp_name(), @@ -226,12 +265,19 @@ def is_view(self): def is_enabled(self): enabled = self.config['enabled'] if enabled not in (True, False): - compiler_error(self, "'enabled' config must be either True or False. '{}' given.".format(enabled)) + compiler_error( + self, + "'enabled' config must be either True or False. '{}' given." + .format(enabled) + ) return enabled @property def fqn(self): - "fully-qualified name for model. Includes all subdirs below 'models' path and the filename" + """ + fully-qualified name for model. Includes all subdirs below 'models' + path and the filename + """ parts = split_path(self.filepath) name, _ = os.path.splitext(parts[-1]) return [self.own_project['name']] + parts[1:-1] + [name] @@ -259,19 +305,25 @@ def rename_query(self, schema): "final_name": self.name } - return 'alter table "{schema}"."{tmp_name}" rename to "{final_name}"'.format(**opts) + return 'alter table "{schema}"."{tmp_name}" rename to "{final_name}"' \ + .format(**opts) @property def nice_name(self): return "{}.{}".format(self.fqn[0], self.fqn[-1]) + class Model(DBTSource): dbt_run_type = 'run' - def __init__(self, project, model_dir, rel_filepath, own_project, create_template): + def __init__( + self, project, model_dir, rel_filepath, own_project, create_template + ): self.prologue = [] self.create_template = create_template - super(Model, self).__init__(project, model_dir, rel_filepath, own_project) + super(Model, self).__init__( + project, model_dir, rel_filepath, own_project + ) def add_to_prologue(self, s): safe_string = s.replace('{{', 'DBT_EXPR(').replace('}}', ')') @@ -286,40 +338,55 @@ def sort_qualifier(self, model_config): return '' if (self.is_view or self.is_ephemeral) and 'sort' in model_config: - return '' + return '' sort_keys = model_config['sort'] sort_type = model_config.get('sort_type', 'compound') if type(sort_type) != str: - compiler_error(self, "The provided sort_type '{}' is not valid!".format(sort_type)) + compiler_error( + self, + "The provided sort_type '{}' is not valid!".format(sort_type) + ) sort_type = sort_type.strip().lower() valid_sort_types = ['compound', 'interleaved'] if sort_type not in valid_sort_types: - compiler_error(self, "Invalid sort_type given: {} -- must be one of {}".format(sort_type, valid_sort_types)) + compiler_error( + self, + "Invalid sort_type given: {} -- must be one of {}".format( + sort_type, valid_sort_types + ) + ) if type(sort_keys) == str: sort_keys = [sort_keys] # remove existing quotes in field name, then wrap in quotes - formatted_sort_keys = ['"{}"'.format(sort_key.replace('"', '')) for sort_key in sort_keys] + formatted_sort_keys = [ + '"{}"'.format(sort_key.replace('"', '')) for sort_key in sort_keys + ] keys_csv = ', '.join(formatted_sort_keys) - return "{sort_type} sortkey ({keys_csv})".format(sort_type=sort_type, keys_csv=keys_csv) + return "{sort_type} sortkey ({keys_csv})".format( + sort_type=sort_type, keys_csv=keys_csv + ) def dist_qualifier(self, model_config): if 'dist' not in model_config: return '' if (self.is_view or self.is_ephemeral) and 'dist' in model_config: - return '' + return '' dist_key = model_config['dist'] if type(dist_key) != str: - compiler_error(self, "The provided distkey '{}' is not valid!".format(dist_key)) + compiler_error( + self, + "The provided distkey '{}' is not valid!".format(dist_key) + ) dist_key = dist_key.strip().lower() @@ -346,7 +413,9 @@ def compile_string(self, ctx, string): return string try: - fs_loader = jinja2.FileSystemLoader(searchpath=self.project['macro-paths']) + fs_loader = jinja2.FileSystemLoader( + searchpath=self.project['macro-paths'] + ) env = jinja2.Environment(loader=fs_loader) template = env.from_string(string, globals=ctx) return template.render(ctx) @@ -366,7 +435,11 @@ def compile(self, rendered_query, project, create_template, ctx): model_config = self.config if self.materialization not in SourceConfig.Materializations: - compiler_error(self, "Invalid materialize option given: '{}'. Must be one of {}".format(self.materialization, SourceConfig.Materializations)) + compiler_error( + self, + "Invalid materialize option given: '{}'. Must be one of {}" + .format(self.materialization, SourceConfig.Materializations) + ) schema = ctx['env'].get('schema', 'public') @@ -377,7 +450,11 @@ def compile(self, rendered_query, project, create_template, ctx): if self.materialization == 'incremental': identifier = self.name if 'sql_where' not in model_config: - compiler_error(self, "sql_where not specified in model materialized as incremental") + compiler_error( + self, + """sql_where not specified in model materialized as + incremental""" + ) raw_sql_where = model_config['sql_where'] sql_where = self.compile_string(ctx, raw_sql_where) @@ -387,7 +464,7 @@ def compile(self, rendered_query, project, create_template, ctx): sql_where = None unique_key = None - pre_hooks = self.get_hooks(ctx, 'pre-hook') + pre_hooks = self.get_hooks(ctx, 'pre-hook') post_hooks = self.get_hooks(ctx, 'post-hook') opts = { @@ -399,9 +476,9 @@ def compile(self, rendered_query, project, create_template, ctx): "sort_qualifier": sort_qualifier, "sql_where": sql_where, "prologue": self.get_prologue_string(), - "unique_key" : unique_key, - "pre-hooks" : pre_hooks, - "post-hooks" : post_hooks, + "unique_key": unique_key, + "pre-hooks": pre_hooks, + "post-hooks": post_hooks, "non_destructive": self.is_non_destructive() } @@ -419,11 +496,20 @@ def cte_name(self): return "__dbt__CTE__{}".format(self.name) def __repr__(self): - return "".format(self.project['name'], self.name, self.filepath) + return "".format( + self.project['name'], self.name, self.filepath + ) + class Analysis(Model): def __init__(self, project, target_dir, rel_filepath, own_project): - return super(Analysis, self).__init__(project, target_dir, rel_filepath, own_project, BaseCreateTemplate()) + return super(Analysis, self).__init__( + project, + target_dir, + rel_filepath, + own_project, + BaseCreateTemplate() + ) def build_path(self): build_dir = 'build-analysis' @@ -434,11 +520,21 @@ def build_path(self): def __repr__(self): return "".format(self.name, self.filepath) + class TestModel(Model): dbt_run_type = 'dry-run' - def __init__(self, project, target_dir, rel_filepath, own_project, create_template): - return super(TestModel, self).__init__(project, target_dir, rel_filepath, own_project, create_template) + def __init__( + self, + project, + target_dir, + rel_filepath, + own_project, + create_template + ): + return super(TestModel, self).__init__( + project, target_dir, rel_filepath, own_project, create_template + ) def build_path(self): build_dir = self.create_template.label @@ -448,7 +544,8 @@ def build_path(self): @property def fqn(self): - "fully-qualified name for model. Includes all subdirs below 'models' path and the filename" + """fully-qualified name for model. Includes all subdirs below 'models' + path and the filename""" parts = split_path(self.filepath) name, _ = os.path.splitext(parts[-1]) test_name = DryCreateTemplate.model_name(name) @@ -461,7 +558,10 @@ def original_fqn(self): return [self.project['name']] + parts[1:-1] + [name] def __repr__(self): - return "".format(self.project['name'], self.name, self.filepath) + return "".format( + self.project['name'], self.name, self.filepath + ) + class SchemaTest(DBTSource): test_type = "base" @@ -474,13 +574,16 @@ def __init__(self, project, target_dir, rel_filepath, model_name, options): self.options = options self.params = self.get_params(options) - super(SchemaTest, self).__init__(project, target_dir, rel_filepath, project) + super(SchemaTest, self).__init__( + project, target_dir, rel_filepath, project + ) @property def fqn(self): parts = split_path(self.filepath) name, _ = os.path.splitext(parts[-1]) - return [self.project['name']] + parts[1:-1] + ['schema', self.get_filename()] + return [self.project['name']] + parts[1:-1] + \ + ['schema', self.get_filename()] def serialize(self): serialized = DBTSource.serialize(self).copy() @@ -500,7 +603,9 @@ def unique_option_key(self): def get_filename(self): key = re.sub('[^0-9a-zA-Z]+', '_', self.unique_option_key()) - filename = "{test_type}_{model_name}_{key}".format(test_type=self.test_type, model_name=self.model_name, key=key) + filename = "{test_type}_{model_name}_{key}".format( + test_type=self.test_type, model_name=self.model_name, key=key + ) return filename def build_path(self): @@ -518,7 +623,10 @@ def render(self): def __repr__(self): class_name = self.__class__.__name__ - return "<{} {}.{}: {}>".format(class_name, self.project['name'], self.name, self.filepath) + return "<{} {}.{}: {}>".format( + class_name, self.project['name'], self.name, self.filepath + ) + class NotNullSchemaTest(SchemaTest): template = dbt.schema_tester.QUERY_VALIDATE_NOT_NULL @@ -528,7 +636,9 @@ def unique_option_key(self): return self.params['field'] def describe(self): - return 'VALIDATE NOT NULL {schema}.{table}.{field}'.format(**self.params) + return 'VALIDATE NOT NULL {schema}.{table}.{field}' \ + .format(**self.params) + class UniqueSchemaTest(SchemaTest): template = dbt.schema_tester.QUERY_VALIDATE_UNIQUE @@ -540,6 +650,7 @@ def unique_option_key(self): def describe(self): return 'VALIDATE UNIQUE {schema}.{table}.{field}'.format(**self.params) + class ReferentialIntegritySchemaTest(SchemaTest): template = dbt.schema_tester.QUERY_VALIDATE_REFERENTIAL_INTEGRITY test_type = "relationships" @@ -554,10 +665,14 @@ def get_params(self, options): } def unique_option_key(self): - return "{child_field}_to_{parent_table}_{parent_field}".format(**self.params) + return "{child_field}_to_{parent_table}_{parent_field}" \ + .format(**self.params) def describe(self): - return 'VALIDATE REFERENTIAL INTEGRITY {schema}.{child_table}.{child_field} to {schema}.{parent_table}.{parent_field}'.format(**self.params) + return """VALIDATE REFERENTIAL INTEGRITY + {schema}.{child_table}.{child_field} to + {schema}.{parent_table}.{parent_field}""".format(**self.params) + class AcceptedValuesSchemaTest(SchemaTest): template = dbt.schema_tester.QUERY_VALIDATE_ACCEPTED_VALUES @@ -568,8 +683,8 @@ def get_params(self, options): quoted_values_csv = ",".join(quoted_values) return { "schema": self.schema, - "table" : self.model_name, - "field" : options['field'], + "table": self.model_name, + "field": options['field'], "values_csv": quoted_values_csv } @@ -577,7 +692,10 @@ def unique_option_key(self): return "{field}".format(**self.params) def describe(self): - return 'VALIDATE ACCEPTED VALUES {schema}.{table}.{field} VALUES ({values_csv})'.format(**self.params) + return """VALIDATE ACCEPTED VALUES + {schema}.{table}.{field} VALUES + ({values_csv})""".format(**self.params) + class SchemaFile(DBTSource): SchemaTestMap = { @@ -588,7 +706,9 @@ class SchemaFile(DBTSource): } def __init__(self, project, target_dir, rel_filepath, own_project): - super(SchemaFile, self).__init__(project, target_dir, rel_filepath, own_project) + super(SchemaFile, self).__init__( + project, target_dir, rel_filepath, own_project + ) self.og_target_dir = target_dir self.schema = yaml.safe_load(self.contents) @@ -597,7 +717,11 @@ def get_test(self, test_type): return SchemaFile.SchemaTestMap[test_type] else: possible_types = ", ".join(SchemaFile.SchemaTestMap.keys()) - compiler_error(self, "Invalid validation type given in {}: '{}'. Possible: {}".format(self.filepath, test_type, possible_types)) + compiler_error( + self, + "Invalid validation type given in {}: '{}'. Possible: {}" + .format(self.filepath, test_type, possible_types) + ) def do_compile(self): schema_tests = [] @@ -605,10 +729,20 @@ def do_compile(self): constraints = constraint_blob.get('constraints', {}) for constraint_type, constraint_data in constraints.items(): if constraint_data is None: - compiler_error(self, "no constraints given to test: '{}.{}'".format(model_name, constraint_type)) + compiler_error( + self, + "no constraints given to test: '{}.{}'" + .format(model_name, constraint_type) + ) for params in constraint_data: schema_test_klass = self.get_test(constraint_type) - schema_test = schema_test_klass(self.project, self.og_target_dir, self.rel_filepath, model_name, params) + schema_test = schema_test_klass( + self.project, + self.og_target_dir, + self.rel_filepath, + model_name, + params + ) schema_tests.append(schema_test) return schema_tests @@ -621,18 +755,28 @@ def compile(self): compiler_error(self, str(e)) def __repr__(self): - return "".format(self.project['name'], self.model_name, self.filepath) + return "".format( + self.project['name'], self.model_name, self.filepath + ) + class Csv(DBTSource): def __init__(self, project, target_dir, rel_filepath, own_project): - super(Csv, self).__init__(project, target_dir, rel_filepath, own_project) + super(Csv, self).__init__( + project, target_dir, rel_filepath, own_project + ) def __repr__(self): - return "".format(self.project['name'], self.model_name, self.filepath) + return "".format( + self.project['name'], self.model_name, self.filepath + ) + class Macro(DBTSource): def __init__(self, project, target_dir, rel_filepath, own_project): - super(Macro, self).__init__(project, target_dir, rel_filepath, own_project) + super(Macro, self).__init__( + project, target_dir, rel_filepath, own_project + ) self.filepath = os.path.join(self.root_dir, self.rel_filepath) def get_macros(self, ctx): @@ -647,7 +791,9 @@ def get_macros(self, ctx): yield {"project": self.own_project, "name": key, "macro": item} def __repr__(self): - return "".format(self.project['name'], self.name, self.filepath) + return "".format( + self.project['name'], self.name, self.filepath + ) class ArchiveModel(DBTSource): @@ -661,15 +807,17 @@ def __init__(self, project, create_template, archive_data): self.source_schema = archive_data['source_schema'] self.target_schema = archive_data['target_schema'] - self.source_table = archive_data['source_table'] - self.target_table = archive_data['target_table'] - self.unique_key = archive_data['unique_key'] - self.updated_at = archive_data['updated_at'] + self.source_table = archive_data['source_table'] + self.target_table = archive_data['target_table'] + self.unique_key = archive_data['unique_key'] + self.updated_at = archive_data['updated_at'] target_dir = self.create_template.label rel_filepath = os.path.join(self.target_schema, self.target_table) - super(ArchiveModel, self).__init__(project, target_dir, rel_filepath, project) + super(ArchiveModel, self).__init__( + project, target_dir, rel_filepath, project + ) def validate(self, data): required = [ @@ -683,18 +831,21 @@ def validate(self, data): for key in required: if data.get(key, None) is None: - compiler_error("Invalid archive config: missing required field '{}'".format(key)) + compiler_error( + "Invalid archive config: missing required field '{}'" + .format(key) + ) def serialize(self): data = DBTSource.serialize(self).copy() serialized = { - "source_schema" : self.source_schema, - "target_schema" : self.target_schema, - "source_table" : self.source_table, - "target_table" : self.target_table, - "unique_key" : self.unique_key, - "updated_at" : self.updated_at + "source_schema": self.source_schema, + "target_schema": self.target_schema, + "source_table": self.source_table, + "target_table": self.target_table, + "unique_key": self.unique_key, + "updated_at": self.updated_at } data.update(serialized) @@ -704,7 +855,10 @@ def compile(self): archival = dbt.archival.Archival(self.project, self) query = archival.compile() - sql = self.create_template.wrap(self.target_schema, self.target_table, query, self.unique_key) + sql = self.create_template.wrap( + self.target_schema, self.target_table, query, self.unique_key + ) + return sql def build_path(self): @@ -714,14 +868,25 @@ def build_path(self): return os.path.join(*path_parts) def __repr__(self): - return " {} unique:{} updated_at:{}>".format(self.source_table, self.target_table, self.unique_key, self.updated_at) + return " {} unique:{} updated_at:{}>".format( + self.source_table, + self.target_table, + self.unique_key, + self.updated_at + ) + class DataTest(DBTSource): dbt_run_type = 'test' dbt_test_type = 'data' def __init__(self, project, target_dir, rel_filepath, own_project): - super(DataTest, self).__init__(project, target_dir, rel_filepath, own_project) + super(DataTest, self).__init__( + project, + target_dir, + rel_filepath, + own_project + ) def build_path(self): build_dir = "test" @@ -744,4 +909,6 @@ def immediate_name(self): return self.name def __repr__(self): - return "".format(self.project['name'], self.name, self.filepath) + return "".format( + self.project['name'], self.name, self.filepath + ) diff --git a/dbt/project.py b/dbt/project.py index c25b41bf52b..717d3ed402c 100644 --- a/dbt/project.py +++ b/dbt/project.py @@ -26,11 +26,13 @@ default_profiles_dir = os.path.join(os.path.expanduser('~'), '.dbt') + class DbtProjectError(Exception): def __init__(self, message, project): self.project = project super(DbtProjectError, self).__init__(message) + class Project(object): def __init__(self, cfg, profiles, profiles_dir, profile_to_load=None): @@ -46,12 +48,16 @@ def __init__(self, cfg, profiles, profiles_dir, profile_to_load=None): self.profile_to_load = self.cfg['profile'] if self.profile_to_load is None: - raise DbtProjectError("No profile was supplied in the dbt_project.yml file, or the command line", self) + raise DbtProjectError( + "No profile was supplied in the dbt_project.yml file, or the " + "command line", self) if self.profile_to_load in self.profiles: self.cfg.update(self.profiles[self.profile_to_load]) else: - raise DbtProjectError("Could not find profile named '{}'".format(self.profile_to_load), self) + raise DbtProjectError( + "Could not find profile named '{}'" + .format(self.profile_to_load), self) def __str__(self): return pprint.pformat({'project': self.cfg, 'profiles': self.profiles}) @@ -77,7 +83,8 @@ def handle_deprecations(self): self.cfg['target'] = self.cfg['run-target'] if not self.is_valid_package_name(): - dbt.deprecations.warn('invalid-package-name', package_name = self['name']) + dbt.deprecations.warn( + 'invalid-package-name', package_name=self['name']) def is_valid_package_name(self): if re.match(r"^[^\d\W]\w*\Z", self['name']): @@ -110,13 +117,16 @@ def validate(self): package_version = self.cfg.get('version', None) if package_name is None or package_version is None: - raise DbtProjectError("Project name and version is not provided", self) + raise DbtProjectError( + "Project name and version is not provided", self) - required_keys = ['host', 'user', 'pass', 'schema', 'type', 'dbname', 'port'] + required_keys = ['host', 'user', 'pass', 'schema', 'type', + 'dbname', 'port'] for key in required_keys: if key not in target_cfg or len(str(target_cfg[key])) == 0: - raise DbtProjectError("Expected project configuration '{}' was not supplied".format(key), self) - + raise DbtProjectError( + "Expected project configuration '{}' was not supplied" + .format(key), self) def hashed_name(self): if self.cfg.get("name", None) is None: @@ -138,18 +148,22 @@ def read_profiles(profiles_dir=None): if os.path.isfile(path): with open(path, 'r') as f: m = yaml.safe_load(f) - valid_profiles = {k:v for (k,v) in m.items() if k != 'config'} + valid_profiles = {k: v for (k, v) in m.items() + if k != 'config'} profiles.update(valid_profiles) return profiles -def read_project(filename, profiles_dir=None, validate=True, profile_to_load=None): + +def read_project(filename, profiles_dir=None, validate=True, + profile_to_load=None): if profiles_dir is None: profiles_dir = default_profiles_dir with open(filename, 'r') as f: project_cfg = yaml.safe_load(f) - project_cfg['project-root'] = os.path.dirname(os.path.abspath(filename)) + project_cfg['project-root'] = os.path.dirname( + os.path.abspath(filename)) profiles = read_profiles(profiles_dir) proj = Project(project_cfg, profiles, profiles_dir, profile_to_load) diff --git a/dbt/runner.py b/dbt/runner.py index ec6ffb2ffa2..15049676f4d 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -2,7 +2,8 @@ from __future__ import print_function import psycopg2 -import os, sys +import os +import sys import logging import time import itertools @@ -16,23 +17,28 @@ from dbt.templates import BaseCreateTemplate import dbt.targets from dbt.source import Source -from dbt.utils import find_model_by_fqn, find_model_by_name, dependency_projects +from dbt.utils import find_model_by_fqn, find_model_by_name, \ + dependency_projects from dbt.compiled_model import make_compiled_model import dbt.tracking import dbt.schema from multiprocessing.dummy import Pool as ThreadPool -ABORTED_TRANSACTION_STRING = "current transaction is aborted, commands ignored until end of transaction block" +ABORTED_TRANSACTION_STRING = ("current transaction is aborted, commands " + "ignored until end of transaction block") + def get_timestamp(): return "{} |".format(time.strftime("%H:%M:%S")) + class RunModelResult(object): - def __init__(self, model, error=None, skip=False, status=None, execution_time=0): + def __init__(self, model, error=None, skip=False, status=None, + execution_time=0): self.model = model self.error = error - self.skip = skip + self.skip = skip self.status = status self.execution_time = execution_time @@ -44,6 +50,7 @@ def errored(self): def skipped(self): return self.skip + class BaseRunner(object): def __init__(self, project, schema_helper): self.project = project @@ -82,11 +89,13 @@ def execute_list(self, queries, source): status = 'None' for i, query in enumerate(queries): try: - handle, status = self.schema_helper.execute_without_auto_commit(query, handle) + handle, status = self.schema_helper.execute_without_auto_commit(query, handle) # noqa except psycopg2.ProgrammingError as e: error_msg = e.diag.message_primary - if error_msg is not None and "permission denied for" in error_msg: - raise RuntimeError("Permission denied while running {}".format(source)) + if error_msg is not None and \ + "permission denied for" in error_msg: + raise RuntimeError( + "Permission denied while running {}".format(source)) else: raise @@ -107,28 +116,30 @@ def execute_contents(self, target, model): kwargs = instruction['args'] func_map = { - 'expand_column_types_if_needed': lambda kwargs: self.schema_helper.expand_column_types_if_needed(**kwargs), + 'expand_column_types_if_needed': lambda kwargs: self.schema_helper.expand_column_types_if_needed(**kwargs), # noqa } func_map[function](kwargs) else: try: - handle, status = self.schema_helper.execute_without_auto_commit(part, handle) + handle, status = self.schema_helper.execute_without_auto_commit(part, handle) # noqa except psycopg2.ProgrammingError as e: if "permission denied for" in e.diag.message_primary: - raise RuntimeError(dbt.schema.READ_PERMISSION_DENIED_ERROR.format( - model=model.name, - error=str(e).strip(), - user=target.user, - )) + raise RuntimeError( + dbt.schema.READ_PERMISSION_DENIED_ERROR.format( + model=model.name, + error=str(e).strip(), + user=target.user)) else: raise handle.commit() return status + class ModelRunner(BaseRunner): run_type = 'run' + def pre_run_msg(self, model): print_vars = { "schema": model.target.schema, @@ -137,7 +148,8 @@ def pre_run_msg(self, model): "info": "START" } - output = "START {model_type} model {schema}.{model_name} ".format(**print_vars) + output = ("START {model_type} model {schema}.{model_name} " + .format(**print_vars)) return output def post_run_msg(self, result): @@ -149,35 +161,46 @@ def post_run_msg(self, result): "info": "ERROR creating" if result.errored else "OK created" } - output = "{info} {model_type} model {schema}.{model_name} ".format(**print_vars) + output = ("{info} {model_type} model {schema}.{model_name} " + .format(**print_vars)) return output def pre_run_all_msg(self, models): return "{} Running {} models".format(get_timestamp(), len(models)) def post_run_all_msg(self, results): - return "{} Finished running {} models".format(get_timestamp(), len(results)) + return ("{} Finished running {} models" + .format(get_timestamp(), len(results))) def status(self, result): return result.status def execute(self, target, model): if model.tmp_drop_type is not None: - if model.materialization == 'table' and self.project.args.non_destructive: + if model.materialization == 'table' and \ + self.project.args.non_destructive: self.schema_helper.truncate(target.schema, model.tmp_name) else: - self.schema_helper.drop(target.schema, model.tmp_drop_type, model.tmp_name) + self.schema_helper.drop( + target.schema, model.tmp_drop_type, model.tmp_name) status = self.execute_contents(target, model) if model.final_drop_type is not None: - if model.materialization == 'table' and self.project.args.non_destructive: - pass # we just inserted into this recently truncated table... do nothing here + if model.materialization == 'table' and \ + self.project.args.non_destructive: + # we just inserted into this recently truncated table... + # do nothing here + pass else: - self.schema_helper.drop(target.schema, model.final_drop_type, model.name) + self.schema_helper.drop( + target.schema, model.final_drop_type, model.name) if model.should_rename(self.project.args): - self.schema_helper.rename(target.schema, model.tmp_name, model.name) + self.schema_helper.rename( + target.schema, + model.tmp_name, + model.name) return status @@ -186,10 +209,10 @@ def __run_hooks(self, hooks, context, source): hooks = [hooks] ctx = { - "target" : self.project.get_target(), - "state" : "start", - "invocation_id" : context['invocation_id'], - "run_started_at" : context['run_started_at'] + "target": self.project.get_target(), + "state": "start", + "invocation_id": context['invocation_id'], + "run_started_at": context['run_started_at'] } compiled_hooks = [compile_string(hook, ctx) for hook in hooks] @@ -203,23 +226,27 @@ def post_run_all(self, models, results, context): hooks = self.project.cfg.get('on-run-start', []) self.__run_hooks(hooks, context, 'on-run-end hooks') + class DryRunner(ModelRunner): run_type = 'dry-run' def pre_run_msg(self, model): - output = "DRY-RUN model {schema}.{model_name} ".format(schema=model.target.schema, model_name=model.name) + output = ("DRY-RUN model {schema}.{model_name} " + .format(schema=model.target.schema, model_name=model.name)) return output def post_run_msg(self, result): model = result.model - output = "DONE model {schema}.{model_name} ".format(schema=model.target.schema, model_name=model.name) + output = ("DONE model {schema}.{model_name} " + .format(schema=model.target.schema, model_name=model.name)) return output def pre_run_all_msg(self, models): return "Dry-running {} models".format(len(models)) def post_run_all_msg(self, results): - return "{} Finished dry-running {} models".format(get_timestamp(), len(results)) + return ("{} Finished dry-running {} models" + .format(get_timestamp(), len(results))) def post_run_all(self, models, results, context): count_dropped = 0 @@ -229,11 +256,13 @@ def post_run_all(self, models, results, context): model = result.model schema_name = model.target.schema - relation_type = 'table' if model.materialization == 'incremental' else 'view' + relation_type = ('table' if model.materialization == 'incremental' + else 'view') self.schema_helper.drop(schema_name, relation_type, model.name) count_dropped += 1 logger.info("Dropped {} dry-run models".format(count_dropped)) + class TestRunner(ModelRunner): run_type = 'test' @@ -257,14 +286,23 @@ def pre_run_all_msg(self, models): def post_run_all_msg(self, results): total = len(results) - passed = len([result for result in results if not result.errored and not result.skipped and result.status == 0]) - failed = len([result for result in results if not result.errored and not result.skipped and result.status > 0]) + passed = len([result for result in results if not + result.errored and not result.skipped and + result.status == 0]) + failed = len([result for result in results if not + result.errored and not result.skipped and + result.status > 0]) errored = len([result for result in results if result.errored]) skipped = len([result for result in results if result.skipped]) total_errors = failed + errored - overview = "PASS={passed} FAIL={total_errors} SKIP={skipped} TOTAL={total}".format(total=total, passed=passed, total_errors=total_errors, skipped=skipped) + overview = ("PASS={passed} FAIL={total_errors} SKIP={skipped} " + "TOTAL={total}".format( + total=total, + passed=passed, + total_errors=total_errors, + skipped=skipped)) if total_errors > 0: final = "Tests completed with errors" @@ -288,14 +326,19 @@ def status(self, result): def execute(self, target, model): rows = self.schema_helper.execute_and_fetch(model.compiled_contents) if len(rows) > 1: - raise RuntimeError("Bad test {name}: Returned {num_rows} rows instead of 1".format(name=model.name, num_rows=len(rows))) + raise RuntimeError( + "Bad test {name}: Returned {num_rows} rows instead of 1" + .format(name=model.name, num_rows=len(rows))) row = rows[0] if len(row) > 1: - raise RuntimeError("Bad test {name}: Returned {num_cols} cols instead of 1".format(name=model.name, num_cols=len(row))) + raise RuntimeError( + "Bad test {name}: Returned {num_cols} cols instead of 1" + .format(name=model.name, num_cols=len(row))) return row[0] + class ArchiveRunner(BaseRunner): run_type = 'archive' @@ -305,7 +348,8 @@ def pre_run_msg(self, model): "model_name": model.name, } - output = "START archive table {schema}.{model_name} ".format(**print_vars) + output = ("START archive table {schema}.{model_name} " + .format(**print_vars)) return output def post_run_msg(self, result): @@ -323,7 +367,8 @@ def pre_run_all_msg(self, models): return "Archiving {} tables".format(len(models)) def post_run_all_msg(self, results): - return "{} Finished archiving {} tables".format(get_timestamp(), len(results)) + return ("{} Finished archiving {} tables" + .format(get_timestamp(), len(results))) def status(self, result): return result.status @@ -332,6 +377,7 @@ def execute(self, target, model): status = self.execute_contents(target, model) return status + class RunManager(object): def __init__(self, project, target_path, graph_type, args): self.project = project @@ -339,26 +385,27 @@ def __init__(self, project, target_path, graph_type, args): self.graph_type = graph_type self.args = args - self.target = dbt.targets.get_target(self.project.run_environment(), self.args.threads) + self.target = dbt.targets.get_target( + self.project.run_environment(), + self.args.threads) if self.target.should_open_tunnel(): - logger.info("Opening ssh tunnel to host {}... ".format(self.target.ssh_host), end="") + logger.info("Opening ssh tunnel to host {}... " + .format(self.target.ssh_host), end="") sys.stdout.flush() self.target.open_tunnel_if_needed() logger.info("Connected") - self.schema = dbt.schema.Schema(self.project, self.target) self.context = { "run_started_at": datetime.now(), - "invocation_id" : dbt.tracking.invocation_id, - "get_columns_in_table" : self.schema.get_columns_in_table, - "get_missing_columns" : self.schema.get_missing_columns, - "already_exists" : self.schema.table_exists, + "invocation_id": dbt.tracking.invocation_id, + "get_columns_in_table": self.schema.get_columns_in_table, + "get_missing_columns": self.schema.get_missing_columns, + "already_exists": self.schema.table_exists, } - def deserialize_graph(self): linker = Linker() base_target_path = self.project['target-path'] @@ -382,22 +429,33 @@ def safe_execute_model(self, data): error = None try: status = self.execute_model(runner, model) - except (RuntimeError, psycopg2.ProgrammingError, psycopg2.InternalError) as e: - error = "Error executing {filepath}\n{error}".format(filepath=model['build_path'], error=str(e).strip()) + except (RuntimeError, + psycopg2.ProgrammingError, + psycopg2.InternalError) as e: + error = "Error executing {filepath}\n{error}".format( + filepath=model['build_path'], error=str(e).strip()) status = "ERROR" logger.exception(error) - if type(e) == psycopg2.InternalError and ABORTED_TRANSACTION_STRING == e.diag.message_primary: - return RunModelResult(model, error=ABORTED_TRANSACTION_STRING, status="SKIP") + if type(e) == psycopg2.InternalError and \ + ABORTED_TRANSACTION_STRING == e.diag.message_primary: + return RunModelResult( + model, error=ABORTED_TRANSACTION_STRING, status="SKIP") except Exception as e: - error = "Unhandled error while executing {filepath}\n{error}".format(filepath=model['build_path'], error=str(e).strip()) + error = ("Unhandled error while executing {filepath}\n{error}" + .format( + filepath=model['build_path'], error=str(e).strip())) logger.exception(error) raise e execution_time = time.time() - start_time - return RunModelResult(model, error=error, status=status, execution_time=execution_time) + return RunModelResult(model, + error=error, + status=status, + execution_time=execution_time) - def as_concurrent_dep_list(self, linker, models, existing, target, limit_to): + def as_concurrent_dep_list(self, linker, models, existing, target, + limit_to): model_dependency_list = [] dependency_list = linker.as_dependency_list(limit_to) for node_list in dependency_list: @@ -421,28 +479,40 @@ def skip_dependent(model): model_to_skip.do_skip() return skip_dependent - def print_fancy_output_line(self, message, status, index, total, execution_time=None): - prefix = "{timestamp} {index} of {total} {message}".format(timestamp=get_timestamp(), index=index, total=total, message=message) + def print_fancy_output_line(self, message, status, index, total, + execution_time=None): + prefix = "{timestamp} {index} of {total} {message}".format( + timestamp=get_timestamp(), + index=index, + total=total, + message=message) justified = prefix.ljust(80, ".") if execution_time is None: status_time = "" else: - status_time = " in {execution_time:0.2f}s".format(execution_time=execution_time) + status_time = " in {execution_time:0.2f}s".format( + execution_time=execution_time) - output = "{justified} [{status}{status_time}]".format(justified=justified, status=status, status_time=status_time) + output = "{justified} [{status}{status_time}]".format( + justified=justified, status=status, status_time=status_time) logger.info(output) def execute_models(self, runner, model_dependency_list, on_failure): - flat_models = list(itertools.chain.from_iterable(model_dependency_list)) + flat_models = list(itertools.chain.from_iterable( + model_dependency_list)) num_models = len(flat_models) if num_models == 0: - logger.info("WARNING: Nothing to do. Try checking your model configs and running `dbt compile`".format(self.target_path)) + logger.info("WARNING: Nothing to do. Try checking your model " + "configs and running `dbt compile`".format( + self.target_path)) return [] num_threads = self.target.threads - logger.info("Concurrency: {} threads (target='{}')".format(num_threads, self.project.get_target().get('name'))) + logger.info("Concurrency: {} threads (target='{}')".format( + num_threads, self.project.get_target().get('name')) + ) logger.info("Running!") pool = ThreadPool(num_threads) @@ -451,20 +521,24 @@ def execute_models(self, runner, model_dependency_list, on_failure): logger.info(runner.pre_run_all_msg(flat_models)) runner.pre_run_all(flat_models, self.context) - fqn_to_id_map = {model.fqn: i + 1 for (i, model) in enumerate(flat_models)} + fqn_to_id_map = {model.fqn: i + 1 for (i, model) + in enumerate(flat_models)} def get_idx(model): return fqn_to_id_map[model.fqn] model_results = [] for model_list in model_dependency_list: - for i, model in enumerate([model for model in model_list if model.should_skip()]): + for i, model in enumerate([model for model in model_list + if model.should_skip()]): msg = runner.skip_msg(model) - self.print_fancy_output_line(msg, 'SKIP', get_idx(model), num_models) + self.print_fancy_output_line( + msg, 'SKIP', get_idx(model), num_models) model_result = RunModelResult(model, skip=True) model_results.append(model_result) - models_to_execute = [model for model in model_list if not model.should_skip()] + models_to_execute = [model for model in model_list + if not model.should_skip()] threads = self.target.threads num_models_this_batch = len(models_to_execute) @@ -477,7 +551,13 @@ def on_complete(run_model_results): msg = runner.post_run_msg(run_model_result) status = runner.status(run_model_result) index = get_idx(run_model_result.model) - self.print_fancy_output_line(msg, status, index, num_models, run_model_result.execution_time) + self.print_fancy_output_line( + msg, + status, + index, + num_models, + run_model_result.execution_time + ) dbt.tracking.track_model_run({ "invocation_id": dbt.tracking.invocation_id, @@ -487,9 +567,9 @@ def on_complete(run_model_results): "run_status": run_model_result.status, "run_skipped": run_model_result.skip, "run_error": run_model_result.error, - "model_materialization": run_model_result.model['materialized'], + "model_materialization": run_model_result.model['materialized'], # noqa "model_id": run_model_result.model.hashed_name(), - "hashed_contents": run_model_result.model.hashed_contents(), + "hashed_contents": run_model_result.model.hashed_contents(), # noqa }) if run_model_result.errored: @@ -498,14 +578,25 @@ def on_complete(run_model_results): while model_index < num_models_this_batch: local_models = [] - for i in range(model_index, min(model_index + threads, num_models_this_batch)): + for i in range( + model_index, + min(model_index + threads, num_models_this_batch)): model = models_to_execute[i] local_models.append(model) msg = runner.pre_run_msg(model) - self.print_fancy_output_line(msg, 'RUN', get_idx(model), num_models) - - wrapped_models_to_execute = [{"runner": runner, "model": model} for model in local_models] - map_result = pool.map_async(self.safe_execute_model, wrapped_models_to_execute, callback=on_complete) + self.print_fancy_output_line( + msg, 'RUN', get_idx(model), num_models + ) + + wrapped_models_to_execute = [ + {"runner": runner, "model": model} + for model in local_models + ] + map_result = pool.map_async( + self.safe_execute_model, + wrapped_models_to_execute, + callback=on_complete + ) map_result.wait() run_model_results = map_result.get() @@ -523,36 +614,49 @@ def on_complete(run_model_results): def run_from_graph(self, runner, limit_to): logger.info("Loading dependency graph file") linker = self.deserialize_graph() - compiled_models = [make_compiled_model(fqn, linker.get_node(fqn)) for fqn in linker.nodes()] - relevant_compiled_models = [m for m in compiled_models if m.is_type(runner.run_type)] + compiled_models = [make_compiled_model(fqn, linker.get_node(fqn)) + for fqn in linker.nodes()] + relevant_compiled_models = [m for m in compiled_models + if m.is_type(runner.run_type)] for m in relevant_compiled_models: - if m.should_execute(self.args, existing = []): + if m.should_execute(self.args, existing=[]): context = self.context.copy() context.update(m.context()) m.compile(context) schema_name = self.target.schema - logger.info("Connecting to redshift") try: self.schema.create_schema_if_not_exists(schema_name) except psycopg2.OperationalError as e: - logger.info("ERROR: Could not connect to the target database. Try `dbt debug` for more information") + logger.info("ERROR: Could not connect to the target database. Try" + "`dbt debug` for more information") logger.info(str(e)) sys.exit(1) - existing = self.schema.query_for_existing(schema_name); + existing = self.schema.query_for_existing(schema_name) if limit_to is None: specified_models = None else: - specified_models = [find_model_by_name(relevant_compiled_models, name).fqn for name in limit_to] - model_dependency_list = self.as_concurrent_dep_list(linker, relevant_compiled_models, existing, self.target, specified_models) + specified_models = [find_model_by_name( + relevant_compiled_models, name + ).fqn for name in limit_to] + + model_dependency_list = self.as_concurrent_dep_list( + linker, + relevant_compiled_models, + existing, + self.target, + specified_models + ) on_failure = self.on_model_failure(linker, relevant_compiled_models) - results = self.execute_models(runner, model_dependency_list, on_failure) + results = self.execute_models( + runner, model_dependency_list, on_failure + ) return results @@ -568,10 +672,10 @@ def safe_run_from_graph(self, *args, **kwargs): self.target.cleanup() logger.info("Done") - def run_tests_from_graph(self, test_schemas, test_data): linker = self.deserialize_graph() - compiled_models = [make_compiled_model(fqn, linker.get_node(fqn)) for fqn in linker.nodes()] + compiled_models = [make_compiled_model(fqn, linker.get_node(fqn)) + for fqn in linker.nodes()] schema_name = self.target.schema @@ -579,26 +683,29 @@ def run_tests_from_graph(self, test_schemas, test_data): try: self.schema.create_schema_if_not_exists(schema_name) except psycopg2.OperationalError as e: - logger.info("ERROR: Could not connect to the target database. Try `dbt debug` for more information") + logger.info("ERROR: Could not connect to the target database. Try " + "`dbt debug` for more information") logger.info(str(e)) sys.exit(1) test_runner = TestRunner(self.project, self.schema) if test_schemas: - schema_tests = [m for m in compiled_models if m.is_test_type(test_runner.test_schema_type)] + schema_tests = [m for m in compiled_models + if m.is_test_type(test_runner.test_schema_type)] else: schema_tests = [] if test_data: - data_tests = [m for m in compiled_models if m.is_test_type(test_runner.test_data_type)] + data_tests = [m for m in compiled_models + if m.is_test_type(test_runner.test_data_type)] else: data_tests = [] all_tests = schema_tests + data_tests for m in all_tests: - if m.should_execute(self.args, existing = []): + if m.should_execute(self.args, existing=[]): context = self.context.copy() context.update(m.context()) m.compile(context) diff --git a/dbt/runtime.py b/dbt/runtime.py index 71fbffdb3c7..29e16b906f8 100644 --- a/dbt/runtime.py +++ b/dbt/runtime.py @@ -1,5 +1,6 @@ from dbt.utils import compiler_error + class RuntimeContext(dict): def __init__(self, model=None, *args, **kwargs): super(RuntimeContext, self).__init__(*args, **kwargs) diff --git a/dbt/schema.py b/dbt/schema.py index 26dc0840447..336ae944d97 100644 --- a/dbt/schema.py +++ b/dbt/schema.py @@ -6,23 +6,33 @@ import time import re -SCHEMA_PERMISSION_DENIED_MESSAGE = """The user '{user}' does not have sufficient permissions to create the schema '{schema}'. -Either create the schema manually, or adjust the permissions of the '{user}' user.""" +SCHEMA_PERMISSION_DENIED_MESSAGE = """ +The user '{user}' does not have sufficient permissions to create the schema +'{schema}'. Either create the schema manually, or adjust the permissions of +the '{user}' user.""" -RELATION_PERMISSION_DENIED_MESSAGE = """The user '{user}' does not have sufficient permissions to create the model '{model}' in the schema '{schema}'. -Please adjust the permissions of the '{user}' user on the '{schema}' schema. -With a superuser account, execute the following commands, then re-run dbt. +RELATION_PERMISSION_DENIED_MESSAGE = """ +The user '{user}' does not have sufficient permissions to create the model +'{model}' in the schema '{schema}'. Please adjust the permissions of the +'{user}' user on the '{schema}' schema. With a superuser account, execute the +following commands, then re-run dbt. grant usage, create on schema "{schema}" to "{user}"; grant select, insert, delete on all tables in schema "{schema}" to "{user}";""" -RELATION_NOT_OWNER_MESSAGE = """The user '{user}' does not have sufficient permissions to drop the model '{model}' in the schema '{schema}'. -This is likely because the relation was created by a different user. Either delete the model "{schema}"."{model}" manually, -or adjust the permissions of the '{user}' user in the '{schema}' schema.""" +RELATION_NOT_OWNER_MESSAGE = """ +The user '{user}' does not have sufficient permissions to drop the model +'{model}' in the schema '{schema}'. This is likely because the relation was +created by a different user. Either delete the model "{schema}"."{model}" +manually, or adjust the permissions of the '{user}' user in the '{schema}' +schema.""" -READ_PERMISSION_DENIED_ERROR = """Encountered an error while executing model '{model}'. +READ_PERMISSION_DENIED_ERROR = """ +Encountered an error while executing model '{model}'. > {error} -Check that the user '{user}' has sufficient permissions to read from all necessary source tables""" +Check that the user '{user}' has sufficient permissions to read from all +necessary source tables""" + class Column(object): def __init__(self, column, dtype, char_size): @@ -59,7 +69,8 @@ def string_size(self): return int(self.char_size) def can_expand_to(self, other_column): - "returns True if this column can be expanded to the size of the other column" + """returns True if this column can be expanded to the size of the + other column""" if not self.is_string() or not other_column.is_string(): return False @@ -72,6 +83,7 @@ def string_type(cls, size): def __repr__(self): return "".format(self.name, self.data_type) + class Schema(object): def __init__(self, project, target): self.project = project @@ -94,7 +106,8 @@ def get_table_columns_if_cached(self, schema, table): def get_schemas(self): existing = [] - results = self.execute_and_fetch('select nspname from pg_catalog.pg_namespace') + results = self.execute_and_fetch( + 'select nspname from pg_catalog.pg_namespace') return [name for (name,) in results] def create_schema(self, schema_name): @@ -102,10 +115,13 @@ def create_schema(self, schema_name): user = target_cfg['user'] try: - self.execute('create schema if not exists "{}"'.format(schema_name)) + self.execute( + 'create schema if not exists "{}"'.format(schema_name)) except psycopg2.ProgrammingError as e: if "permission denied for" in e.diag.message_primary: - raise RuntimeError(SCHEMA_PERMISSION_DENIED_MESSAGE.format(schema=schema_name, user=user)) + raise RuntimeError( + SCHEMA_PERMISSION_DENIED_MESSAGE.format( + schema=schema_name, user=user)) else: raise e @@ -113,8 +129,7 @@ def query_for_existing(self, schema): sql = """ select tablename as name, 'table' as type from pg_tables where schemaname = '{schema}' union all - select viewname as name, 'view' as type from pg_views where schemaname = '{schema}' """.format(schema=schema) - + select viewname as name, 'view' as type from pg_views where schemaname = '{schema}' """.format(schema=schema) # noqa results = self.execute_and_fetch(sql) existing = [(name, relation_type) for (name, relation_type) in results] @@ -129,7 +144,9 @@ def execute(self, sql): pre = time.time() cursor.execute(sql) post = time.time() - logger.debug("SQL status: %s in %0.2f seconds", cursor.statusmessage, post-pre) + logger.debug( + "SQL status: %s in %0.2f seconds", + cursor.statusmessage, post-pre) return cursor.statusmessage except Exception as e: self.target.rollback() @@ -145,7 +162,9 @@ def execute_and_fetch(self, sql): pre = time.time() cursor.execute(sql) post = time.time() - logger.debug("SQL status: %s in %0.2f seconds", cursor.statusmessage, post-pre) + logger.debug( + "SQL status: %s in %0.2f seconds", + cursor.statusmessage, post-pre) data = cursor.fetchall() logger.debug("SQL response: %s", data) return data @@ -159,11 +178,15 @@ def execute_and_handle_permissions(self, query, model_name): try: return self.execute(query) except psycopg2.ProgrammingError as e: - error_data = {"model": model_name, "schema": self.target.schema, "user": self.target.user} + error_data = {"model": model_name, + "schema": self.target.schema, + "user": self.target.user} if 'must be owner of relation' in e.diag.message_primary: - raise RuntimeError(RELATION_NOT_OWNER_MESSAGE.format(**error_data)) + raise RuntimeError( + RELATION_NOT_OWNER_MESSAGE.format(**error_data)) elif "permission denied for" in e.diag.message_primary: - raise RuntimeError(RELATION_PERMISSION_DENIED_MESSAGE.format(**error_data)) + raise RuntimeError( + RELATION_PERMISSION_DENIED_MESSAGE.format(**error_data)) else: raise e @@ -178,7 +201,9 @@ def execute_without_auto_commit(self, sql, handle=None): pre = time.time() cursor.execute(sql) post = time.time() - logger.debug("SQL status: %s in %0.2f seconds", cursor.statusmessage, post-pre) + logger.debug( + "SQL status: %s in %0.2f seconds", + cursor.statusmessage, post-pre) return handle, cursor.statusmessage except Exception as e: self.target.rollback() @@ -189,25 +214,32 @@ def execute_without_auto_commit(self, sql, handle=None): cursor.close() def truncate(self, schema, relation): - sql = 'truncate table "{schema}"."{relation}"'.format(schema=schema, relation=relation) + sql = ('truncate table "{schema}"."{relation}"' + .format(schema=schema, relation=relation)) logger.debug("dropping table %s.%s", schema, relation) self.execute_and_handle_permissions(sql, relation) logger.debug("dropped %s.%s", schema, relation) def drop(self, schema, relation_type, relation): - sql = 'drop {relation_type} if exists "{schema}"."{relation}" cascade'.format(schema=schema, relation_type=relation_type, relation=relation) + sql = ('drop {relation_type} if exists "{schema}"."{relation}" cascade' + .format( + schema=schema, + relation_type=relation_type, + relation=relation)) logger.debug("dropping %s %s.%s", relation_type, schema, relation) self.execute_and_handle_permissions(sql, relation) logger.debug("dropped %s %s.%s", relation_type, schema, relation) def sql_columns_in_table(self, schema_name, table_name): - sql = """ + sql = (""" select column_name, data_type, character_maximum_length from information_schema.columns - where table_name = '{table_name}'""".format(table_name=table_name).strip() + where table_name = '{table_name}'""" + .format(table_name=table_name).strip()) if schema_name is not None: - sql += " AND table_schema = '{schema_name}'".format(schema_name=schema_name) + sql += (" AND table_schema = '{schema_name}'" + .format(schema_name=schema_name)) return sql @@ -234,26 +266,37 @@ def get_columns_in_table(self, schema_name, table_name, use_cached=True): return columns def rename(self, schema, from_name, to_name): - rename_query = 'alter table "{schema}"."{from_name}" rename to "{to_name}"'.format(schema=schema, from_name=from_name, to_name=to_name) - logger.debug("renaming model %s.%s --> %s.%s", schema, from_name, schema, to_name) + rename_query = 'alter table "{schema}"."{from_name}" rename to "{to_name}"'.format(schema=schema, from_name=from_name, to_name=to_name) # noqa + logger.debug( + "renaming model %s.%s --> %s.%s", + schema, from_name, schema, to_name) self.execute_and_handle_permissions(rename_query, from_name) - logger.debug("renamed model %s.%s --> %s.%s", schema, from_name, schema, to_name) - - def get_missing_columns(self, from_schema, from_table, to_schema, to_table): - "Returns dict of {column:type} for columns in from_table that are missing from to_table" - from_columns = {col.name:col for col in self.get_columns_in_table(from_schema, from_table)} - to_columns = {col.name:col for col in self.get_columns_in_table(to_schema, to_table)} + logger.debug( + "renamed model %s.%s --> %s.%s", + schema, from_name, schema, to_name) + + def get_missing_columns(self, from_schema, from_table, to_schema, + to_table): + """Returns dict of {column:type} for columns in from_table that are + missing from to_table""" + from_columns = {col.name: col for col in + self.get_columns_in_table(from_schema, from_table)} + to_columns = {col.name: col for col in + self.get_columns_in_table(to_schema, to_table)} missing_columns = set(from_columns.keys()) - set(to_columns.keys()) - return [col for (col_name, col) in from_columns.items() if col_name in missing_columns] + return [col for (col_name, col) in from_columns.items() + if col_name in missing_columns] def create_table(self, schema, table, columns, sort, dist): - fields = ['"{field}" {data_type}'.format(field=column.name, data_type=column.data_type) for column in columns] + fields = ['"{field}" {data_type}'.format( + field=column.name, data_type=column.data_type + ) for column in columns] fields_csv = ",\n ".join(fields) dist = self.target.dist_qualifier(dist) sort = self.target.sort_qualifier('compound', sort) - sql = 'create table if not exists "{schema}"."{table}" (\n {fields}\n) {dist} {sort};'.format(schema=schema, table=table, fields=fields_csv, sort=sort, dist=dist) + sql = 'create table if not exists "{schema}"."{table}" (\n {fields}\n) {dist} {sort};'.format(schema=schema, table=table, fields=fields_csv, sort=sort, dist=dist) # noqa logger.debug('creating table "%s"."%s"'.format(schema, table)) self.execute_and_handle_permissions(sql, table) @@ -284,24 +327,33 @@ def alter_column_type(self, schema, table, column_name, new_column_type): update "{schema}"."{table}" set "{tmp_column}" = "{old_column}"; alter table "{schema}"."{table}" drop column "{old_column}" cascade; alter table "{schema}"."{table}" rename column "{tmp_column}" to "{old_column}"; - """.format(**opts) + """.format(**opts) # noqa status = self.execute(sql) return status def expand_column_types_if_needed(self, temp_table, to_schema, to_table): - source_columns = {col.name: col for col in self.get_columns_in_table(None, temp_table)} - dest_columns = {col.name: col for col in self.get_columns_in_table(to_schema, to_table)} + source_columns = {col.name: col for col in + self.get_columns_in_table(None, temp_table)} + dest_columns = {col.name: col for col in + self.get_columns_in_table(to_schema, to_table)} for column_name, source_column in source_columns.items(): dest_column = dest_columns.get(column_name) - if dest_column is not None and dest_column.can_expand_to(source_column): + if dest_column is not None and \ + dest_column.can_expand_to(source_column): new_type = Column.string_type(source_column.string_size()) - logger.debug("Changing col type from %s to %s in table %s.%s", dest_column.data_type, new_type, to_schema, to_table) - self.alter_column_type(to_schema, to_table, column_name, new_type) - - # update these cols in the cache! This is a hack to fix broken incremental models for type expansion. TODO + logger.debug("Changing col type from %s to %s in table %s.%s", + dest_column.data_type, + new_type, + to_schema, + to_table) + self.alter_column_type( + to_schema, to_table, column_name, new_type) + + # update these cols in the cache! This is a hack to fix broken + # incremental models for type expansion. TODO self.cache_table_columns(to_schema, to_table, source_columns) def table_exists(self, schema, table): diff --git a/dbt/schema_tester.py b/dbt/schema_tester.py index 0c2bde4cc72..8b2fbe35192 100644 --- a/dbt/schema_tester.py +++ b/dbt/schema_tester.py @@ -64,6 +64,7 @@ ); """ + class SchemaTester(object): def __init__(self, project): self.project = project @@ -84,12 +85,15 @@ def execute_query(self, model, sql): pre = time.time() cursor.execute(sql) post = time.time() - logger.debug("SQL status: %s in %d seconds", cursor.statusmessage, post-pre) + logger.debug( + "SQL status: %s in %d seconds", + cursor.statusmessage, post-pre) except psycopg2.ProgrammingError as e: logger.exception('programming error: %s', sql) return e.diag.message_primary except Exception as e: - logger.exception('encountered exception while running: %s', sql) + logger.exception( + 'encountered exception while running: %s', sql) e.model = model raise e @@ -97,7 +101,9 @@ def execute_query(self, model, sql): if len(result) != 1: logger.error("SQL: %s", sql) logger.error("RESULT: %s", result) - raise RuntimeError("Unexpected validation result. Expected 1 record, got {}".format(len(result))) + raise RuntimeError( + "Unexpected validation result. Expected 1 record, " + "got {}".format(len(result))) else: return result[0] diff --git a/dbt/seeder.py b/dbt/seeder.py index 52d81b66f39..768011361ad 100644 --- a/dbt/seeder.py +++ b/dbt/seeder.py @@ -9,6 +9,7 @@ from dbt.source import Source from dbt.logger import GLOBAL_LOGGER as logger + class Seeder: def __init__(self, project): self.project = project @@ -19,18 +20,24 @@ def find_csvs(self): return Source(self.project).get_csvs(self.project['data-paths']) def drop_table(self, cursor, schema, table): - sql = 'drop table if exists "{schema}"."{table}" cascade'.format(schema=schema, table=table) + sql = 'drop table if exists "{schema}"."{table}" cascade'.format( + schema=schema, table=table + ) logger.info("Dropping table {}.{}".format(schema, table)) cursor.execute(sql) def truncate_table(self, cursor, schema, table): - sql = 'truncate table "{schema}"."{table}"'.format(schema=schema, table=table) + sql = 'truncate table "{schema}"."{table}"'.format( + schema=schema, table=table + ) logger.info("Truncating table {}.{}".format(schema, table)) cursor.execute(sql) def create_table(self, cursor, schema, table, virtual_table): sql_table = csv_sql.make_table(virtual_table, db_schema=schema) - create_table_sql = csv_sql.make_create_table_statement(sql_table, dialect='postgresql') + create_table_sql = csv_sql.make_create_table_statement( + sql_table, dialect='postgresql' + ) logger.info("Creating table {}.{}".format(schema, table)) cursor.execute(create_table_sql) @@ -38,7 +45,12 @@ def insert_into_table(self, cursor, schema, table, virtual_table): headers = virtual_table.headers() header_csv = ", ".join(['"{}"'.format(h) for h in headers]) - base_insert = 'INSERT INTO "{schema}"."{table}" ({header_csv}) VALUES '.format(schema=schema, table=table, header_csv=header_csv) + base_insert = ('INSERT INTO "{schema}"."{table}" ({header_csv}) ' + 'VALUES '.format( + schema=schema, + table=table, + header_csv=header_csv + )) records = [] def quote_or_null(s): @@ -48,15 +60,17 @@ def quote_or_null(s): return "'{}'".format(s) for row in virtual_table.to_rows(): - record_csv = ', '.join([quote_or_null(val) for val in row]) - record_csv_wrapped = "({})".format(record_csv) - records.append(record_csv_wrapped) + record_csv = ', '.join([quote_or_null(val) for val in row]) + record_csv_wrapped = "({})".format(record_csv) + records.append(record_csv_wrapped) insert_sql = "{} {}".format(base_insert, ",\n".join(records)) - logger.info("Inserting {} records into table {}.{}".format(len(virtual_table.to_rows()), schema, table)) + logger.info("Inserting {} records into table {}.{}" + .format(len(virtual_table.to_rows()), schema, table)) cursor.execute(insert_sql) def existing_tables(self, cursor, schema): - sql = "select tablename as name from pg_tables where schemaname = '{schema}'".format(schema=schema) + sql = ("select tablename as name from pg_tables where " + "schemaname = '{schema}'".format(schema=schema)) cursor.execute(sql) existing = set([row[0] for row in cursor.fetchall()]) @@ -75,21 +89,35 @@ def do_seed(self, schema, cursor, drop_existing): if table_name in existing_tables: if drop_existing: self.drop_table(cursor, schema, table_name) - self.create_table(cursor, schema, table_name, virtual_table) + self.create_table( + cursor, + schema, + table_name, + virtual_table + ) else: self.truncate_table(cursor, schema, table_name) else: self.create_table(cursor, schema, table_name, virtual_table) try: - self.insert_into_table(cursor, schema, table_name, virtual_table) + self.insert_into_table( + cursor, schema, table_name, virtual_table + ) except psycopg2.ProgrammingError as e: - logger.info('Encountered an error while inserting into table "{}"."{}"'.format(schema, table_name)) - logger.info('Check for formatting errors in {}'.format(csv.filepath)) - logger.info('Try --drop-existing to delete and recreate the table instead') + logger.info( + 'Encountered an error while inserting into table "{}"."{}"' + .format(schema, table_name) + ) + logger.info( + 'Check for formatting errors in {}'.format(csv.filepath) + ) + logger.info( + 'Try --drop-existing to delete and recreate the table ' + 'instead' + ) logger.info(str(e)) - def seed(self, drop_existing=False): schema = self.target.schema diff --git a/dbt/source.py b/dbt/source.py index c53c13d7818..0a6e5c40cd5 100644 --- a/dbt/source.py +++ b/dbt/source.py @@ -1,7 +1,8 @@ - import os.path import fnmatch -from dbt.model import Model, Analysis, TestModel, SchemaFile, Csv, Macro, ArchiveModel, DataTest +from dbt.model import Model, Analysis, TestModel, SchemaFile, Csv, Macro, \ + ArchiveModel, DataTest + class Source(object): def __init__(self, project, own_project=None): @@ -9,12 +10,14 @@ def __init__(self, project, own_project=None): self.project_root = project['project-root'] self.project_name = project['name'] - self.own_project = own_project if own_project is not None else self.project + self.own_project = (own_project if own_project is not None + else self.project) self.own_project_root = self.own_project['project-root'] self.own_project_name = self.own_project['name'] def find(self, source_paths, file_pattern): - """returns abspath, relpath, filename of files matching file_regex in source_paths""" + """returns abspath, relpath, filename of files matching file_regex in + source_paths""" found = [] if type(source_paths) not in (list, tuple): @@ -28,28 +31,37 @@ def find(self, source_paths, file_pattern): rel_path = os.path.relpath(abs_path, root_path) if fnmatch.fnmatch(filename, file_pattern): - found.append((self.project, source_path, rel_path, self.own_project)) + found.append( + (self.project, + source_path, + rel_path, + self.own_project) + ) return found def get_models(self, model_dirs, create_template): pattern = "[!.#~]*.sql" - models = [Model(*model + (create_template,)) for model in self.find(model_dirs, pattern)] + models = [Model(*model + (create_template,)) + for model in self.find(model_dirs, pattern)] return models def get_test_models(self, model_dirs, create_template): pattern = "[!.#~]*.sql" - models = [TestModel(*model + (create_template,)) for model in self.find(model_dirs, pattern)] + models = [TestModel(*model + (create_template,)) + for model in self.find(model_dirs, pattern)] return models def get_analyses(self, analysis_dirs): pattern = "[!.#~]*.sql" - models = [Analysis(*analysis) for analysis in self.find(analysis_dirs, pattern)] + models = [Analysis(*analysis) + for analysis in self.find(analysis_dirs, pattern)] return models def get_schemas(self, model_dirs): "Get schema.yml files" pattern = "[!.#~]*.yml" - schemas = [SchemaFile(*schema) for schema in self.find(model_dirs, pattern)] + schemas = [SchemaFile(*schema) + for schema in self.find(model_dirs, pattern)] return schemas def get_tests(self, test_dirs): @@ -88,7 +100,7 @@ def get_archives(self, create_template): for table in tables: fields = table.copy() fields.update(schema) - archives.append(ArchiveModel(self.project, create_template, fields)) + archives.append(ArchiveModel( + self.project, create_template, fields + )) return archives - - diff --git a/dbt/ssh_forward.py b/dbt/ssh_forward.py index 7a457225a37..0ff32097998 100644 --- a/dbt/ssh_forward.py +++ b/dbt/ssh_forward.py @@ -1,5 +1,3 @@ - -#import sshtunnel import logging # modules are only imported once -- make sure that we don't have > 1 @@ -7,23 +5,6 @@ server = None + def get_or_create_tunnel(host, port, user, remote_host, remote_port, timeout): pass - #global server - #if server is None: - # logger = logging.getLogger(__name__) - - # bind_from = (host, port) - # bind_to = (remote_host, remote_port) - - # # hack - # sshtunnel.SSH_TIMEOUT = timeout - # server = sshtunnel.SSHTunnelForwarder(bind_from, ssh_username=user, remote_bind_address=bind_to, logger=logger) - # try: - # server.start() - # except sshtunnel.BaseSSHTunnelForwarderError as e: - # raise RuntimeError("Problem connecting through {}:{}: {}".format(host, port, str(e))) - # except KeyboardInterrupt: - # raise RuntimeError('Tunnel aborted (ctrl-c)') - - #return server diff --git a/dbt/targets.py b/dbt/targets.py index 9d23a7e48a6..c6a5613fbdd 100644 --- a/dbt/targets.py +++ b/dbt/targets.py @@ -3,15 +3,13 @@ import os import logging -#from paramiko import SSHConfig -#logging.getLogger("paramiko").setLevel(logging.WARNING) -#import dbt.ssh_forward - THREAD_MIN = 1 THREAD_MAX = 8 BAD_THREADS_ERROR = """Invalid value given for "threads" in active target. -Value given was {supplied} but it should be an int between {min_val} and {max_val}""" +Value given was {supplied} but it should be an int between {min_val} and +{max_val}.""" + class BaseSQLTarget(object): def __init__(self, cfg, threads): @@ -25,53 +23,19 @@ def __init__(self, cfg, threads): self.threads = self.__get_threads(cfg, threads) - #self.ssh_host = cfg.get('ssh-host', None) self.ssh_host = None self.handle = None - #def get_tunnel_config(self): - # config = SSHConfig() - - # config_filepath = os.path.join(os.path.expanduser('~'), '.ssh/config') - # config.parse(open(config_filepath)) - # options = config.lookup(self.ssh_host) - # return options - - #def __open_tunnel(self): - # config = self.get_tunnel_config() - # host = config.get('hostname') - # port = int(config.get('port', '22')) - # user = config.get('user') - # timeout = config.get('connecttimeout', 10) - # timeout = float(timeout) - - # if host is None: - # raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'hostname' field".format(self.ssh_host)) - # if user is None: - # raise RuntimeError("Invalid ssh config for Hostname {} -- missing 'user' field".format(self.ssh_host)) - - # # modules are only imported once -- this singleton makes sure we don't try to bind to the host twice (and lock) - # server = dbt.ssh_forward.get_or_create_tunnel(host, port, user, self.host, self.port, timeout) - - # # rebind the pg host and port - # self.host = 'localhost' - # self.port = server.local_bind_port - - # return server - def should_open_tunnel(self): - #return self.ssh_host is not None return False # make the user explicitly call this function to enable the ssh tunnel - # we don't want it to be automatically opened any time someone makes a new target + # we don't want it to be automatically opened any time someone makes a new + # target def open_tunnel_if_needed(self): - #self.ssh_tunnel = self.__open_tunnel() pass def cleanup(self): - #if self.ssh_tunnel is not None: - # self.ssh_tunnel.stop() pass def __get_threads(self, cfg, cli_threads=None): @@ -80,7 +44,13 @@ def __get_threads(self, cfg, cli_threads=None): else: supplied = cli_threads - bad_threads_error = RuntimeError(BAD_THREADS_ERROR.format(supplied=supplied, min_val=THREAD_MIN, max_val=THREAD_MAX)) + bad_threads_error = RuntimeError( + BAD_THREADS_ERROR.format( + supplied=supplied, + min_val=THREAD_MIN, + max_val=THREAD_MAX + ) + ) if type(supplied) != int: raise bad_threads_error @@ -91,13 +61,14 @@ def __get_threads(self, cfg, cli_threads=None): raise bad_threads_error def __get_spec(self): - return "dbname='{}' user='{}' host='{}' password='{}' port='{}' connect_timeout=10".format( - self.dbname, - self.user, - self.host, - self.password, - self.port - ) + return "dbname='{}' user='{}' host='{}' password='{}' port='{}' " \ + "connect_timeout=10".format( + self.dbname, + self.user, + self.host, + self.password, + self.port + ) def get_handle(self): # this is important -- if we use different handles, then redshift @@ -114,6 +85,7 @@ def rollback(self): def type(self): return self.target_type + class RedshiftTarget(BaseSQLTarget): def __init__(self, cfg, threads): super(RedshiftTarget, self).__init__(cfg, threads) @@ -128,15 +100,21 @@ def sort_qualifier(self, sort_type, sort_keys): valid_sort_types = ['compound', 'interleaved'] if sort_type not in valid_sort_types: - raise RuntimeError("Invalid sort_type given: {} -- must be one of {}".format(sort_type, valid_sort_types)) + raise RuntimeError( + "Invalid sort_type given: {} -- must be one of {}" + .format(sort_type, valid_sort_types) + ) if type(sort_keys) == str: sort_keys = [sort_keys] - formatted_sort_keys = ['"{}"'.format(sort_key) for sort_key in sort_keys] + formatted_sort_keys = ['"{}"'.format(sort_key) + for sort_key in sort_keys] keys_csv = ', '.join(formatted_sort_keys) - return "{sort_type} sortkey({keys_csv})".format(sort_type=sort_type, keys_csv=keys_csv) + return "{sort_type} sortkey({keys_csv})".format( + sort_type=sort_type, keys_csv=keys_csv + ) def dist_qualifier(self, dist_key): dist_key = dist_key.strip().lower() @@ -146,6 +124,7 @@ def dist_qualifier(self, dist_key): else: return 'diststyle key distkey("{}")'.format(dist_key) + class PostgresTarget(BaseSQLTarget): def __init__(self, cfg, threads): super(PostgresTarget, self).__init__(cfg, threads) @@ -167,6 +146,7 @@ def context(self): 'redshift': RedshiftTarget } + def get_target(cfg, threads=1): target_type = cfg['type'] if target_type in target_map: @@ -174,4 +154,7 @@ def get_target(cfg, threads=1): return klass(cfg, threads) else: valid_csv = ", ".join(["'{}'".format(t) for t in target_map]) - raise RuntimeError("Invalid target type provided: '{}'. Must be one of {}".format(target_type, valid_csv)) + raise RuntimeError( + "Invalid target type provided: '{}'. Must be one of {}" + .format(target_type, valid_csv) + ) diff --git a/dbt/task/archive.py b/dbt/task/archive.py index fdbd6d75bb8..527276870a8 100644 --- a/dbt/task/archive.py +++ b/dbt/task/archive.py @@ -1,9 +1,9 @@ - from dbt.runner import RunManager from dbt.templates import ArchiveInsertTemplate from dbt.compilation import Compiler from dbt.logger import GLOBAL_LOGGER as logger + class ArchiveTask: def __init__(self, args, project): self.args = args @@ -18,6 +18,12 @@ def compile(self): def run(self): self.compile() - runner = RunManager(self.project, self.project['target-path'], self.create_template.label, self.args) + + runner = RunManager( + self.project, + self.project['target-path'], + self.create_template.label, + self.args + ) results = runner.run_archive() diff --git a/dbt/task/clean.py b/dbt/task/clean.py index d84d37e6abb..c0eac09aa8d 100644 --- a/dbt/task/clean.py +++ b/dbt/task/clean.py @@ -17,10 +17,12 @@ def __is_project_path(self, path): def __is_protected_path(self, path): abs_path = os.path.abspath(path) - protected_paths = self.project['source-paths'] + self.project['test-paths'] + ['.'] + protected_paths = self.project['source-paths'] + \ + self.project['test-paths'] + ['.'] protected_abs_paths = [os.path.abspath for p in protected_paths] - return abs_path in set(protected_abs_paths) or self.__is_project_path(abs_path) + return abs_path in set(protected_abs_paths) or \ + self.__is_project_path(abs_path) def run(self): for path in self.project['clean-targets']: diff --git a/dbt/task/compile.py b/dbt/task/compile.py index f173cb33e6d..dcea876ceb6 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -1,4 +1,3 @@ - from dbt.compilation import Compiler, CompilableEntities from dbt.templates import BaseCreateTemplate, DryCreateTemplate from dbt.logger import GLOBAL_LOGGER as logger @@ -19,5 +18,7 @@ def run(self): compiler.initialize() results = compiler.compile(limit_to=CompilableEntities) - stat_line = ", ".join(["{} {}".format(results[k], k) for k in CompilableEntities]) + stat_line = ", ".join( + ["{} {}".format(results[k], k) for k in CompilableEntities] + ) logger.info("Compiled {}".format(stat_line)) diff --git a/dbt/task/deps.py b/dbt/task/deps.py index d5c31ecaf4c..9d7e40a1efe 100644 --- a/dbt/task/deps.py +++ b/dbt/task/deps.py @@ -8,11 +8,13 @@ from dbt.logger import GLOBAL_LOGGER as logger + def folder_from_git_remote(remote_spec): start = remote_spec.rfind('/') + 1 end = len(remote_spec) - (4 if remote_spec.endswith('.git') else 0) return remote_spec[start:end] + class DepsTask: def __init__(self, args, project): self.args = args @@ -36,7 +38,11 @@ def __pull_repo(self, repo, branch=None): out, err = proc.communicate() - exists = re.match("fatal: destination path '(.+)' already exists", err.decode('utf-8')) + exists = re.match( + "fatal: destination path '(.+)' already exists", + err.decode('utf-8') + ) + folder = None if exists: folder = exists.group(1) @@ -48,7 +54,8 @@ def __pull_repo(self, repo, branch=None): stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = proc.communicate() - remote_branch = 'origin/master' if branch is None else 'origin/{}'.format(branch) + remote_branch = 'origin/master' if branch is None \ + else 'origin/{}'.format(branch) proc = subprocess.Popen( ['git', 'reset', '--hard', remote_branch], cwd=full_path, @@ -69,7 +76,11 @@ def __pull_repo(self, repo, branch=None): def __split_at_branch(self, repo_spec): parts = repo_spec.split("@") - error = RuntimeError("Invalid dep specified: '{}' -- not a repo we can clone".format(repo_spec)) + error = RuntimeError( + "Invalid dep specified: '{}' -- not a repo we can clone".format( + repo_spec + ) + ) repo = None if repo_spec.startswith("git@"): @@ -90,7 +101,7 @@ def __split_at_branch(self, repo_spec): return repo, branch - def __pull_deps_recursive(self, repos, processed_repos = None, i=0): + def __pull_deps_recursive(self, repos, processed_repos=None, i=0): if processed_repos is None: processed_repos = set() for repo_string in repos: @@ -99,19 +110,27 @@ def __pull_deps_recursive(self, repos, processed_repos = None, i=0): try: if repo_folder in processed_repos: - logger.info("skipping already processed dependency {}".format(repo_folder)) + logger.info( + "skipping already processed dependency {}" + .format(repo_folder) + ) else: dep_folder = self.__pull_repo(repo, branch) dep_project = project.read_project( os.path.join(self.project['modules-path'], dep_folder, - 'dbt_project.yml'), self.project.profiles_dir, profile_to_load=self.project.profile_to_load + 'dbt_project.yml'), + self.project.profiles_dir, + profile_to_load=self.project.profile_to_load ) processed_repos.add(dep_folder) - self.__pull_deps_recursive(dep_project['repositories'], processed_repos, i+1) + self.__pull_deps_recursive( + dep_project['repositories'], processed_repos, i+1 + ) except IOError as e: if e.errno == errno.ENOENT: - logger.info("'{}' is not a valid dbt project - dbt_project.yml not found".format(repo)) + logger.info("'{}' is not a valid dbt project - " + "dbt_project.yml not found".format(repo)) exit(1) else: raise e diff --git a/dbt/task/init.py b/dbt/task/init.py index f3927f5140f..68a35fb1157 100644 --- a/dbt/task/init.py +++ b/dbt/task/init.py @@ -6,7 +6,8 @@ version: '1.0' source-paths: ["models"] # paths with source code to compile -analysis-paths: ["analysis"] # path with analysis files which are compiled, but not run +analysis-paths: ["analysis"] # path with analysis files which are compiled, but + # not run target-path: "target" # path for compiled code clean-targets: ["target"] # directories removed by the clean task test-paths: ["test"] # where to store test results @@ -14,8 +15,9 @@ # specify per-model configs #models: -# package_name: # define configs for this package (called "package_name" above) -# pardot: # assuming pardot is listed in the models/ directory +# package_name: # define configs for this package (called +# # "package_name" above) +# pardot: # assuming pardot is listed in models/ # enabled: false # disable all pardot models except where overriden # pardot_emails: # override the configs for the pardot_emails model # enabled: true # enable this specific model @@ -32,6 +34,7 @@ dbt_modules/ """ + class InitTask: def __init__(self, args, project=None): self.args = args @@ -47,7 +50,9 @@ def run(self): project_dir = self.args.project_name if os.path.exists(project_dir): - raise RuntimeError("directory {} already exists!".format(project_dir)) + raise RuntimeError("directory {} already exists!".format( + project_dir + )) os.mkdir(project_dir) diff --git a/dbt/task/run.py b/dbt/task/run.py index 79ca53361fd..6fda9249466 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -9,18 +9,22 @@ THREAD_LIMIT = 9 + class RunTask: def __init__(self, args, project): self.args = args self.project = project def compile(self): - create_template = DryCreateTemplate if self.args.dry else BaseCreateTemplate + create_template = DryCreateTemplate if self.args.dry \ + else BaseCreateTemplate compiler = Compiler(self.project, create_template, self.args) compiler.initialize() - results = compiler.compile(limit_to=['models'] ) + results = compiler.compile(limit_to=['models']) - stat_line = ", ".join(["{} {}".format(results[k], k) for k in CompilableEntities]) + stat_line = ", ".join([ + "{} {}".format(results[k], k) for k in CompilableEntities + ]) logger.info("Compiled {}".format(stat_line)) return create_template.label @@ -28,16 +32,26 @@ def compile(self): def run(self): graph_type = self.compile() - runner = RunManager(self.project, self.project['target-path'], graph_type, self.args) + runner = RunManager( + self.project, self.project['target-path'], graph_type, self.args + ) if self.args.dry: results = runner.dry_run(self.args.models) else: results = runner.run(self.args.models) - total = len(results) - passed = len([r for r in results if not r.errored and not r.skipped]) + total = len(results) + passed = len([r for r in results if not r.errored and not r.skipped]) errored = len([r for r in results if r.errored]) skipped = len([r for r in results if r.skipped]) - logger.info("Done. PASS={passed} ERROR={errored} SKIP={skipped} TOTAL={total}".format(total=total, passed=passed, errored=errored, skipped=skipped)) + logger.info( + "Done. PASS={passed} ERROR={errored} SKIP={skipped} TOTAL={total}" + .format( + total=total, + passed=passed, + errored=errored, + skipped=skipped + ) + ) diff --git a/dbt/task/seed.py b/dbt/task/seed.py index c301558cdac..eb36aef2fc8 100644 --- a/dbt/task/seed.py +++ b/dbt/task/seed.py @@ -1,7 +1,7 @@ - import os from dbt.seeder import Seeder + class SeedTask: def __init__(self, args, project): self.args = args diff --git a/dbt/task/test.py b/dbt/task/test.py index db27ca48ce1..ea6ceb97ce2 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -1,6 +1,6 @@ - -import os, sys +import os import psycopg2 +import sys import yaml from dbt.compilation import Compiler, CompilableEntities @@ -13,7 +13,8 @@ class TestTask: """ Testing: - 1) Create tmp views w/ 0 rows to ensure all tables, schemas, and SQL statements are valid + 1) Create tmp views w/ 0 rows to ensure all tables, schemas, and SQL + statements are valid 2) Read schema files and validate that constraints are satisfied a) not null b) uniquenss @@ -29,16 +30,21 @@ def compile(self): compiler.initialize() results = compiler.compile(limit_to=['tests']) - stat_line = ", ".join(["{} {}".format(results[k], k) for k in CompilableEntities]) + stat_line = ", ".join( + ["{} {}".format(results[k], k) for k in CompilableEntities] + ) logger.info("Compiled {}".format(stat_line)) return compiler def run(self): self.compile() - runner = RunManager(self.project, self.project['target-path'], 'build', self.args) + runner = RunManager( + self.project, self.project['target-path'], 'build', self.args + ) - if (self.args.data and self.args.schema) or (not self.args.data and not self.args.schema): + if (self.args.data and self.args.schema) or \ + (not self.args.data and not self.args.schema): res = runner.run_tests(test_schemas=True, test_data=True) elif self.args.data: res = runner.run_tests(test_schemas=False, test_data=True) diff --git a/dbt/templates.py b/dbt/templates.py index be1d25b26af..ac18f634692 100644 --- a/dbt/templates.py +++ b/dbt/templates.py @@ -90,7 +90,8 @@ def wrap(self, opts): if opts['materialization'] == 'view': sql = self.template.format(**opts) - elif opts['materialization'] == 'table' and not opts['non_destructive']: + elif (opts['materialization'] == 'table' and + not opts['non_destructive']): sql = self.template.format(**opts) elif opts['materialization'] == 'table' and opts['non_destructive']: @@ -295,4 +296,3 @@ class ArchiveInsertTemplate(object): def wrap(self, schema, table, query, unique_key): sql = self.archival_template.format(schema=schema, identifier=table, query=query, unique_key=unique_key, alter_template=self.alter_template, dest_cols=self.dest_cols, definitions=self.definitions) return sql - diff --git a/dbt/tracking.py b/dbt/tracking.py index 843d43e449d..f8ff5afd843 100644 --- a/dbt/tracking.py +++ b/dbt/tracking.py @@ -3,8 +3,6 @@ from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger from snowplow_tracker import SelfDescribingJson, disable_contracts -disable_contracts() - import platform import uuid import yaml @@ -12,6 +10,7 @@ import json import logging +disable_contracts() sp_logger.setLevel(100) COLLECTOR_URL = "events.fivetran.com/snowplow/forgiving_ain" @@ -19,16 +18,20 @@ COOKIE_PATH = os.path.join(os.path.expanduser('~'), '.dbt/.user.yml') -INVOCATION_SPEC = "https://raw.githubusercontent.com/analyst-collective/dbt/master/events/schemas/com.fishtownanalytics/invocation_event.json" -PLATFORM_SPEC = "https://raw.githubusercontent.com/analyst-collective/dbt/master/events/schemas/com.fishtownanalytics/platform_context.json" -RUN_MODEL_SPEC = "https://raw.githubusercontent.com/analyst-collective/dbt/master/events/schemas/com.fishtownanalytics/run_model_context.json" -INVOCATION_ENV_SPEC = "https://raw.githubusercontent.com/analyst-collective/dbt/master/events/schemas/com.fishtownanalytics/invocation_env_context.json" +BASE_URL = 'https://raw.githubusercontent.com/analyst-collective/'\ + 'dbt/master/events/schemas/com.fishtownanalytics/' + +INVOCATION_SPEC = BASE_URL + "invocation_event.json" +PLATFORM_SPEC = BASE_URL + "platform_context.json" +RUN_MODEL_SPEC = BASE_URL + "run_model_context.json" +INVOCATION_ENV_SPEC = BASE_URL + "invocation_env_context.json" DBT_INVOCATION_ENV = 'DBT_INVOCATION_ENV' emitter = Emitter(COLLECTOR_URL, protocol=COLLECTOR_PROTOCOL, buffer_size=1) tracker = Tracker(emitter, namespace="cf", app_id="dbt") + def __write_user(): user = { "id": str(uuid.uuid4()) @@ -43,6 +46,7 @@ def __write_user(): return user + def get_user(): if os.path.isfile(COOKIE_PATH): with open(COOKIE_PATH, "r") as fh: @@ -57,75 +61,87 @@ def get_user(): return user + def get_options(args): exclude = ['cls', 'target', 'profile'] - options = {k:v for (k, v) in args.__dict__.items() if k not in exclude} + options = {k: v for (k, v) in args.__dict__.items() if k not in exclude} return json.dumps(options) + def get_run_type(args): - if 'dry' in args and args.dry == True: + if 'dry' in args and args.dry is True: return 'dry' else: return 'regular' + def get_invocation_context(invocation_id, user, project, args): return { - "project_id" : None if project is None else project.hashed_name(), - "user_id" : user.get("id", None), - "invocation_id" : invocation_id, + "project_id": None if project is None else project.hashed_name(), + "user_id": user.get("id", None), + "invocation_id": invocation_id, - "command" : args.which, - "options" : get_options(args), - "version" : dbt_version.installed, + "command": args.which, + "options": get_options(args), + "version": dbt_version.installed, - "run_type" : get_run_type(args), + "run_type": get_run_type(args), } + def get_invocation_start_context(invocation_id, user, project, args): data = get_invocation_context(invocation_id, user, project, args) start_data = { - "progress" : "start", - "result_type" : None, - "result" : None + "progress": "start", + "result_type": None, + "result": None } data.update(start_data) return SelfDescribingJson(INVOCATION_SPEC, data) -def get_invocation_end_context(invocation_id, user, project, args, result_type, result): + +def get_invocation_end_context( + invocation_id, user, project, args, result_type, result +): data = get_invocation_context(invocation_id, user, project, args) start_data = { - "progress" : "end", - "result_type" : result_type, - "result" : result, + "progress": "end", + "result_type": result_type, + "result": result, } data.update(start_data) return SelfDescribingJson(INVOCATION_SPEC, data) -def get_invocation_invalid_context(invocation_id, user, project, args, result_type, result): + +def get_invocation_invalid_context( + invocation_id, user, project, args, result_type, result +): data = get_invocation_context(invocation_id, user, project, args) start_data = { - "progress" : "invalid", - "result_type" : result_type, - "result" : result, + "progress": "invalid", + "result_type": result_type, + "result": result, } data.update(start_data) return SelfDescribingJson(INVOCATION_SPEC, data) + def get_platform_context(): data = { - "platform" : platform.platform(), - "python" : platform.python_version(), - "python_version" : platform.python_implementation(), + "platform": platform.platform(), + "python": platform.python_version(), + "python_version": platform.python_implementation(), } return SelfDescribingJson(PLATFORM_SPEC, data) + def get_dbt_env_context(): default = 'manual' @@ -134,7 +150,7 @@ def get_dbt_env_context(): dbt_invocation_env = default data = { - "environment" : dbt_invocation_env, + "environment": dbt_invocation_env, } return SelfDescribingJson(INVOCATION_ENV_SPEC, data) @@ -150,6 +166,7 @@ def get_dbt_env_context(): __is_do_not_track = False + def track(*args, **kwargs): if __is_do_not_track: return @@ -158,32 +175,57 @@ def track(*args, **kwargs): try: tracker.track_struct_event(*args, **kwargs) except Exception as e: - logger.exception("An error was encountered while trying to send an event") + logger.exception( + "An error was encountered while trying to send an event" + ) + def track_invocation_start(project=None, args=None): - invocation_context = get_invocation_start_context(invocation_id, user, project, args) + invocation_context = get_invocation_start_context( + invocation_id, user, project, args + ) context = [invocation_context, platform_context, env_context] track(category="dbt", action='invocation', label='start', context=context) + def track_model_run(options): context = [SelfDescribingJson(RUN_MODEL_SPEC, options)] model_id = options['model_id'] - track(category="dbt", action='run_model', label=invocation_id, context=context) - -def track_invocation_end(project=None, args=None, result_type=None, result=None): - invocation_context = get_invocation_end_context(invocation_id, user, project, args, result_type, result) + track( + category="dbt", + action='run_model', + label=invocation_id, + context=context + ) + + +def track_invocation_end( + project=None, args=None, result_type=None, result=None +): + invocation_context = get_invocation_end_context( + invocation_id, user, project, args, result_type, result + ) context = [invocation_context, platform_context, env_context] track(category="dbt", action='invocation', label='end', context=context) -def track_invalid_invocation(project=None, args=None, result_type=None, result=None): - invocation_context = get_invocation_invalid_context(invocation_id, user, project, args, result_type, result) + +def track_invalid_invocation( + project=None, args=None, result_type=None, result=None +): + invocation_context = get_invocation_invalid_context( + invocation_id, user, project, args, result_type, result + ) context = [invocation_context, platform_context, env_context] - track(category="dbt", action='invocation', label='invalid', context=context) + track( + category="dbt", action='invocation', label='invalid', context=context + ) + def flush(): logger.debug("Flushing usage events") tracker.flush() + def do_not_track(): global __is_do_not_track logger.debug("Not sending anonymous usage events") diff --git a/dbt/utils.py b/dbt/utils.py index ee4cf827cb1..963045aded1 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -1,8 +1,6 @@ - import os -import dbt.project -import pprint import json + import dbt.project from dbt.logger import GLOBAL_LOGGER as logger @@ -19,6 +17,7 @@ 'vars' ] + class This(object): def __init__(self, schema, table, name): self.schema = schema @@ -31,20 +30,31 @@ def schema_table(self, schema, table): def __repr__(self): return self.schema_table(self.schema, self.table) + def compiler_error(model, msg): if model is None: name = '' else: name = model.nice_name - raise RuntimeError("! Compilation error while compiling model {}:\n! {}".format(name, msg)) + raise RuntimeError( + "! Compilation error while compiling model {}:\n! {}" + .format(name, msg) + ) + def compiler_warning(model, msg): - logger.info("* Compilation warning while compiling model {}:\n* {}".format(model.nice_name, msg)) + logger.info( + "* Compilation warning while compiling model {}:\n* {}" + .format(model.nice_name, msg) + ) + class Var(object): - UndefinedVarError = "Required var '{}' not found in config:\nVars supplied to {} = {}" - NoneVarError = "Supplied var '{}' is undefined in config:\nVars supplied to {} = {}" + UndefinedVarError = "Required var '{}' not found in config:\nVars "\ + "supplied to {} = {}" + NoneVarError = "Supplied var '{}' is undefined in config:\nVars supplied"\ + "to {} = {}" def __init__(self, model, context): self.model = model @@ -57,54 +67,87 @@ def pretty_dict(self, data): def __call__(self, var_name, default=None): pretty_vars = self.pretty_dict(self.local_vars) if var_name not in self.local_vars and default is None: - compiler_error(self.model, self.UndefinedVarError.format(var_name, self.model.nice_name, pretty_vars)) + compiler_error( + self.model, + self.UndefinedVarError.format( + var_name, self.model.nice_name, pretty_vars + ) + ) elif var_name in self.local_vars: raw = self.local_vars[var_name] if raw is None: - compiler_error(self.model, self.NoneVarError.format(var_name, self.model.nice_name, pretty_vars)) + compiler_error( + self.model, + self.NoneVarError.format( + var_name, self.model.nice_name, pretty_vars + ) + ) compiled = self.model.compile_string(self.context, raw) return compiled else: return default + def find_model_by_name(models, name, package_namespace=None): found = [] for model in models: if model.name == name: if package_namespace is None: found.append(model) - elif package_namespace is not None and package_namespace == model.project['name']: + elif (package_namespace is not None and + package_namespace == model.project['name']): found.append(model) - nice_package_name = 'ANY' if package_namespace is None else package_namespace + nice_package_name = 'ANY' if package_namespace is None \ + else package_namespace if len(found) == 0: - raise RuntimeError("Can't find a model named '{}' in package '{}' -- does it exist?".format(name, nice_package_name)) + raise RuntimeError( + "Can't find a model named '{}' in package '{}' -- does it exist?" + .format(name, nice_package_name) + ) elif len(found) == 1: return found[0] else: - raise RuntimeError("Model specification is ambiguous: model='{}' package='{}' -- {} models match criteria: {}".format(name, nice_package_name, len(found), found)) + raise RuntimeError( + "Model specification is ambiguous: model='{}' package='{}' -- " + "{} models match criteria: {}" + .format(name, nice_package_name, len(found), found) + ) + def find_model_by_fqn(models, fqn): for model in models: if tuple(model.fqn) == tuple(fqn): return model - raise RuntimeError("Couldn't find a compiled model with fqn: '{}'".format(fqn)) + + raise RuntimeError( + "Couldn't find a compiled model with fqn: '{}'".format(fqn) + ) + def dependency_projects(project): for obj in os.listdir(project['modules-path']): full_obj = os.path.join(project['modules-path'], obj) if os.path.isdir(full_obj): try: - yield dbt.project.read_project(os.path.join(full_obj, 'dbt_project.yml'), project.profiles_dir, profile_to_load=project.profile_to_load) + yield dbt.project.read_project( + os.path.join(full_obj, 'dbt_project.yml'), + project.profiles_dir, + profile_to_load=project.profile_to_load + ) except dbt.project.DbtProjectError as e: - logger.info("Error reading dependency project at {}".format(full_obj)) + logger.info( + "Error reading dependency project at {}".format(full_obj) + ) logger.info(str(e)) + def split_path(path): norm = os.path.normpath(path) return path.split(os.sep) -# influenced by: http://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data + +# http://stackoverflow.com/questions/20656135/python-deep-merge-dictionary-data def deep_merge(destination, source): if isinstance(source, dict): for key, value in source.items(): @@ -120,6 +163,7 @@ def deep_merge(destination, source): destination[key] = value return destination + def to_unicode(s, encoding): try: unicode @@ -127,6 +171,7 @@ def to_unicode(s, encoding): except NameError: return s + def to_string(s): try: unicode diff --git a/dbt/version.py b/dbt/version.py index 659f4607bfc..1264884b70d 100644 --- a/dbt/version.py +++ b/dbt/version.py @@ -1,6 +1,6 @@ - -import os, re import argparse +import os +import re try: # For Python 3.0 and later @@ -9,7 +9,10 @@ # Fall back to Python 2's urllib2 from urllib2 import urlopen -REMOTE_VERISON_FILE = 'https://raw.githubusercontent.com/analyst-collective/dbt/master/.bumpversion.cfg' +REMOTE_VERSION_FILE = \ + 'https://raw.githubusercontent.com/analyst-collective/dbt/' \ + 'master/.bumpversion.cfg' + def __parse_version(contents): matches = re.search(r"current_version = ([\.0-9]+)", contents) @@ -19,12 +22,14 @@ def __parse_version(contents): version = matches.groups()[0] return version + def get_version(): return __version__ + def get_latest_version(): try: - f = urlopen(REMOTE_VERISON_FILE) + f = urlopen(REMOTE_VERSION_FILE) contents = f.read() except: contents = '' @@ -32,29 +37,36 @@ def get_latest_version(): contents = contents.decode('utf-8') return __parse_version(contents) + def not_latest(): - return """Your version of dbt is out of date! You can find instructions for upgrading here: -http://dbt.readthedocs.io/en/master/guide/upgrading/ -""" + return """Your version of dbt is out of date! You can find instructions + for upgrading here: + + http://dbt.readthedocs.io/en/master/guide/upgrading/ + """ + def get_version_string(): - return "installed version: {}\n latest version: {}".format(installed, latest) + return "installed version: {}\n latest version: {}".format( + installed, latest + ) + def get_version_information(): basic = get_version_string() if is_latest(): - basic +='\nUp to date!' + basic += '\nUp to date!' else: basic += '\n{}'.format(not_latest()) return basic + def is_latest(): return installed == latest + __version__ = '0.6.0' installed = get_version() latest = get_latest_version() - - diff --git a/setup.py b/setup.py index 25d0d81e84f..30f325fdf28 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ author_email="admin@analystcollective.org", url="https://github.com/analyst-collective/dbt", packages=find_packages(), - test_suite = 'test', + test_suite='test', entry_points={ 'console_scripts': [ 'dbt = dbt.main:main', @@ -31,7 +31,5 @@ 'csvkit==0.9.1', 'snowplow-tracker==0.7.2', 'celery==3.1.23', - #'paramiko==2.0.1', - #'sshtunnel==0.0.8.2' ], ) diff --git a/tox.ini b/tox.ini index ef84fef34b7..45e20a0a89b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,12 @@ [tox] -envlist = unit-py27, unit-py35, integration-py27, integration-py35 +envlist = unit-py27, unit-py35, integration-py27, integration-py35, pep8 + +[testenv:pep8] +basepython = python3.5 +commands = /bin/bash -c '$(which pep8) dbt/ --exclude dbt/templates.py' +deps = + -rrequirements.txt + -rdev_requirements.txt [testenv:unit-py27] basepython = python2.7