Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved flags fixturing for for repository unit tests #10190

Merged
merged 9 commits into from
May 21, 2024
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# All manifest related fixtures.
from tests.unit.utils.adapter import * # noqa
from tests.unit.utils.event_manager import * # noqa
from tests.unit.utils.flags import * # noqa
from tests.unit.utils.manifest import * # noqa
from tests.unit.utils.project import * # noqa

Expand Down
4 changes: 0 additions & 4 deletions tests/unit/context/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from argparse import Namespace
from typing import Any, Dict, Set
from unittest import mock

Expand All @@ -19,14 +18,11 @@
UnitTestNode,
UnitTestOverrides,
)
from dbt.flags import set_from_args
from dbt.node_types import NodeType
from dbt_common.events.functions import reset_metadata_vars
from tests.unit.mock_adapter import adapter_factory
from tests.unit.utils import clear_plugin, config_from_parts_or_dicts, inject_adapter

set_from_args(Namespace(WARN_ERROR=False), None)


class TestVar:
@pytest.fixture
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/context/test_query_header.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import re
from argparse import Namespace
from unittest import mock

import pytest

from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.context.query_header import generate_query_header_context
from dbt.flags import set_from_args
from tests.unit.utils import config_from_parts_or_dicts

set_from_args(Namespace(WARN_ERROR=False), None)


class TestQueryHeaderContext:
@pytest.fixture
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/events/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from argparse import Namespace
from copy import deepcopy

from pytest_mock import MockerFixture

Expand All @@ -19,12 +18,10 @@ def test_clears_preexisting_event_manager_state(self) -> None:
assert len(manager.loggers) == 1
assert len(manager.callbacks) == 1

flags = deepcopy(get_flags())
# setting both of these to none guarantees that no logger will be added
object.__setattr__(flags, "LOG_LEVEL", "none")
object.__setattr__(flags, "LOG_LEVEL_FILE", "none")
args = Namespace(log_level="none", log_level_file="none")
set_from_args(args, {})

setup_event_logger(flags=flags)
setup_event_logger(get_flags())
assert len(manager.loggers) == 0
assert len(manager.callbacks) == 0

