diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6bad8e7eb85..f1823fb858f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,6 +7,7 @@ # All manifest related fixtures. from tests.unit.utils.adapter import * # noqa from tests.unit.utils.event_manager import * # noqa +from tests.unit.utils.flags import * # noqa from tests.unit.utils.manifest import * # noqa from tests.unit.utils.project import * # noqa diff --git a/tests/unit/context/test_context.py b/tests/unit/context/test_context.py index 3df0109191a..10e591093ee 100644 --- a/tests/unit/context/test_context.py +++ b/tests/unit/context/test_context.py @@ -1,5 +1,4 @@ import os -from argparse import Namespace from typing import Any, Dict, Set from unittest import mock @@ -19,14 +18,11 @@ UnitTestNode, UnitTestOverrides, ) -from dbt.flags import set_from_args from dbt.node_types import NodeType from dbt_common.events.functions import reset_metadata_vars from tests.unit.mock_adapter import adapter_factory from tests.unit.utils import clear_plugin, config_from_parts_or_dicts, inject_adapter -set_from_args(Namespace(WARN_ERROR=False), None) - class TestVar: @pytest.fixture diff --git a/tests/unit/context/test_query_header.py b/tests/unit/context/test_query_header.py index 40c0f1284d9..f14d28d40c4 100644 --- a/tests/unit/context/test_query_header.py +++ b/tests/unit/context/test_query_header.py @@ -1,16 +1,12 @@ import re -from argparse import Namespace from unittest import mock import pytest from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.context.query_header import generate_query_header_context -from dbt.flags import set_from_args from tests.unit.utils import config_from_parts_or_dicts -set_from_args(Namespace(WARN_ERROR=False), None) - class TestQueryHeaderContext: @pytest.fixture diff --git a/tests/unit/events/test_logging.py b/tests/unit/events/test_logging.py index 16441ad6de7..00284ecab78 100644 --- a/tests/unit/events/test_logging.py +++ b/tests/unit/events/test_logging.py @@ -1,5 +1,4 @@ from argparse import Namespace -from copy import deepcopy from pytest_mock import MockerFixture @@ -19,12 +18,10 @@ def test_clears_preexisting_event_manager_state(self) -> None: assert len(manager.loggers) == 1 assert len(manager.callbacks) == 1 - flags = deepcopy(get_flags()) - # setting both of these to none guarantees that no logger will be added - object.__setattr__(flags, "LOG_LEVEL", "none") - object.__setattr__(flags, "LOG_LEVEL_FILE", "none") + args = Namespace(log_level="none", log_level_file="none") + set_from_args(args, {}) - setup_event_logger(flags=flags) + setup_event_logger(get_flags()) assert len(manager.loggers) == 0 assert len(manager.callbacks) == 0 diff --git a/tests/unit/parser/test_manifest.py b/tests/unit/parser/test_manifest.py index b7d470a3552..1f10ee04f25 100644 --- a/tests/unit/parser/test_manifest.py +++ b/tests/unit/parser/test_manifest.py @@ -20,7 +20,6 @@ def test_partial_parse_file_path(self, patched_open, patched_os_exist, patched_s mock_project = MagicMock(RuntimeConfig) mock_project.project_target_path = "mock_target_path" patched_os_exist.return_value = True - set_from_args(Namespace(), {}) ManifestLoader(mock_project, {}) # by default we use the project_target_path patched_open.assert_called_with("mock_target_path/partial_parse.msgpack", "rb") @@ -33,7 +32,6 @@ def test_profile_hash_change(self, mock_project): # This test validate that the profile_hash is updated when the connection keys change profile_hash = "750bc99c1d64ca518536ead26b28465a224be5ffc918bf2a490102faa5a1bcf5" mock_project.credentials.connection_info.return_value = "test" - set_from_args(Namespace(), {}) manifest = ManifestLoader(mock_project, {}) assert manifest.manifest.state_check.profile_hash.checksum == profile_hash mock_project.credentials.connection_info.return_value = "test1" @@ -67,7 +65,6 @@ def test_partial_parse_safe_update_project_parser_files_partially( mock_saved_manifest.files = {} patched_read_manifest_for_partial_parse.return_value = mock_saved_manifest - set_from_args(Namespace(), {}) loader = ManifestLoader(mock_project, {}) loader.safe_update_project_parser_files_partially({}) @@ -150,7 +147,6 @@ def test_partial_parse_file_diff_flag( mock_file_diff = mocker.patch("dbt.parser.read_files.FileDiff.from_dict") mock_file_diff.return_value = FileDiff([], [], []) - set_from_args(Namespace(), {}) ManifestLoader.get_full_manifest(config=mock_project) assert not mock_file_diff.called diff --git a/tests/unit/test_compilation.py b/tests/unit/test_compilation.py index 458efb90901..c18e7fb15d2 100644 --- a/tests/unit/test_compilation.py +++ b/tests/unit/test_compilation.py @@ -1,17 +1,15 @@ import os import tempfile -import unittest -from argparse import Namespace from queue import Empty from unittest import mock -from dbt import compilation -from dbt.flags import set_from_args +import pytest + +from dbt.compilation import Graph, Linker from dbt.graph.cli import parse_difference +from dbt.graph.queue import GraphQueue from dbt.graph.selector import NodeSelector -set_from_args(Namespace(WARN_ERROR=False), None) - def _mock_manifest(nodes): config = mock.MagicMock(enabled=True) @@ -33,41 +31,48 @@ def _mock_manifest(nodes): return manifest -class LinkerTest(unittest.TestCase): - def setUp(self): - self.linker = compilation.Linker() +class TestLinker: + @pytest.fixture + def linker(self) -> Linker: + return Linker() - def test_linker_add_node(self): + def test_linker_add_node(self, linker: Linker) -> None: expected_nodes = ["A", "B", "C"] for node in expected_nodes: - self.linker.add_node(node) + linker.add_node(node) - actual_nodes = self.linker.nodes() + actual_nodes = linker.nodes() for node in expected_nodes: - self.assertIn(node, actual_nodes) + assert node in actual_nodes - self.assertEqual(len(actual_nodes), len(expected_nodes)) + assert len(actual_nodes) == len(expected_nodes) - def test_linker_write_graph(self): + def test_linker_write_graph(self, linker: Linker) -> None: expected_nodes = ["A", "B", "C"] for node in expected_nodes: - self.linker.add_node(node) + linker.add_node(node) manifest = _mock_manifest("ABC") (fd, fname) = tempfile.mkstemp() os.close(fd) try: - self.linker.write_graph(fname, manifest) + linker.write_graph(fname, manifest) assert os.path.exists(fname) finally: os.unlink(fname) - def assert_would_join(self, queue): + def assert_would_join(self, queue: GraphQueue) -> None: """test join() without timeout risk""" - self.assertEqual(queue.inner.unfinished_tasks, 0) - - def _get_graph_queue(self, manifest, include=None, exclude=None): - graph = compilation.Graph(self.linker.graph) + assert queue.inner.unfinished_tasks == 0 + + def _get_graph_queue( + self, + manifest, + linker: Linker, + include=None, + exclude=None, + ) -> GraphQueue: + graph = Graph(linker.graph) selector = NodeSelector(graph, manifest) # TODO: The "eager" string below needs to be replaced with programatic access # to the default value for the indirect selection parameter in @@ -77,114 +82,114 @@ def _get_graph_queue(self, manifest, include=None, exclude=None): spec = parse_difference(include, exclude) return selector.get_graph_queue(spec) - def test_linker_add_dependency(self): + def test_linker_add_dependency(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("A", "C"), ("B", "C")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - queue = self._get_graph_queue(_mock_manifest("ABC")) + queue = self._get_graph_queue(_mock_manifest("ABC"), linker) got = queue.get(block=False) - self.assertEqual(got.unique_id, "C") - with self.assertRaises(Empty): + assert got.unique_id == "C" + with pytest.raises(Empty): queue.get(block=False) - self.assertFalse(queue.empty()) + assert not queue.empty() queue.mark_done("C") - self.assertFalse(queue.empty()) + assert not queue.empty() got = queue.get(block=False) - self.assertEqual(got.unique_id, "B") - with self.assertRaises(Empty): + assert got.unique_id == "B" + with pytest.raises(Empty): queue.get(block=False) - self.assertFalse(queue.empty()) + assert not queue.empty() queue.mark_done("B") - self.assertFalse(queue.empty()) + assert not queue.empty() got = queue.get(block=False) - self.assertEqual(got.unique_id, "A") - with self.assertRaises(Empty): + assert got.unique_id == "A" + with pytest.raises(Empty): queue.get(block=False) - self.assertTrue(queue.empty()) + assert queue.empty() queue.mark_done("A") self.assert_would_join(queue) - self.assertTrue(queue.empty()) + assert queue.empty() - def test_linker_add_disjoint_dependencies(self): + def test_linker_add_disjoint_dependencies(self, linker: Linker) -> None: actual_deps = [("A", "B")] additional_node = "Z" for (l, r) in actual_deps: - self.linker.dependency(l, r) - self.linker.add_node(additional_node) + linker.dependency(l, r) + linker.add_node(additional_node) - queue = self._get_graph_queue(_mock_manifest("ABCZ")) + queue = self._get_graph_queue(_mock_manifest("ABCZ"), linker) # the first one we get must be B, it has the longest dep chain first = queue.get(block=False) - self.assertEqual(first.unique_id, "B") - self.assertFalse(queue.empty()) + assert first.unique_id == "B" + assert not queue.empty() queue.mark_done("B") - self.assertFalse(queue.empty()) + assert not queue.empty() second = queue.get(block=False) - self.assertIn(second.unique_id, {"A", "Z"}) - self.assertFalse(queue.empty()) + assert second.unique_id in {"A", "Z"} + assert not queue.empty() queue.mark_done(second.unique_id) - self.assertFalse(queue.empty()) + assert not queue.empty() third = queue.get(block=False) - self.assertIn(third.unique_id, {"A", "Z"}) - with self.assertRaises(Empty): + assert third.unique_id in {"A", "Z"} + with pytest.raises(Empty): queue.get(block=False) - self.assertNotEqual(second.unique_id, third.unique_id) - self.assertTrue(queue.empty()) + assert second.unique_id != third.unique_id + assert queue.empty() queue.mark_done(third.unique_id) self.assert_would_join(queue) - self.assertTrue(queue.empty()) + assert queue.empty() - def test_linker_dependencies_limited_to_some_nodes(self): + def test_linker_dependencies_limited_to_some_nodes(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("B", "C"), ("C", "D")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - queue = self._get_graph_queue(_mock_manifest("ABCD"), ["B"]) + queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["B"]) got = queue.get(block=False) - self.assertEqual(got.unique_id, "B") - self.assertTrue(queue.empty()) + assert got.unique_id == "B" + assert queue.empty() queue.mark_done("B") self.assert_would_join(queue) - queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), ["A", "B"]) + queue_2 = queue = self._get_graph_queue(_mock_manifest("ABCD"), linker, ["A", "B"]) got = queue_2.get(block=False) - self.assertEqual(got.unique_id, "B") - self.assertFalse(queue_2.empty()) - with self.assertRaises(Empty): + assert got.unique_id == "B" + assert not queue_2.empty() + with pytest.raises(Empty): queue_2.get(block=False) queue_2.mark_done("B") - self.assertFalse(queue_2.empty()) + assert not queue_2.empty() got = queue_2.get(block=False) - self.assertEqual(got.unique_id, "A") - self.assertTrue(queue_2.empty()) - with self.assertRaises(Empty): + assert got.unique_id == "A" + assert queue_2.empty() + with pytest.raises(Empty): queue_2.get(block=False) - self.assertTrue(queue_2.empty()) + assert queue_2.empty() queue_2.mark_done("A") self.assert_would_join(queue_2) - def test__find_cycles__cycles(self): + def test__find_cycles__cycles(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("B", "C"), ("C", "A")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - self.assertIsNotNone(self.linker.find_cycles()) + assert linker.find_cycles() is not None - def test__find_cycles__no_cycles(self): + def test__find_cycles__no_cycles(self, linker: Linker) -> None: actual_deps = [("A", "B"), ("B", "C"), ("C", "D")] for (l, r) in actual_deps: - self.linker.dependency(l, r) + linker.dependency(l, r) - self.assertIsNone(self.linker.find_cycles()) + assert linker.find_cycles() is None diff --git a/tests/unit/test_contracts_graph_parsed.py b/tests/unit/test_contracts_graph_parsed.py index 7a62c394b22..b94271fab08 100644 --- a/tests/unit/test_contracts_graph_parsed.py +++ b/tests/unit/test_contracts_graph_parsed.py @@ -7,7 +7,6 @@ from hypothesis import given from hypothesis.strategies import builds, lists -from dbt import flags from dbt.artifacts.resources import ( ColumnInfo, Dimension, @@ -67,7 +66,10 @@ replace_config, ) -flags.set_from_args(Namespace(SEND_ANONYMOUS_USAGE_STATS=False), None) + +@pytest.fixture +def flags_for_args() -> Namespace: + return Namespace(SEND_ANONYMOUS_USAGE_STATS=False) @pytest.fixture diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index 85d1ea4add5..69d30132ef4 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -1,6 +1,3 @@ -from argparse import Namespace - -from dbt.flags import set_from_args from dbt.internal_deprecations import deprecated @@ -11,6 +8,5 @@ def to_be_decorated(): # simple test that the return value is not modified def test_deprecated_func(): - set_from_args(Namespace(WARN_ERROR=False), None) assert hasattr(to_be_decorated, "__wrapped__") assert to_be_decorated() == 5 diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py index bd9892bcc6c..8a19b0ad39f 100644 --- a/tests/unit/test_events.py +++ b/tests/unit/test_events.py @@ -1,6 +1,5 @@ import logging import re -from argparse import Namespace from typing import TypeVar import pytest @@ -20,7 +19,6 @@ WarnLevel, ) from dbt.events.types import RunResultError -from dbt.flags import set_from_args from dbt.task.printer import print_run_result_error from dbt_common.events import types from dbt_common.events.base_types import msg_from_base_event @@ -29,8 +27,6 @@ from dbt_common.events.functions import msg_to_dict, msg_to_json from dbt_common.events.helpers import get_json_string_utcnow -set_from_args(Namespace(WARN_ERROR=False), None) - # takes in a class and finds any subclasses for it def get_all_subclasses(cls): diff --git a/tests/unit/test_graph_selection.py b/tests/unit/test_graph_selection.py index be283e59926..5d5cbf7469d 100644 --- a/tests/unit/test_graph_selection.py +++ b/tests/unit/test_graph_selection.py @@ -1,5 +1,4 @@ import string -from argparse import Namespace from unittest import mock import networkx as nx @@ -8,12 +7,8 @@ import dbt.graph.cli as graph_cli import dbt.graph.selector as graph_selector import dbt_common.exceptions -from dbt import flags -from dbt.contracts.project import ProjectFlags from dbt.node_types import NodeType -flags.set_from_args(Namespace(), ProjectFlags()) - def _get_graph(): integer_graph = nx.balanced_tree(2, 2, nx.DiGraph()) diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 7d369e6b00d..51fdf8a2024 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -1,5 +1,3 @@ -from argparse import Namespace - from google.protobuf.json_format import MessageToDict from dbt.adapters.events.types import PluginLoadError, RollbackFailed @@ -11,7 +9,6 @@ MainReportArgs, MainReportVersion, ) -from dbt.flags import set_from_args from dbt.version import installed from dbt_common.events import types_pb2 from dbt_common.events.base_types import EventLevel, msg_from_base_event @@ -22,9 +19,6 @@ reset_metadata_vars, ) -set_from_args(Namespace(WARN_ERROR=False), None) - - info_keys = { "name", "code", diff --git a/tests/unit/utils/flags.py b/tests/unit/utils/flags.py new file mode 100644 index 00000000000..20bb4a44ea0 --- /dev/null +++ b/tests/unit/utils/flags.py @@ -0,0 +1,33 @@ +import sys +from argparse import Namespace + +if sys.version_info < (3, 9): + from typing import Generator +else: + from collections.abc import Generator + +import pytest + +from dbt.flags import set_from_args + + +@pytest.fixture +def args_for_flags() -> Namespace: + """Defines the namespace args to be used in `set_from_args` of `set_test_flags` fixture. + + This fixture is meant to be overrided by tests that need specific flags to be set. + """ + return Namespace() + + +@pytest.fixture(autouse=True) +def set_test_flags(args_for_flags: Namespace) -> Generator[None, None, None]: + """Sets up and tears down the global flags for every pytest unit test + + Override `args_for_flags` fixture as needed to set any specific flags. + """ + set_from_args(args_for_flags, {}) + # fixtures stop setup upon yield + yield None + # everything after yield is run at test teardown + set_from_args(Namespace(), {}) diff --git a/tests/unit/utils/manifest.py b/tests/unit/utils/manifest.py index c62d0bd0edf..a7c269cdab2 100644 --- a/tests/unit/utils/manifest.py +++ b/tests/unit/utils/manifest.py @@ -1,5 +1,3 @@ -from argparse import Namespace - import pytest from dbt_semantic_interfaces.type_enums import MetricType @@ -36,11 +34,8 @@ UnitTestDefinition, ) from dbt.contracts.graph.unparsed import UnitTestInputFixture, UnitTestOutputFixture -from dbt.flags import set_from_args from dbt.node_types import NodeType -set_from_args(Namespace(WARN_ERROR=False), None) - def make_model( pkg,