Skip to content

Commit

Permalink
Compile on-run-(start|end) hooks to file (#412)
Browse files Browse the repository at this point in the history
compile on-run-(start|end) hooks to a file
  • Loading branch information
drewbanin authored May 9, 2017
1 parent 9977388 commit ce46052
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 123 deletions.
15 changes: 11 additions & 4 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dbt.utils
import dbt.include
import dbt.wrapper
import dbt.tracking

from dbt.model import Model
from dbt.utils import This, Var, is_enabled, get_materialization, NodeType, \
Expand Down Expand Up @@ -38,6 +39,7 @@ def print_compile_stats(stats):
NodeType.Archive: 'archives',
NodeType.Analysis: 'analyses',
NodeType.Macro: 'macros',
NodeType.Operation: 'operations',
}

results = {
Expand All @@ -46,6 +48,7 @@ def print_compile_stats(stats):
NodeType.Archive: 0,
NodeType.Analysis: 0,
NodeType.Macro: 0,
NodeType.Operation: 0,
}

results.update(stats)
Expand Down Expand Up @@ -235,8 +238,8 @@ def get_compiler_context(self, model, flat_graph):

context.update(wrapper.get_context_functions())

context['run_started_at'] = '{{ run_started_at }}'
context['invocation_id'] = '{{ invocation_id }}'
context['run_started_at'] = dbt.tracking.active_user.run_started_at
context['invocation_id'] = dbt.tracking.active_user.invocation_id
context['sql_now'] = adapter.date_function()

for unique_id, macro in flat_graph.get('macros').items():
Expand Down Expand Up @@ -280,7 +283,8 @@ def compile_node(self, node, flat_graph):
injected_node, _ = prepend_ctes(compiled_node, flat_graph)

if compiled_node.get('resource_type') in [NodeType.Test,
NodeType.Analysis]:
NodeType.Analysis,
NodeType.Operation]:
# data tests get wrapped in count(*)
# TODO : move this somewhere more reasonable
if 'data' in injected_node['tags'] and \
Expand Down Expand Up @@ -350,7 +354,7 @@ def link_node(self, linker, node, flat_graph):
def link_graph(self, linker, flat_graph):
linked_graph = {
'nodes': {},
'macros': flat_graph.get('macros'),
'macros': flat_graph.get('macros')
}

for name, node in flat_graph.get('nodes').items():
Expand Down Expand Up @@ -468,6 +472,8 @@ def load_all_nodes(self, root_project, all_projects):
all_nodes.update(
dbt.parser.parse_archives_from_projects(root_project,
all_projects))
all_nodes.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects))

return all_nodes

Expand All @@ -478,6 +484,7 @@ def compile(self):
all_projects = self.get_all_projects()

all_macros = self.load_all_macros(root_project, all_projects)

all_nodes = self.load_all_nodes(root_project, all_projects)

flat_graph = {
Expand Down
3 changes: 2 additions & 1 deletion dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
unparsed_node_contract = unparsed_base_contract.extend({
Required('resource_type'): Any(NodeType.Model,
NodeType.Test,
NodeType.Analysis)
NodeType.Analysis,
NodeType.Operation)
})

