Skip to content

Commit

Permalink
Merge pull request #1014 from fishtown-analytics/feature/tags
Browse files Browse the repository at this point in the history
Add custom tags
  • Loading branch information
drewbanin authored Oct 17, 2018
2 parents 6d66ab0 + ab14380 commit 84588a3
Show file tree
Hide file tree
Showing 18 changed files with 320 additions and 59 deletions.
15 changes: 14 additions & 1 deletion dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,23 @@
'type': 'object',
'additionalProperties': True,
},
'tags': {
'anyOf': [
{
'type': 'array',
'items': {
'type': 'string'
},
},
{
'type': 'string'
}
]
},
},
'required': [
'enabled', 'materialized', 'post-hook', 'pre-hook', 'vars',
'quoting', 'column_types'
'quoting', 'column_types', 'tags'
]
}

Expand Down
67 changes: 58 additions & 9 deletions dbt/graph/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
from dbt.utils import is_enabled, get_materialization, coalesce
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedNode
import dbt.exceptions

SELECTOR_PARENTS = '+'
SELECTOR_CHILDREN = '+'
SELECTOR_GLOB = '*'
SELECTOR_DELIMITER = ':'


class SELECTOR_FILTERS(object):
FQN = 'fqn'
TAG = 'tag'


def split_specs(node_specs):
Expand All @@ -34,12 +41,27 @@ def parse_spec(node_spec):
index_end -= 1

node_selector = node_spec[index_start:index_end]
qualified_node_name = node_selector.split('.')

if SELECTOR_DELIMITER in node_selector:
selector_parts = node_selector.split(SELECTOR_DELIMITER, 1)
selector_type, selector_value = selector_parts

node_filter = {
"type": selector_type,
"value": selector_value
}

else:
node_filter = {
"type": SELECTOR_FILTERS.FQN,
"value": node_selector

}

return {
"select_parents": select_parents,
"select_children": select_children,
"qualified_node_name": qualified_node_name,
"filter": node_filter,
"raw": node_spec
}

Expand Down Expand Up @@ -98,12 +120,12 @@ def _node_is_match(qualified_name, package_names, fqn):
return False


