Skip to content

Commit

Permalink
(fixes #311) Configure tags, and select them with --models
Browse files Browse the repository at this point in the history
  • Loading branch information
drewbanin committed Sep 20, 2018
1 parent f588876 commit 68631ae
Show file tree
Hide file tree
Showing 15 changed files with 314 additions and 55 deletions.
15 changes: 14 additions & 1 deletion dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,23 @@
'type': 'object',
'additionalProperties': True,
},
'tags': {
'anyOf': [
{
'type': 'array',
'items': {
'type': 'string'
},
},
{
'type': 'string'
}
]
},
},
'required': [
'enabled', 'materialized', 'post-hook', 'pre-hook', 'vars',
'quoting', 'column_types'
'quoting', 'column_types', 'tags'
]
}

Expand Down
59 changes: 53 additions & 6 deletions dbt/graph/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
from dbt.utils import is_enabled, get_materialization, coalesce
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedNode
import dbt.exceptions

SELECTOR_PARENTS = '+'
SELECTOR_CHILDREN = '+'
SELECTOR_GLOB = '*'
SELECTOR_DELIMITER = ':'


class SELECTOR_FILTERS:
FQN = 'fqn'
TAG = 'tag'


def split_specs(node_specs):
Expand All @@ -34,12 +41,27 @@ def parse_spec(node_spec):
index_end -= 1

node_selector = node_spec[index_start:index_end]
qualified_node_name = node_selector.split('.')

if SELECTOR_DELIMITER in node_selector:
selector_parts = node_selector.split(SELECTOR_DELIMITER, 1)
selector_type, selector_value = selector_parts

node_filter = {
"type": selector_type,
"value": selector_value
}

else:
node_filter = {
"type": SELECTOR_FILTERS.FQN,
"value": node_selector

}

return {
"select_parents": select_parents,
"select_children": select_children,
"qualified_node_name": qualified_node_name,
"filter": node_filter,
"raw": node_spec
}

Expand Down Expand Up @@ -74,10 +96,11 @@ def is_selected_node(real_node, node_selector):
return True


def get_nodes_by_qualified_name(graph, qualified_name):
def get_nodes_by_qualified_name(graph, name):
""" returns a node if matched, else throws a CompilerError. qualified_name
should be either 1) a node name or 2) a dot-notation qualified selector"""

qualified_name = name.split('.')
package_names = get_package_names(graph)

for node in graph.nodes():
Expand All @@ -98,13 +121,37 @@ def get_nodes_by_qualified_name(graph, qualified_name):
break


def get_nodes_by_tag(graph, tag_name):
""" yields nodes from graph that have the specified tag """

for node in graph.nodes():
tags = graph.node[node]['tags']

if tag_name in tags:
yield node


def get_nodes_from_spec(graph, spec):
select_parents = spec['select_parents']
select_children = spec['select_children']
qualified_node_name = spec['qualified_node_name']

selected_nodes = set(get_nodes_by_qualified_name(graph,
qualified_node_name))
filter_map = {
SELECTOR_FILTERS.FQN: get_nodes_by_qualified_name,
SELECTOR_FILTERS.TAG: get_nodes_by_tag,
}

node_filter = spec['filter']
filter_func = filter_map.get(node_filter['type'])

if filter_func is None:
valid_selectors = ", ".join(filter_map.keys())
logger.info("The '{}' selector specified in {} is invalid. Must be "
"one of [{}]".format(node_filter['type'], spec['raw'],
valid_selectors))
selected_nodes = set()

else:
selected_nodes = set(filter_func(graph, node_filter['value']))

additional_nodes = set()
test_nodes = set()
Expand Down
23 changes: 11 additions & 12 deletions dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class SourceConfig(object):
ConfigKeys = DBTConfigKeys

