Skip to content

Commit

Permalink
Merge pull request #2485 from Raalsky/direct-nodes-selector
Browse files Browse the repository at this point in the history
Direct child model selector
  • Loading branch information
beckjake authored May 26, 2020
2 parents 75dbb0b + 1d298ea commit c3c99f3
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 38 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
### Features
- Added a `full_refresh` config item that overrides the behavior of the `--full-refresh` flag ([#1009](https://github.com/fishtown-analytics/dbt/issues/1009), [#2348](https://github.com/fishtown-analytics/dbt/pull/2348))
- Added a "docs" field to macros, with a "show" subfield to allow for hiding macros from the documentation site ([#2430](https://github.com/fishtown-analytics/dbt/issues/2430))
- Added intersection syntax for model selector ([#2167](https://github.com/fishtown-analytics/dbt/issues/2167), [#2417](https://github.com/fishtown-analytics/dbt/pull/2417))
- Added intersection syntax for model selector ([#2167](https://github.com/fishtown-analytics/dbt/issues/2167), [#2417](https://github.com/fishtown-analytics/dbt/pull/2417))
- Extends model selection syntax with at most n-th parent/children `dbt run --models 3+m1+2` ([#2052](https://github.com/fishtown-analytics/dbt/issues/2052), [#2485](https://github.com/fishtown-analytics/dbt/pull/2485))

Contributors:
- [@raalsky](https://github.com/Raalsky) ([#2417](https://github.com/fishtown-analytics/dbt/pull/2417))
- [@raalsky](https://github.com/Raalsky) ([#2417](https://github.com/fishtown-analytics/dbt/pull/2417), [#2485](https://github.com/fishtown-analytics/dbt/pull/2485))
- [@alf-mindshift](https://github.com/alf-mindshift) ([#2431](https://github.com/fishtown-analytics/dbt/pull/2431)

## dbt 0.17.0 (Release TBD)
Expand Down
64 changes: 52 additions & 12 deletions core/dbt/graph/selector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from enum import Enum
from itertools import chain
from pathlib import Path
Expand All @@ -11,8 +12,8 @@
from dbt.node_types import NodeType
from dbt.exceptions import RuntimeException, InternalException, warn_or_error

SELECTOR_PARENTS = '+'
SELECTOR_CHILDREN = '+'
SELECTOR_PARENTS = r'^(?P<depth>[0-9]*)\+(?P<node>.*)$'
SELECTOR_CHILDREN = r'^(?P<node>.*)\+(?P<depth>[0-9]*)$'
SELECTOR_GLOB = '*'
SELECTOR_CHILDREN_AND_ANCESTORS = '@'
SELECTOR_DELIMITER = ':'
Expand All @@ -36,20 +37,28 @@ class SelectionCriteria:
def __init__(self, node_spec: str):
self.raw = node_spec
self.select_children = False
self.select_children_max_depth = None
self.select_parents = False
self.select_parents_max_depth = None
self.select_childrens_parents = False

if node_spec.startswith(SELECTOR_CHILDREN_AND_ANCESTORS):
self.select_childrens_parents = True
node_spec = node_spec[1:]

if node_spec.startswith(SELECTOR_PARENTS):
matches = re.match(SELECTOR_PARENTS, node_spec)
if matches:
self.select_parents = True
node_spec = node_spec[1:]
if matches['depth']:
self.select_parents_max_depth = int(matches['depth'])
node_spec = matches['node']

if node_spec.endswith(SELECTOR_CHILDREN):
matches = re.match(SELECTOR_CHILDREN, node_spec)
if matches:
self.select_children = True
node_spec = node_spec[:-1]
if matches['depth']:
self.select_children_max_depth = int(matches['depth'])
node_spec = matches['node']

if self.select_children and self.select_childrens_parents:
raise RuntimeException(
Expand Down Expand Up @@ -329,16 +338,41 @@ def select_childrens_parents(self, selected: Set[str]) -> Set[str]:
ancestors_for = self.select_children(selected) | selected
return self.select_parents(ancestors_for) | ancestors_for

def select_children(self, selected: Set[str]) -> Set[str]:
def descendants(self, node, max_depth: Optional[int] = None) -> Set[str]:
"""Returns all nodes reachable from `node` in `graph`"""
if not self.graph.has_node(node):
raise InternalException(f'Node {node} not found in the graph!')
des = nx.single_source_shortest_path_length(G=self.graph,
source=node,
cutoff=max_depth)\
.keys()
return des - {node}

def ancestors(self, node, max_depth: Optional[int] = None) -> Set[str]:
"""Returns all nodes having a path to `node` in `graph`"""
if not self.graph.has_node(node):
raise InternalException(f'Node {node} not found in the graph!')
with nx.utils.reversed(self.graph):
anc = nx.single_source_shortest_path_length(G=self.graph,
source=node,
cutoff=max_depth)\
.keys()
return anc - {node}

def select_children(self,
selected: Set[str],
max_depth: Optional[int] = None) -> Set[str]:
descendants: Set[str] = set()
for node in selected:
descendants.update(nx.descendants(self.graph, node))
descendants.update(self.descendants(node, max_depth=max_depth))
return descendants

def select_parents(self, selected: Set[str]) -> Set[str]:
def select_parents(self,
selected: Set[str],
max_depth: Optional[int] = None) -> Set[str]:
ancestors: Set[str] = set()
for node in selected:
ancestors.update(nx.ancestors(self.graph, node))
ancestors.update(self.ancestors(node, max_depth=max_depth))
return ancestors

def select_successors(self, selected: Set[str]) -> Set[str]:
Expand All @@ -354,9 +388,15 @@ def collect_models(
if spec.select_childrens_parents:
additional.update(self.select_childrens_parents(selected))
if spec.select_parents:
additional.update(self.select_parents(selected))
additional.update(
self.select_parents(selected,
max_depth=spec.select_parents_max_depth)
)
if spec.select_children:
additional.update(self.select_children(selected))
additional.update(
self.select_children(selected,
max_depth=spec.select_children_max_depth)
)
return additional

def subgraph(self, nodes: Iterable[str]) -> 'Graph':
Expand Down
46 changes: 46 additions & 0 deletions test/integration/007_graph_selection_tests/test_graph_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ def test__postgres__tags_and_children(self):
self.assertTrue('users' in created_models)
self.assert_correct_schemas()

@use_profile('postgres')
def test__postgres__tags_and_children_limited(self):
self.run_sql_file("seed.sql")

results = self.run_dbt(['run', '--models', 'tag:base+2'])
self.assertEqual(len(results), 3)

created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assertIn('emails_alt', created_models)
self.assertIn('users_rollup', created_models)
self.assertIn('users', created_models)
self.assertNotIn('users_rollup_dependency', created_models)
self.assert_correct_schemas()

@use_profile('snowflake')
def test__snowflake__specific_model(self):
self.run_sql_file("seed.sql")
Expand Down Expand Up @@ -112,6 +128,22 @@ def test__snowflake__specific_model_and_children(self):
self.assertFalse('BASE_USERS' in created_models)
self.assertFalse('EMAILS' in created_models)

@use_profile('postgres')
def test__postgres__specific_model_and_children_limited(self):
self.run_sql_file("seed.sql")

results = self.run_dbt(['run', '--models', 'users+1'])
self.assertEqual(len(results), 3)

self.assertTablesEqual("seed", "users")
self.assertTablesEqual("summary_expected", "users_rollup")
created_models = self.get_models_in_schema()
self.assertIn('emails_alt', created_models)
self.assertNotIn('base_users', created_models)
self.assertNotIn('emails', created_models)
self.assertNotIn('users_rollup_dependency', created_models)
self.assert_correct_schemas()

@use_profile('postgres')
def test__postgres__specific_model_and_parents(self):
self.run_sql_file("seed.sql")
Expand Down Expand Up @@ -142,6 +174,20 @@ def test__snowflake__specific_model_and_parents(self):
self.assertFalse('BASE_USERS' in created_models)
self.assertFalse('EMAILS' in created_models)

@use_profile('postgres')
def test__postgres__specific_model_and_parents_limited(self):
self.run_sql_file("seed.sql")

results = self.run_dbt(['run', '--models', '1+users_rollup'])
self.assertEqual(len(results), 2)

self.assertTablesEqual("seed", "users")
self.assertTablesEqual("summary_expected", "users_rollup")
created_models = self.get_models_in_schema()
self.assertFalse('base_users' in created_models)
self.assertFalse('emails' in created_models)
self.assert_correct_schemas()

@use_profile('postgres')
def test__postgres__specific_model_with_exclusion(self):
self.run_sql_file("seed.sql")
Expand Down
92 changes: 68 additions & 24 deletions test/unit/test_graph_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,34 @@ def test__select_children_except_in_package(self):
['b'],
set(['m.X.a','m.X.c', 'm.Y.d','m.X.e','m.Y.f','m.X.g']))

def test__select_children(self):
self.run_specs_and_assert(
self.package_graph,
['X.c+'],
[],
set(['m.X.c', 'm.Y.f', 'm.X.g']))

def test__select_children_limited(self):
self.run_specs_and_assert(
self.package_graph,
['X.a+1'],
[],
set(['m.X.a', 'm.Y.b', 'm.X.c']))

def test__select_parents(self):
self.run_specs_and_assert(
self.package_graph,
['+Y.f'],
[],
set(['m.X.c', 'm.Y.f', 'm.X.a']))

def test__select_parents_limited(self):
self.run_specs_and_assert(
self.package_graph,
['1+Y.f'],
[],
set(['m.X.c', 'm.Y.f']))

def test__select_children_except_tag(self):
self.run_specs_and_assert(
self.package_graph,
Expand Down Expand Up @@ -269,10 +297,12 @@ def test__select_concat_intersection_exclude_intersection_concat(self):
set(['m.X.e', 'm.Y.f'])
)

def parse_spec_and_assert(self, spec, parents, children, filter_type, filter_value, childrens_parents):
def parse_spec_and_assert(self, spec, parents, parents_max_depth, children, children_max_depth, filter_type, filter_value, childrens_parents):
parsed = graph_selector.SelectionCriteria(spec)
self.assertEqual(parsed.select_parents, parents)
self.assertEqual(parsed.select_parents_max_depth, parents_max_depth)
self.assertEqual(parsed.select_children, children)
self.assertEqual(parsed.select_children_max_depth, children_max_depth)
self.assertEqual(parsed.selector_type, filter_type)
self.assertEqual(parsed.selector_value, filter_value)
self.assertEqual(parsed.select_childrens_parents, childrens_parents)
Expand All @@ -282,37 +312,51 @@ def invalid_spec(self, spec):
graph_selector.SelectionCriteria(spec)

def test__spec_parsing(self):
self.parse_spec_and_assert('a', False, False, 'fqn', 'a', False)
self.parse_spec_and_assert('+a', True, False, 'fqn', 'a', False)
self.parse_spec_and_assert('a+', False, True, 'fqn', 'a', False)
self.parse_spec_and_assert('+a+', True, True, 'fqn', 'a', False)
self.parse_spec_and_assert('@a', False, False, 'fqn', 'a', True)
self.parse_spec_and_assert('a', False, None, False, None, 'fqn', 'a', False)
self.parse_spec_and_assert('+a', True, None, False, None, 'fqn', 'a', False)
self.parse_spec_and_assert('256+a', True, 256, False, None, 'fqn', 'a', False)
self.parse_spec_and_assert('a+', False, None, True, None, 'fqn', 'a', False)
self.parse_spec_and_assert('a+256', False, None, True, 256, 'fqn', 'a', False)
self.parse_spec_and_assert('+a+', True, None, True, None, 'fqn', 'a', False)
self.parse_spec_and_assert('16+a+32', True, 16, True, 32, 'fqn', 'a', False)
self.parse_spec_and_assert('@a', False, None, False, None, 'fqn', 'a', True)
self.invalid_spec('@a+')

self.parse_spec_and_assert('a.b', False, False, 'fqn', 'a.b', False)
self.parse_spec_and_assert('+a.b', True, False, 'fqn', 'a.b', False)
self.parse_spec_and_assert('a.b+', False, True, 'fqn', 'a.b', False)
self.parse_spec_and_assert('+a.b+', True, True, 'fqn', 'a.b', False)
self.parse_spec_and_assert('@a.b', False, False, 'fqn', 'a.b', True)
self.parse_spec_and_assert('a.b', False, None, False, None, 'fqn', 'a.b', False)
self.parse_spec_and_assert('+a.b', True, None, False, None, 'fqn', 'a.b', False)
self.parse_spec_and_assert('256+a.b', True, 256, False, None, 'fqn', 'a.b', False)
self.parse_spec_and_assert('a.b+', False, None, True, None, 'fqn', 'a.b', False)
self.parse_spec_and_assert('a.b+256', False, None, True, 256, 'fqn', 'a.b', False)
self.parse_spec_and_assert('+a.b+', True, None, True, None, 'fqn', 'a.b', False)
self.parse_spec_and_assert('16+a.b+32', True, 16, True, 32, 'fqn', 'a.b', False)
self.parse_spec_and_assert('@a.b', False, None, False, None, 'fqn', 'a.b', True)
self.invalid_spec('@a.b+')

self.parse_spec_and_assert('a.b.*', False, False, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('+a.b.*', True, False, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('a.b.*+', False, True, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('+a.b.*+', True, True, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('@a.b.*', False, False, 'fqn', 'a.b.*', True)
self.parse_spec_and_assert('a.b.*', False, None, False, None, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('+a.b.*', True, None, False, None, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('256+a.b.*', True, 256, False, None, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('a.b.*+', False, None, True, None, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('a.b.*+256', False, None, True, 256, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('+a.b.*+', True, None, True, None, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('16+a.b.*+32', True, 16, True, 32, 'fqn', 'a.b.*', False)
self.parse_spec_and_assert('@a.b.*', False, None, False, None, 'fqn', 'a.b.*', True)
self.invalid_spec('@a.b*+')

self.parse_spec_and_assert('tag:a', False, False, 'tag', 'a', False)
self.parse_spec_and_assert('+tag:a', True, False, 'tag', 'a', False)
self.parse_spec_and_assert('tag:a+', False, True, 'tag', 'a', False)
self.parse_spec_and_assert('+tag:a+', True, True, 'tag', 'a', False)
self.parse_spec_and_assert('@tag:a', False, False, 'tag', 'a', True)
self.parse_spec_and_assert('tag:a', False, None, False, None, 'tag', 'a', False)
self.parse_spec_and_assert('+tag:a', True, None, False, None, 'tag', 'a', False)
self.parse_spec_and_assert('256+tag:a', True, 256, False, None, 'tag', 'a', False)
self.parse_spec_and_assert('tag:a+', False, None, True, None, 'tag', 'a', False)
self.parse_spec_and_assert('tag:a+256', False, None, True, 256, 'tag', 'a', False)
self.parse_spec_and_assert('+tag:a+', True, None, True, None, 'tag', 'a', False)
self.parse_spec_and_assert('16+tag:a+32', True, 16, True, 32, 'tag', 'a', False)
self.parse_spec_and_assert('@tag:a', False, None, False, None, 'tag', 'a', True)
self.invalid_spec('@tag:a+')

self.parse_spec_and_assert('source:a', False, False, 'source', 'a', False)
self.parse_spec_and_assert('source:a+', False, True, 'source', 'a', False)
self.parse_spec_and_assert('@source:a', False, False, 'source', 'a', True)
self.parse_spec_and_assert('source:a', False, None, False, None, 'source', 'a', False)
self.parse_spec_and_assert('source:a+', False, None, True, None, 'source', 'a', False)
self.parse_spec_and_assert('source:a+1', False, None, True, 1, 'source', 'a', False)
self.parse_spec_and_assert('source:a+32', False, None, True, 32, 'source', 'a', False)
self.parse_spec_and_assert('@source:a', False, None, False, None, 'source', 'a', True)
self.invalid_spec('@source:a+')

def test__package_name_getter(self):
Expand Down

0 comments on commit c3c99f3

Please sign in to comment.