Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile on-run-(start|end) hooks to file #412

Merged
merged 7 commits into from
May 9, 2017
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dbt.utils
import dbt.include
import dbt.wrapper
import dbt.tracking

from dbt.model import Model
from dbt.utils import This, Var, is_enabled, get_materialization, NodeType, \
Expand Down Expand Up @@ -38,6 +39,7 @@ def print_compile_stats(stats):
NodeType.Archive: 'archives',
NodeType.Analysis: 'analyses',
NodeType.Macro: 'macros',
NodeType.Operation: 'operations',
}

results = {
Expand All @@ -46,6 +48,7 @@ def print_compile_stats(stats):
NodeType.Archive: 0,
NodeType.Analysis: 0,
NodeType.Macro: 0,
NodeType.Operation: 0,
}

results.update(stats)
Expand Down Expand Up @@ -235,8 +238,8 @@ def get_compiler_context(self, model, flat_graph):

context.update(wrapper.get_context_functions())

context['run_started_at'] = '{{ run_started_at }}'
context['invocation_id'] = '{{ invocation_id }}'
context['run_started_at'] = dbt.tracking.active_user.run_started_at
context['invocation_id'] = dbt.tracking.active_user.invocation_id
context['sql_now'] = adapter.date_function()

for unique_id, macro in flat_graph.get('macros').items():
Expand Down Expand Up @@ -280,7 +283,8 @@ def compile_node(self, node, flat_graph):
injected_node, _ = prepend_ctes(compiled_node, flat_graph)

if compiled_node.get('resource_type') in [NodeType.Test,
NodeType.Analysis]:
NodeType.Analysis,
NodeType.Operation]:
# data tests get wrapped in count(*)
# TODO : move this somewhere more reasonable
if 'data' in injected_node['tags'] and \
Expand Down Expand Up @@ -350,7 +354,7 @@ def link_node(self, linker, node, flat_graph):
def link_graph(self, linker, flat_graph):
linked_graph = {
'nodes': {},
'macros': flat_graph.get('macros'),
'macros': flat_graph.get('macros')
}

for name, node in flat_graph.get('nodes').items():
Expand Down Expand Up @@ -468,6 +472,8 @@ def load_all_nodes(self, root_project, all_projects):
all_nodes.update(
dbt.parser.parse_archives_from_projects(root_project,
all_projects))
all_nodes.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects))

return all_nodes

Expand All @@ -478,6 +484,7 @@ def compile(self):
all_projects = self.get_all_projects()

all_macros = self.load_all_macros(root_project, all_projects)

all_nodes = self.load_all_nodes(root_project, all_projects)

flat_graph = {
Expand Down
3 changes: 2 additions & 1 deletion dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
unparsed_node_contract = unparsed_base_contract.extend({
Required('resource_type'): Any(NodeType.Model,
NodeType.Test,
NodeType.Analysis)
NodeType.Analysis,
NodeType.Operation)
})