AppendListFields = ['pre-hook', 'post-hook']
AppendListFields = ['pre-hook', 'post-hook', 'tags']
ExtendDictFields = ['vars', 'column_types', 'quoting']
ClobberFields = [
'alias',
Expand Down Expand Up @@ -94,22 +94,21 @@ def update_in_model_config(self, config):

# make sure we're not clobbering an array of hooks with a single hook
# string
hook_fields = ['pre-hook', 'post-hook']
for hook_field in hook_fields:
if hook_field in config:
config[hook_field] = self.__get_hooks(config, hook_field)
for field in self.AppendListFields:
if field in config:
config[field] = self.__get_as_list(config, field)

self.in_model_config.update(config)

def __get_hooks(self, relevant_configs, key):
def __get_as_list(self, relevant_configs, key):
if key not in relevant_configs:
return []

hooks = relevant_configs[key]
if not isinstance(hooks, (list, tuple)):
hooks = [hooks]
items = relevant_configs[key]
if not isinstance(items, (list, tuple)):
items = [items]

return hooks
return items

def smart_update(self, mutable_config, new_configs):
relevant_configs = {
Expand All @@ -118,9 +117,9 @@ def smart_update(self, mutable_config, new_configs):
}

for key in SourceConfig.AppendListFields:
new_hooks = self.__get_hooks(relevant_configs, key)
append_fields = self.__get_as_list(relevant_configs, key)
mutable_config[key].extend([
h for h in new_hooks if h not in mutable_config[key]
f for f in append_fields if f not in mutable_config[key]
])

for key in SourceConfig.ExtendDictFields:
Expand Down
5 changes: 5 additions & 0 deletions dbt/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dbt.hooks
import dbt.clients.jinja
import dbt.context.parser
from dbt.compat import basestring

from dbt.utils import coalesce
from dbt.logger import GLOBAL_LOGGER as logger
Expand Down Expand Up @@ -128,6 +129,10 @@ def parse_node(cls, node, node_path, root_project_config,
parsed_node.schema = get_schema(schema_override)
parsed_node.alias = config.config.get('alias', default_alias)

# Set tags on node provided in config blocks
model_tags = config.config.get('tags', [])
parsed_node.tags.extend(model_tags)

# Overwrite node config
config_dict = parsed_node.get('config', {})
config_dict.update(config.config)
Expand Down
1 change: 1 addition & 0 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'column_types',
'bind',
'quoting',
'tags',
]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{{
config(
materialized = 'ephemeral'
materialized = 'ephemeral',
tags = ['base']
)
}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

{{
config(materialized='ephemeral')
config(materialized='ephemeral', tags=['base'])
}}

select distinct email from {{ ref('base_users') }}
3 changes: 2 additions & 1 deletion test/integration/007_graph_selection_tests/models/users.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{{
config(
materialized = 'table'
materialized = 'table',
tags='bi'
)
}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{{
config(
materialized = 'view'
materialized = 'view',
tags = ['bi']
)
}}

Expand Down
30 changes: 30 additions & 0 deletions test/integration/007_graph_selection_tests/test_graph_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@ def test__postgres__specific_model(self):
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)

@attr(type='postgres')
def test__postgres__tags(self):
self.use_profile('postgres')
self.use_default_project()
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:bi'])
self.assertEqual(len(results), 2)

created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assertTrue('users' in created_models)
self.assertTrue('users_rollup' in created_models)

@attr(type='postgres')
def test__postgres__tags_and_children(self):
self.use_profile('postgres')
self.use_default_project()
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:base+'])
self.assertEqual(len(results), 2)

created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assertTrue('users_rollup' in created_models)
self.assertTrue('users' in created_models)

@attr(type='snowflake')
def test__snowflake__specific_model(self):
self.use_profile('snowflake')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test__postgres__schema_tests_specify_model(self):
['unique_users_id']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_tag(self):
self.run_schema_and_assert(
['tag:bi'],
None,
['unique_users_id',
'unique_users_rollup_gender']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_model_and_children(self):
self.run_schema_and_assert(
Expand All @@ -70,6 +79,16 @@ def test__postgres__schema_tests_specify_model_and_children(self):
['unique_users_id', 'unique_users_rollup_gender']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_tag_and_children(self):
self.run_schema_and_assert(
['tag:base+'],
None,
['unique_emails_email',
'unique_users_id',
'unique_users_rollup_gender']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_model_and_parents(self):
self.run_schema_and_assert(
Expand Down
63 changes: 63 additions & 0 deletions test/integration/007_graph_selection_tests/test_tag_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from test.integration.base import DBTIntegrationTest, use_profile

class TestGraphSelection(DBTIntegrationTest):

@property
def schema(self):
return "graph_selection_tests_007"

@property
def models(self):
return "test/integration/007_graph_selection_tests/models"

@property
def project_config(self):
return {
"models": {
"test": {
"users": {
"tags": "specified_as_string"
},

"users_rollup": {
"tags": ["specified_in_project"]
}
}
}
}

@use_profile('postgres')
def test__postgres__select_tag(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:specified_as_string'])
self.assertEqual(len(results), 1)

models_run = [r.node['name'] for r in results]
self.assertTrue('users' in models_run)


@use_profile('postgres')
def test__postgres__select_tag_and_children(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', '+tag:specified_in_project+'])
self.assertEqual(len(results), 2)

models_run = [r.node['name'] for r in results]
self.assertTrue('users' in models_run)
self.assertTrue('users_rollup' in models_run)


# check that model configs aren't squashed by project configs
@use_profile('postgres')
def test__postgres__select_tag_in_model_with_project_Config(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:bi'])
self.assertEqual(len(results), 2)

models_run = [r.node['name'] for r in results]
self.assertTrue('users' in models_run)
self.assertTrue('users_rollup' in models_run)

Loading

0 comments on commit 68631ae

Please sign in to comment.