unparsed_nodes_contract = Schema([unparsed_node_contract])
Expand Down
58 changes: 58 additions & 0 deletions dbt/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,64 @@ def load_and_parse_sql(package_name, root_project, all_projects, root_dir,
return parse_sql_nodes(result, root_project, all_projects, tags)


def get_hooks_from_project(project_cfg, hook_type):
hooks = project_cfg.get(hook_type, [])

if type(hooks) not in (list, tuple):
hooks = [hooks]

return hooks


def get_hooks(all_projects, hook_type):
project_hooks = {}

for project_name, project in all_projects.items():
hooks = get_hooks_from_project(project, hook_type)

if len(hooks) > 0:
project_hooks[project_name] = ";\n".join(hooks)

return project_hooks


def load_and_parse_run_hook_type(root_project, all_projects, hook_type):

if dbt.flags.STRICT_MODE:
dbt.contracts.project.validate_list(all_projects)

project_hooks = get_hooks(all_projects, hook_type)

result = []
for project_name, hooks in project_hooks.items():
project = all_projects[project_name]

hook_path = dbt.utils.get_pseudo_hook_path(hook_type)

result.append({
'name': hook_type,
'root_path': "{}/dbt_project.yml".format(project_name),
'resource_type': NodeType.Operation,
'path': hook_path,
'package_name': project_name,
'raw_sql': hooks
})

tags = {hook_type}
return parse_sql_nodes(result, root_project, all_projects, tags=tags)


def load_and_parse_run_hooks(root_project, all_projects):
hook_nodes = {}
for hook_type in dbt.utils.RunHookType.Both:
project_hooks = load_and_parse_run_hook_type(root_project,
all_projects,
hook_type)
hook_nodes.update(project_hooks)

return hook_nodes


def load_and_parse_macros(package_name, root_project, all_projects, root_dir,
relative_dirs, resource_type, tags=None):
extension = "[!.#~]*.sql"
Expand Down
88 changes: 42 additions & 46 deletions dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import os
import time
import itertools
from datetime import datetime

from dbt.adapters.factory import get_adapter
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.utils import get_materialization, NodeType, is_type
from dbt.utils import get_materialization, NodeType, is_type, get_nodes_by_tags

import dbt.clients.jinja
import dbt.compilation
Expand Down Expand Up @@ -370,26 +369,6 @@ def execute_archive(profile, node, context):
return result


def run_hooks(profile, hooks, context, source):
if type(hooks) not in (list, tuple):
hooks = [hooks]

ctx = {
"target": profile,
"state": "start",
"invocation_id": context['invocation_id'],
"run_started_at": context['run_started_at']
}

compiled_hooks = [
dbt.clients.jinja.get_rendered(hook, ctx) for hook in hooks
]

adapter = get_adapter(profile)

return adapter.execute_all(profile=profile, sqls=compiled_hooks)


def track_model_run(index, num_nodes, run_model_result):
invocation_id = dbt.tracking.active_user.invocation_id
dbt.tracking.track_model_run({
Expand Down Expand Up @@ -461,10 +440,8 @@ def call_table_exists(schema, table):
return adapter.table_exists(
profile, schema, table, node.get('name'))

self.run_started_at = datetime.now()

return {
"run_started_at": datetime.now(),
"run_started_at": dbt.tracking.active_user.run_started_at,
"invocation_id": dbt.tracking.active_user.invocation_id,
"get_columns_in_table": call_get_columns_in_table,
"get_missing_columns": call_get_missing_columns,
Expand Down Expand Up @@ -513,7 +490,6 @@ def execute_node(self, node, flat_graph, existing, profile, adapter):
return node, result

def compile_node(self, node, flat_graph):

compiler = dbt.compilation.Compiler(self.project)
node = compiler.compile_node(node, flat_graph)
return node
Expand Down Expand Up @@ -634,6 +610,18 @@ def as_concurrent_dep_list(self, linker, nodes_to_run):

return concurrent_dependency_list

def run_hooks(self, profile, flat_graph, hook_type):
adapter = get_adapter(profile)

nodes = flat_graph.get('nodes', {}).values()
start_hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation)
hooks = [self.compile_node(hook, flat_graph) for hook in start_hooks]

master_connection = adapter.begin(profile)
compiled_hooks = [hook['wrapped_sql'] for hook in hooks]
adapter.execute_all(profile=profile, sqls=compiled_hooks)
master_connection = adapter.commit(master_connection)

def on_model_failure(self, linker, selected_nodes):
def skip_dependent(node):
dependent_nodes = linker.get_dependent_nodes(node.get('unique_id'))
Expand Down Expand Up @@ -687,12 +675,7 @@ def execute_nodes(self, flat_graph, node_dependency_list, on_failure,
start_time = time.time()

if should_run_hooks:
master_connection = adapter.begin(profile)
run_hooks(self.project.get_target(),
self.project.cfg.get('on-run-start', []),
self.node_context({}),
'on-run-start hooks')
master_connection = adapter.commit(master_connection)
self.run_hooks(profile, flat_graph, dbt.utils.RunHookType.Start)

def get_idx(node):
return node_id_to_index_map.get(node.get('unique_id'))
Expand Down Expand Up @@ -739,12 +722,7 @@ def get_idx(node):
pool.join()

if should_run_hooks:
adapter.begin(profile)
run_hooks(self.project.get_target(),
self.project.cfg.get('on-run-end', []),
self.node_context({}),
'on-run-end hooks')
adapter.commit(master_connection)
self.run_hooks(profile, flat_graph, dbt.utils.RunHookType.End)

execution_time = time.time() - start_time

Expand All @@ -755,18 +733,35 @@ def get_idx(node):

def get_ancestor_ephemeral_nodes(self, flat_graph, linked_graph,
selected_nodes):
node_names = {
node: flat_graph['nodes'].get(node).get('name')
for node in selected_nodes
if node in flat_graph['nodes']
}

include_spec = [
'+{}'.format(node_names[node])
for node in selected_nodes if node in node_names
]

all_ancestors = dbt.graph.selector.select_nodes(
self.project,
linked_graph,
['+{}'.format(flat_graph.get('nodes').get(node).get('name'))
for node in selected_nodes],
include_spec,
[])

return set([ancestor for ancestor in all_ancestors
if(flat_graph['nodes'][ancestor].get(
'resource_type') == NodeType.Model and
get_materialization(
flat_graph['nodes'][ancestor]) == 'ephemeral')])
res = []

for ancestor in all_ancestors:
if ancestor not in flat_graph['nodes']:
continue
ancestor_node = flat_graph['nodes'][ancestor]
is_model = ancestor_node.get('resource_type') == NodeType.Model
is_ephemeral = get_materialization(ancestor_node) == 'ephemeral'
if is_model and is_ephemeral:
res.append(ancestor)

return set(res)

def get_nodes_to_run(self, graph, include_spec, exclude_spec,
resource_types, tags):
Expand Down Expand Up @@ -874,7 +869,8 @@ def compile_models(self, include_spec, exclude_spec):
NodeType.Model,
NodeType.Test,
NodeType.Archive,
NodeType.Analysis
NodeType.Analysis,
NodeType.Operation
]

return self.run_types_from_graph(include_spec,
Expand Down
6 changes: 3 additions & 3 deletions dbt/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dbt import version as dbt_version
from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger
from snowplow_tracker import SelfDescribingJson, disable_contracts
from datetime import datetime

import platform
import uuid
Expand Down Expand Up @@ -42,16 +43,15 @@ def __init__(self):
self.do_not_track = True

self.id = None
self.invocation_id = None
self.invocation_id = str(uuid.uuid4())
self.run_started_at = datetime.now()

def state(self):
return "do not track" if self.do_not_track else "tracking"

def initialize(self):
self.do_not_track = False

self.invocation_id = str(uuid.uuid4())

cookie = self.get_cookie()
self.id = cookie.get('id')

Expand Down
21 changes: 21 additions & 0 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class NodeType(object):
Test = 'test'
Archive = 'archive'
Macro = 'macro'
Operation = 'operation'


class RunHookType:
Start = 'on-run-start'
End = 'on-run-end'
Both = [Start, End]


class This(object):
Expand Down Expand Up @@ -263,6 +270,11 @@ def get_pseudo_test_path(node_name, source_path, test_type):
return os.path.join(*pseudo_path_parts)


def get_pseudo_hook_path(hook_name):
path_parts = ['hooks', "{}.sql".format(hook_name)]
return os.path.join(*path_parts)


def get_run_status_line(results):
total = len(results)
errored = len([r for r in results if r.errored or r.failed])
Expand All @@ -277,3 +289,12 @@ def get_run_status_line(results):
errored=errored,
skipped=skipped
))


def get_nodes_by_tags(nodes, match_tags, resource_type):
matched_nodes = []
for node in nodes:
node_tags = node.get('tags', set())
if len(node_tags & match_tags):
matched_nodes.append(node)
return matched_nodes
Loading

0 comments on commit ce46052

Please sign in to comment.