unparsed_nodes_contract = Schema([unparsed_node_contract])
Expand Down
58 changes: 58 additions & 0 deletions dbt/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,64 @@ def load_and_parse_sql(package_name, root_project, all_projects, root_dir,
return parse_sql_nodes(result, root_project, all_projects, tags)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this file was chmod +xed -- can you undo that



def get_hooks_from_project(project_cfg, hook_type):
hooks = project_cfg.get(hook_type, [])

if type(hooks) not in (list, tuple):
hooks = [hooks]

return hooks


def get_hooks(all_projects, hook_type):
project_hooks = {}

for project_name, project in all_projects.items():
hooks = get_hooks_from_project(project, hook_type)

if len(hooks) > 0:
project_hooks[project_name] = ";\n".join(hooks)

return project_hooks


def load_and_parse_run_hook_type(root_project, all_projects, hook_type):

if dbt.flags.STRICT_MODE:
dbt.contracts.project.validate_list(all_projects)

project_hooks = get_hooks(all_projects, hook_type)

result = []
for project_name, hooks in project_hooks.items():
project = all_projects[project_name]

hook_path = dbt.utils.get_pseudo_hook_path(hook_type)

result.append({
'name': hook_type,
'root_path': "{}/dbt_project.yml".format(project_name),
'resource_type': NodeType.Operation,
'path': hook_path,
'package_name': project_name,
'raw_sql': hooks
})

tags = {hook_type}
return parse_sql_nodes(result, root_project, all_projects, tags=tags)


def load_and_parse_run_hooks(root_project, all_projects):
hook_nodes = {}
for hook_type in dbt.utils.RunHookTypes.Both:
project_hooks = load_and_parse_run_hook_type(root_project,
all_projects,
hook_type)
hook_nodes.update(project_hooks)

return hook_nodes


def load_and_parse_macros(package_name, root_project, all_projects, root_dir,
relative_dirs, resource_type, tags=None):
extension = "[!.#~]*.sql"
Expand Down
88 changes: 42 additions & 46 deletions dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import os
import time
import itertools
from datetime import datetime

from dbt.adapters.factory import get_adapter
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.utils import get_materialization, NodeType, is_type
from dbt.utils import get_materialization, NodeType, is_type, get_nodes_by_tags

import dbt.clients.jinja
import dbt.compilation
Expand Down Expand Up @@ -370,26 +369,6 @@ def execute_archive(profile, node, context):
return result


def run_hooks(profile, hooks, context, source):
if type(hooks) not in (list, tuple):
hooks = [hooks]

ctx = {
"target": profile,
"state": "start",
"invocation_id": context['invocation_id'],
"run_started_at": context['run_started_at']
}

compiled_hooks = [
dbt.clients.jinja.get_rendered(hook, ctx) for hook in hooks
]

adapter = get_adapter(profile)

return adapter.execute_all(profile=profile, sqls=compiled_hooks)


def track_model_run(index, num_nodes, run_model_result):
invocation_id = dbt.tracking.active_user.invocation_id
dbt.tracking.track_model_run({
Expand Down Expand Up @@ -461,10 +440,8 @@ def call_table_exists(schema, table):
return adapter.table_exists(
profile, schema, table, node.get('name'))

self.run_started_at = datetime.now()

return {
"run_started_at": datetime.now(),
"run_started_at": dbt.tracking.active_user.run_started_at,
"invocation_id": dbt.tracking.active_user.invocation_id,
"get_columns_in_table": call_get_columns_in_table,
"get_missing_columns": call_get_missing_columns,
Expand Down Expand Up @@ -513,7 +490,6 @@ def execute_node(self, node, flat_graph, existing, profile, adapter):
return node, result

def compile_node(self, node, flat_graph):

compiler = dbt.compilation.Compiler(self.project)
node = compiler.compile_node(node, flat_graph)
return node
Expand Down Expand Up @@ -634,6 +610,18 @@ def as_concurrent_dep_list(self, linker, nodes_to_run):

return concurrent_dependency_list

def run_hooks(self, profile, flat_graph, hook_type):
adapter = get_adapter(profile)

nodes = flat_graph.get('nodes', {}).values()
start_hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation)
hooks = [self.compile_node(hook, flat_graph) for hook in start_hooks]

master_connection = adapter.begin(profile)
compiled_hooks = [hook['wrapped_sql'] for hook in hooks]
adapter.execute_all(profile=profile, sqls=compiled_hooks)
master_connection = adapter.commit(master_connection)

def on_model_failure(self, linker, selected_nodes):
def skip_dependent(node):
dependent_nodes = linker.get_dependent_nodes(node.get('unique_id'))
Expand Down Expand Up @@ -687,12 +675,7 @@ def execute_nodes(self, flat_graph, node_dependency_list, on_failure,
start_time = time.time()

if should_run_hooks:
master_connection = adapter.begin(profile)
run_hooks(self.project.get_target(),
self.project.cfg.get('on-run-start', []),
self.node_context({}),
'on-run-start hooks')
master_connection = adapter.commit(master_connection)
self.run_hooks(profile, flat_graph, dbt.utils.RunHookTypes.Start)

def get_idx(node):
return node_id_to_index_map.get(node.get('unique_id'))
Expand Down Expand Up @@ -739,12 +722,7 @@ def get_idx(node):
pool.join()

if should_run_hooks:
adapter.begin(profile)
run_hooks(self.project.get_target(),
self.project.cfg.get('on-run-end', []),
self.node_context({}),
'on-run-end hooks')
adapter.commit(master_connection)
self.run_hooks(profile, flat_graph, dbt.utils.RunHookTypes.End)

execution_time = time.time() - start_time

Expand All @@ -755,18 +733,35 @@ def get_idx(node):

def get_ancestor_ephemeral_nodes(self, flat_graph, linked_graph,
selected_nodes):
node_names = {
node: flat_graph['nodes'].get(node).get('name')
for node in selected_nodes
if node in flat_graph['nodes']
}

include_spec = [
'+{}'.format(node_names[node])
for node in selected_nodes if node in node_names
]

all_ancestors = dbt.graph.selector.select_nodes(
self.project,
linked_graph,
['+{}'.format(flat_graph.get('nodes').get(node).get('name'))
for node in selected_nodes],
include_spec,
[])

return set([ancestor for ancestor in all_ancestors
if(flat_graph['nodes'][ancestor].get(
'resource_type') == NodeType.Model and
get_materialization(
flat_graph['nodes'][ancestor]) == 'ephemeral')])
res = []

for ancestor in all_ancestors:
if ancestor not in flat_graph['nodes']:
continue
ancestor_node = flat_graph['nodes'][ancestor]
is_model = ancestor_node.get('resource_type') == NodeType.Model
is_ephemeral = get_materialization(ancestor_node) == 'ephemeral'
if is_model and is_ephemeral:
res.append(ancestor)

return set(res)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def get_nodes_to_run(self, graph, include_spec, exclude_spec,
resource_types, tags):
Expand Down Expand Up @@ -874,7 +869,8 @@ def compile_models(self, include_spec, exclude_spec):
NodeType.Model,
NodeType.Test,
NodeType.Archive,
NodeType.Analysis
NodeType.Analysis,
NodeType.Operation
]

return self.run_types_from_graph(include_spec,
Expand Down
6 changes: 3 additions & 3 deletions dbt/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dbt import version as dbt_version
from snowplow_tracker import Subject, Tracker, Emitter, logger as sp_logger
from snowplow_tracker import SelfDescribingJson, disable_contracts
from datetime import datetime

import platform
import uuid
Expand Down Expand Up @@ -42,16 +43,15 @@ def __init__(self):
self.do_not_track = True

self.id = None
self.invocation_id = None
self.invocation_id = str(uuid.uuid4())
self.run_started_at = datetime.now()

def state(self):
return "do not track" if self.do_not_track else "tracking"

def initialize(self):
self.do_not_track = False

self.invocation_id = str(uuid.uuid4())

cookie = self.get_cookie()
self.id = cookie.get('id')

Expand Down
21 changes: 21 additions & 0 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class NodeType(object):
Test = 'test'
Archive = 'archive'
Macro = 'macro'
Operation = 'operation'


class RunHookTypes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunHookType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasonable

Start = 'on-run-start'
End = 'on-run-end'
Both = [Start, End]


class This(object):
Expand Down Expand Up @@ -263,6 +270,11 @@ def get_pseudo_test_path(node_name, source_path, test_type):
return os.path.join(*pseudo_path_parts)


def get_pseudo_hook_path(hook_name):
path_parts = ['hooks', "{}.sql".format(hook_name)]
return os.path.join(*path_parts)


def get_run_status_line(results):
total = len(results)
errored = len([r for r in results if r.errored or r.failed])
Expand All @@ -277,3 +289,12 @@ def get_run_status_line(results):
errored=errored,
skipped=skipped
))


def get_nodes_by_tags(nodes, match_tags, resource_type):
matched_nodes = []
for node in nodes:
node_tags = node.get('tags', set())
if len(node_tags & match_tags):
matched_nodes.append(node)
return matched_nodes
Loading