def get_nodes_by_qualified_name(graph, qualified_name):
"""Yield all nodes in the graph that match qualified_name.
def get_nodes_by_qualified_name(graph, qualified_name_selector):
"""Yield all nodes in the graph that match the qualified_name_selector.
:param List[str] qualified_name: The components of the selector or node
name, split on '.'.
:param str qualified_name_selector: The selector or node name
"""
qualified_name = qualified_name_selector.split(".")
package_names = get_package_names(graph)

for node in graph.nodes():
Expand All @@ -112,13 +134,40 @@ def get_nodes_by_qualified_name(graph, qualified_name):
yield node


def get_nodes_by_tag(graph, tag_name):
""" yields nodes from graph that have the specified tag """

for node in graph.nodes():
tags = graph.node[node]['tags']

if tag_name in tags:
yield node


def get_nodes_from_spec(graph, spec):
select_parents = spec['select_parents']
select_children = spec['select_children']
qualified_node_name = spec['qualified_node_name']

selected_nodes = set(get_nodes_by_qualified_name(graph,
qualified_node_name))
filter_map = {
SELECTOR_FILTERS.FQN: get_nodes_by_qualified_name,
SELECTOR_FILTERS.TAG: get_nodes_by_tag,
}

node_filter = spec['filter']
filter_func = filter_map.get(node_filter['type'])

if filter_func is None:
valid_selectors = ", ".join(filter_map.keys())
logger.info("The '{}' selector specified in {} is invalid. Must be "
"one of [{}]".format(
node_filter['type'],
spec['raw'],
valid_selectors))

selected_nodes = set()

else:
selected_nodes = set(filter_func(graph, node_filter['value']))

additional_nodes = set()
test_nodes = set()
Expand Down
23 changes: 11 additions & 12 deletions dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class SourceConfig(object):
ConfigKeys = DBTConfigKeys

AppendListFields = ['pre-hook', 'post-hook']
AppendListFields = ['pre-hook', 'post-hook', 'tags']
ExtendDictFields = ['vars', 'column_types', 'quoting']
ClobberFields = [
'alias',
Expand Down Expand Up @@ -94,22 +94,21 @@ def update_in_model_config(self, config):

# make sure we're not clobbering an array of hooks with a single hook
# string
hook_fields = ['pre-hook', 'post-hook']
for hook_field in hook_fields:
if hook_field in config:
config[hook_field] = self.__get_hooks(config, hook_field)
for field in self.AppendListFields:
if field in config:
config[field] = self.__get_as_list(config, field)

self.in_model_config.update(config)

def __get_hooks(self, relevant_configs, key):
def __get_as_list(self, relevant_configs, key):
if key not in relevant_configs:
return []

hooks = relevant_configs[key]
if not isinstance(hooks, (list, tuple)):
hooks = [hooks]
items = relevant_configs[key]
if not isinstance(items, (list, tuple)):
items = [items]

return hooks
return items

def smart_update(self, mutable_config, new_configs):
relevant_configs = {
Expand All @@ -118,9 +117,9 @@ def smart_update(self, mutable_config, new_configs):
}

for key in SourceConfig.AppendListFields:
new_hooks = self.__get_hooks(relevant_configs, key)
append_fields = self.__get_as_list(relevant_configs, key)
mutable_config[key].extend([
h for h in new_hooks if h not in mutable_config[key]
f for f in append_fields if f not in mutable_config[key]
])

for key in SourceConfig.ExtendDictFields:
Expand Down
4 changes: 4 additions & 0 deletions dbt/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def parse_node(cls, node, node_path, root_project_config,
parsed_node.schema = get_schema(schema_override)
parsed_node.alias = config.config.get('alias', default_alias)

# Set tags on node provided in config blocks
model_tags = config.config.get('tags', [])
parsed_node.tags.extend(model_tags)

# Overwrite node config
config_dict = parsed_node.get('config', {})
config_dict.update(config.config)
Expand Down
1 change: 1 addition & 0 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
'column_types',
'bind',
'quoting',
'tags',
]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{{
config(
materialized = 'ephemeral'
materialized = 'ephemeral',
tags = ['base']
)
}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

{{
config(materialized='ephemeral')
config(materialized='ephemeral', tags=['base'])
}}

select distinct email from {{ ref('base_users') }}
3 changes: 2 additions & 1 deletion test/integration/007_graph_selection_tests/models/users.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{{
config(
materialized = 'table'
materialized = 'table',
tags='bi'
)
}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{{
config(
materialized = 'view'
materialized = 'view',
tags = ['bi']
)
}}

Expand Down
26 changes: 26 additions & 0 deletions test/integration/007_graph_selection_tests/test_graph_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ def test__postgres__specific_model(self):
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)

@attr(type='postgres')
def test__postgres__tags(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:bi'])
self.assertEqual(len(results), 2)

created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assertTrue('users' in created_models)
self.assertTrue('users_rollup' in created_models)

@attr(type='postgres')
def test__postgres__tags_and_children(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:base+'])
self.assertEqual(len(results), 2)

created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assertTrue('users_rollup' in created_models)
self.assertTrue('users' in created_models)

@attr(type='snowflake')
def test__snowflake__specific_model(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def test__postgres__schema_tests_specify_model(self):
['unique_users_id']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_tag(self):
self.run_schema_and_assert(
['tag:bi'],
None,
['unique_users_id',
'unique_users_rollup_gender']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_model_and_children(self):
self.run_schema_and_assert(
Expand All @@ -67,6 +76,16 @@ def test__postgres__schema_tests_specify_model_and_children(self):
['unique_users_id', 'unique_users_rollup_gender']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_tag_and_children(self):
self.run_schema_and_assert(
['tag:base+'],
None,
['unique_emails_email',
'unique_users_id',
'unique_users_rollup_gender']
)

@attr(type='postgres')
def test__postgres__schema_tests_specify_model_and_parents(self):
self.run_schema_and_assert(
Expand Down
63 changes: 63 additions & 0 deletions test/integration/007_graph_selection_tests/test_tag_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from test.integration.base import DBTIntegrationTest, use_profile

class TestGraphSelection(DBTIntegrationTest):

@property
def schema(self):
return "graph_selection_tests_007"

@property
def models(self):
return "test/integration/007_graph_selection_tests/models"

@property
def project_config(self):
return {
"models": {
"test": {
"users": {
"tags": "specified_as_string"
},

"users_rollup": {
"tags": ["specified_in_project"]
}
}
}
}

@use_profile('postgres')
def test__postgres__select_tag(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:specified_as_string'])
self.assertEqual(len(results), 1)

models_run = [r.node['name'] for r in results]
self.assertTrue('users' in models_run)


@use_profile('postgres')
def test__postgres__select_tag_and_children(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', '+tag:specified_in_project+'])
self.assertEqual(len(results), 2)

models_run = [r.node['name'] for r in results]
self.assertTrue('users' in models_run)
self.assertTrue('users_rollup' in models_run)


# check that model configs aren't squashed by project configs
@use_profile('postgres')
def test__postgres__select_tag_in_model_with_project_Config(self):
self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql")

results = self.run_dbt(['run', '--models', 'tag:bi'])
self.assertEqual(len(results), 2)

models_run = [r.node['name'] for r in results]
self.assertTrue('users' in models_run)
self.assertTrue('users_rollup' in models_run)

Loading

0 comments on commit 84588a3

Please sign in to comment.