Expand Down
4 changes: 0 additions & 4 deletions tests/unit/parser/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_partial_parse_file_path(self, patched_open, patched_os_exist, patched_s
mock_project = MagicMock(RuntimeConfig)
mock_project.project_target_path = "mock_target_path"
patched_os_exist.return_value = True
set_from_args(Namespace(), {})
ManifestLoader(mock_project, {})
# by default we use the project_target_path
patched_open.assert_called_with("mock_target_path/partial_parse.msgpack", "rb")
Expand All @@ -33,7 +32,6 @@ def test_profile_hash_change(self, mock_project):
# This test validate that the profile_hash is updated when the connection keys change
profile_hash = "750bc99c1d64ca518536ead26b28465a224be5ffc918bf2a490102faa5a1bcf5"
mock_project.credentials.connection_info.return_value = "test"
set_from_args(Namespace(), {})
manifest = ManifestLoader(mock_project, {})
assert manifest.manifest.state_check.profile_hash.checksum == profile_hash
mock_project.credentials.connection_info.return_value = "test1"
Expand Down Expand Up @@ -67,7 +65,6 @@ def test_partial_parse_safe_update_project_parser_files_partially(
mock_saved_manifest.files = {}
patched_read_manifest_for_partial_parse.return_value = mock_saved_manifest

set_from_args(Namespace(), {})
loader = ManifestLoader(mock_project, {})
loader.safe_update_project_parser_files_partially({})

Expand Down Expand Up @@ -150,7 +147,6 @@ def test_partial_parse_file_diff_flag(
mock_file_diff = mocker.patch("dbt.parser.read_files.FileDiff.from_dict")
mock_file_diff.return_value = FileDiff([], [], [])

set_from_args(Namespace(), {})
ManifestLoader.get_full_manifest(config=mock_project)
assert not mock_file_diff.called

Expand Down
149 changes: 77 additions & 72 deletions tests/unit/test_compilation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
import tempfile
import unittest
from argparse import Namespace
from queue import Empty
from unittest import mock

from dbt import compilation
from dbt.flags import set_from_args
import pytest

from dbt.compilation import Graph, Linker
from dbt.graph.cli import parse_difference
from dbt.graph.queue import GraphQueue
from dbt.graph.selector import NodeSelector

set_from_args(Namespace(WARN_ERROR=False), None)


def _mock_manifest(nodes):
config = mock.MagicMock(enabled=True)
Expand All @@ -33,41 +31,48 @@ def _mock_manifest(nodes):
return manifest


class LinkerTest(unittest.TestCase):
def setUp(self):
self.linker = compilation.Linker()
class TestLinker:
@pytest.fixture
def linker(self) -> Linker:
return Linker()

def test_linker_add_node(self):
def test_linker_add_node(self, linker: Linker) -> None:
expected_nodes = ["A", "B", "C"]
for node in expected_nodes:
self.linker.add_node(node)
linker.add_node(node)

actual_nodes = self.linker.nodes()
actual_nodes = linker.nodes()
for node in expected_nodes:
self.assertIn(node, actual_nodes)
assert node in actual_nodes

self.assertEqual(len(actual_nodes), len(expected_nodes))
assert len(actual_nodes) == len(expected_nodes)

def test_linker_write_graph(self):
def test_linker_write_graph(self, linker: Linker) -> None:
expected_nodes = ["A", "B", "C"]
for node in expected_nodes:
self.linker.add_node(node)
linker.add_node(node)

manifest = _mock_manifest("ABC")
(fd, fname) = tempfile.mkstemp()
os.close(fd)
try:
self.linker.write_graph(fname, manifest)
linker.write_graph(fname, manifest)
assert os.path.exists(fname)
finally:
os.unlink(fname)

def assert_would_join(self, queue):
def assert_would_join(self, queue: GraphQueue) -> None:
"""test join() without timeout risk"""
self.assertEqual(queue.inner.unfinished_tasks, 0)

def _get_graph_queue(self, manifest, include=None, exclude=None):
graph = compilation.Graph(self.linker.graph)
assert queue.inner.unfinished_tasks == 0

def _get_graph_queue(
self,
manifest,
linker: Linker,
include=None,
exclude=None,
) -> GraphQueue:
graph = Graph(linker.graph)
selector = NodeSelector(graph, manifest)
# TODO: The "eager" string below needs to be replaced with programatic access
# to the default value for the indirect selection parameter in
Expand All @@ -77,114 +82,114 @@ def _get_graph_queue(self, manifest, include=None, exclude=None):
spec = parse_difference(include, exclude)
return selector.get_graph_queue(spec)

def test_linker_add_dependency(self):
def test_linker_add_dependency(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("A", "C"), ("B", "C")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

queue = self._get_graph_queue(_mock_manifest("ABC"))
queue = self._get_graph_queue(_mock_manifest("ABC"), linker)

got = queue.get(block=False)
self.assertEqual(got.unique_id, "C")
with self.assertRaises(Empty):
assert got.unique_id == "C"
with pytest.raises(Empty):
queue.get(block=False)
self.assertFalse(queue.empty())
assert not queue.empty()
queue.mark_done("C")
self.assertFalse(queue.empty())
assert not queue.empty()

got = queue.get(block=False)
self.assertEqual(got.unique_id, "B")
with self.assertRaises(Empty):
assert got.unique_id == "B"
with pytest.raises(Empty):
queue.get(block=False)
self.assertFalse(queue.empty())
assert not queue.empty()
queue.mark_done("B")
self.assertFalse(queue.empty())
assert not queue.empty()

got = queue.get(block=False)
self.assertEqual(got.unique_id, "A")
with self.assertRaises(Empty):
assert got.unique_id == "A"
with pytest.raises(Empty):
queue.get(block=False)
self.assertTrue(queue.empty())
assert queue.empty()
queue.mark_done("A")
self.assert_would_join(queue)
self.assertTrue(queue.empty())
assert queue.empty()

def test_linker_add_disjoint_dependencies(self):
def test_linker_add_disjoint_dependencies(self, linker: Linker) -> None:
actual_deps = [("A", "B")]
additional_node = "Z"

for (l, r) in actual_deps:
self.linker.dependency(l, r)
self.linker.add_node(additional_node)
linker.dependency(l, r)
linker.add_node(additional_node)

queue = self._get_graph_queue(_mock_manifest("ABCZ"))
queue = self._get_graph_queue(_mock_manifest("ABCZ"), linker)
# the first one we get must be B, it has the longest dep chain
first = queue.get(block=False)
self.assertEqual(first.unique_id, "B")
self.assertFalse(queue.empty())
assert first.unique_id == "B"
assert not queue.empty()
queue.mark_done("B")
self.assertFalse(queue.empty())
assert not queue.empty()

second = queue.get(block=False)
self.assertIn(second.unique_id, {"A", "Z"})
self.assertFalse(queue.empty())
assert second.unique_id in {"A", "Z"}
assert not queue.empty()
queue.mark_done(second.unique_id)
self.assertFalse(queue.empty())
assert not queue.empty()

third = queue.get(block=False)
self.assertIn(third.unique_id, {"A", "Z"})
with self.assertRaises(Empty):
assert third.unique_id in {"A", "Z"}
with pytest.raises(Empty):
queue.get(block=False)
self.assertNotEqual(second.unique_id, third.unique_id)
self.assertTrue(queue.empty())
assert second.unique_id != third.unique_id
assert queue.empty()
queue.mark_done(third.unique_id)
self.assert_would_join(queue)
self.assertTrue(queue.empty())
assert queue.empty()

def test_linker_dependencies_limited_to_some_nodes(self):
def test_linker_dependencies_limited_to_some_nodes(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("B", "C"), ("C", "D")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

queue = self._get_graph_queue(_mock_manifest("ABCD"), ["B"])
queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["B"])
got = queue.get(block=False)
self.assertEqual(got.unique_id, "B")
self.assertTrue(queue.empty())
assert got.unique_id == "B"
assert queue.empty()
queue.mark_done("B")
self.assert_would_join(queue)

queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), ["A", "B"])
queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["A", "B"])
got = queue_2.get(block=False)
self.assertEqual(got.unique_id, "B")
self.assertFalse(queue_2.empty())
with self.assertRaises(Empty):
assert got.unique_id == "B"
assert not queue_2.empty()
with pytest.raises(Empty):
queue_2.get(block=False)
queue_2.mark_done("B")
self.assertFalse(queue_2.empty())
assert not queue_2.empty()

got = queue_2.get(block=False)
self.assertEqual(got.unique_id, "A")
self.assertTrue(queue_2.empty())
with self.assertRaises(Empty):
assert got.unique_id == "A"
assert queue_2.empty()
with pytest.raises(Empty):
queue_2.get(block=False)
self.assertTrue(queue_2.empty())
assert queue_2.empty()
queue_2.mark_done("A")
self.assert_would_join(queue_2)

def test__find_cycles__cycles(self):
def test__find_cycles__cycles(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("B", "C"), ("C", "A")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

self.assertIsNotNone(self.linker.find_cycles())
assert linker.find_cycles() is not None

def test__find_cycles__no_cycles(self):
def test__find_cycles__no_cycles(self, linker: Linker) -> None:
actual_deps = [("A", "B"), ("B", "C"), ("C", "D")]

for (l, r) in actual_deps:
self.linker.dependency(l, r)
linker.dependency(l, r)

self.assertIsNone(self.linker.find_cycles())
assert linker.find_cycles() is None
6 changes: 4 additions & 2 deletions tests/unit/test_contracts_graph_parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from hypothesis import given
from hypothesis.strategies import builds, lists

from dbt import flags
from dbt.artifacts.resources import (
ColumnInfo,
Dimension,
Expand Down Expand Up @@ -67,7 +66,10 @@
replace_config,
)

flags.set_from_args(Namespace(SEND_ANONYMOUS_USAGE_STATS=False), None)

@pytest.fixture
def flags_for_args() -> Namespace:
return Namespace(SEND_ANONYMOUS_USAGE_STATS=False)


@pytest.fixture
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/test_deprecations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from argparse import Namespace

from dbt.flags import set_from_args
from dbt.internal_deprecations import deprecated


Expand All @@ -11,6 +8,5 @@ def to_be_decorated():

# simple test that the return value is not modified
def test_deprecated_func():
set_from_args(Namespace(WARN_ERROR=False), None)
assert hasattr(to_be_decorated, "__wrapped__")
assert to_be_decorated() == 5
Loading
Loading