From b6f0eac2cd3f8c00df1a503784c6341a69c65d26 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Tue, 31 Oct 2023 16:24:27 -0400 Subject: [PATCH] Backport Catalog Fix to 1.7.latest (#8953) * Fix issues around new get_catalog_by_relations macro (#8856) * Fix issues around new get_catalog_by_relations macro * Add changelog entry * Fix unit test. * Additional unit testing * Fix cased comparison in catalog-retrieval function (#8940) * Fix cased comparison in catalog-retrieval function. * Fix cased comparison in catalog-retrieval function. --- .../unreleased/Fixes-20231024-155400.yaml | 6 + .../unreleased/Fixes-20231030-093734.yaml | 6 + core/dbt/adapters/base/impl.py | 136 +++++++++++------- core/dbt/task/generate.py | 32 ++++- tests/functional/docs/test_generate.py | 5 + tests/unit/test_postgres_adapter.py | 53 ++++--- 6 files changed, 157 insertions(+), 81 deletions(-) create mode 100644 .changes/unreleased/Fixes-20231024-155400.yaml create mode 100644 .changes/unreleased/Fixes-20231030-093734.yaml diff --git a/.changes/unreleased/Fixes-20231024-155400.yaml b/.changes/unreleased/Fixes-20231024-155400.yaml new file mode 100644 index 00000000000..cd10f06d005 --- /dev/null +++ b/.changes/unreleased/Fixes-20231024-155400.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Rework get_catalog implementation to retain previous adapter interface semantics +time: 2023-10-24T15:54:00.628086-04:00 +custom: + Author: peterallenwebb + Issue: "8846" diff --git a/.changes/unreleased/Fixes-20231030-093734.yaml b/.changes/unreleased/Fixes-20231030-093734.yaml new file mode 100644 index 00000000000..7322dd5042b --- /dev/null +++ b/.changes/unreleased/Fixes-20231030-093734.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix cased comparison in catalog-retrieval function. +time: 2023-10-30T09:37:34.258612-04:00 +custom: + Author: peterallenwebb + Issue: "8939" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 9c8a1c2dcd5..e5861cec655 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -455,30 +455,16 @@ def _get_catalog_relations_by_info_schema( return relations_by_info_schema - def _get_catalog_relations( - self, manifest: Manifest, selected_nodes: Optional[Set] = None - ) -> List[BaseRelation]: - nodes: Iterator[ResultNode] - if selected_nodes: - selected: List[ResultNode] = [] - for unique_id in selected_nodes: - if unique_id in manifest.nodes: - node = manifest.nodes[unique_id] - if node.is_relational and not node.is_ephemeral_model: - selected.append(node) - elif unique_id in manifest.sources: - source = manifest.sources[unique_id] - selected.append(source) - nodes = iter(selected) - else: - nodes = chain( - [ - node - for node in manifest.nodes.values() - if (node.is_relational and not node.is_ephemeral_model) - ], - manifest.sources.values(), - ) + def _get_catalog_relations(self, manifest: Manifest) -> List[BaseRelation]: + + nodes = chain( + [ + node + for node in manifest.nodes.values() + if (node.is_relational and not node.is_ephemeral_model) + ], + manifest.sources.values(), + ) relations = [self.Relation.create_from(self.config, n) for n in nodes] return relations @@ -1166,43 +1152,83 @@ def _get_one_catalog_by_relations( results = self._catalog_filter_table(table, manifest) # type: ignore[arg-type] return results - def get_catalog( - self, manifest: Manifest, selected_nodes: Optional[Set] = None - ) -> Tuple[agate.Table, List[Exception]]: + def get_filtered_catalog( + self, manifest: Manifest, relations: Optional[Set[BaseRelation]] = None + ): + catalogs: agate.Table + if ( + relations is None + or len(relations) > 100 + or not self.supports(Capability.SchemaMetadataByRelations) + ): + # Do it the traditional way. We get the full catalog. + catalogs, exceptions = self.get_catalog(manifest) + else: + # Do it the new way. We try to save time by selecting information + # only for the exact set of relations we are interested in. + catalogs, exceptions = self.get_catalog_by_relations(manifest, relations) + + if relations and catalogs: + relation_map = { + ( + r.database.casefold() if r.database else None, + r.schema.casefold() if r.schema else None, + r.identifier.casefold() if r.identifier else None, + ) + for r in relations + } + + def in_map(row: agate.Row): + d = _expect_row_value("table_database", row).casefold() + s = _expect_row_value("table_schema", row).casefold() + i = _expect_row_value("table_name", row).casefold() + return (d, s, i) in relation_map + + catalogs = catalogs.where(in_map) + + return catalogs, exceptions + + def row_matches_relation(self, row: agate.Row, relations: Set[BaseRelation]): + pass + def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]: with executor(self.config) as tpe: futures: List[Future[agate.Table]] = [] - catalog_relations = self._get_catalog_relations(manifest, selected_nodes) - relation_count = len(catalog_relations) - if relation_count <= 100 and self.supports(Capability.SchemaMetadataByRelations): - relations_by_schema = self._get_catalog_relations_by_info_schema(catalog_relations) - for info_schema in relations_by_schema: - name = ".".join([str(info_schema.database), "information_schema"]) - relations = relations_by_schema[info_schema] - fut = tpe.submit_connected( - self, - name, - self._get_one_catalog_by_relations, - info_schema, - relations, - manifest, - ) - futures.append(fut) - else: - schema_map: SchemaSearchMap = self._get_catalog_schemas(manifest) - for info, schemas in schema_map.items(): - if len(schemas) == 0: - continue - name = ".".join([str(info.database), "information_schema"]) - fut = tpe.submit_connected( - self, name, self._get_one_catalog, info, schemas, manifest - ) - futures.append(fut) - - catalogs, exceptions = catch_as_completed(futures) + schema_map: SchemaSearchMap = self._get_catalog_schemas(manifest) + for info, schemas in schema_map.items(): + if len(schemas) == 0: + continue + name = ".".join([str(info.database), "information_schema"]) + fut = tpe.submit_connected( + self, name, self._get_one_catalog, info, schemas, manifest + ) + futures.append(fut) + catalogs, exceptions = catch_as_completed(futures) return catalogs, exceptions + def get_catalog_by_relations( + self, manifest: Manifest, relations: Set[BaseRelation] + ) -> Tuple[agate.Table, List[Exception]]: + with executor(self.config) as tpe: + futures: List[Future[agate.Table]] = [] + relations_by_schema = self._get_catalog_relations_by_info_schema(relations) + for info_schema in relations_by_schema: + name = ".".join([str(info_schema.database), "information_schema"]) + relations = set(relations_by_schema[info_schema]) + fut = tpe.submit_connected( + self, + name, + self._get_one_catalog_by_relations, + info_schema, + relations, + manifest, + ) + futures.append(fut) + + catalogs, exceptions = catch_as_completed(futures) + return catalogs, exceptions + def cancel_open_connections(self): """Cancel all open connections.""" return self.connections.cancel_open() diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 3a3f99cbb81..d623fddd9af 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -1,7 +1,7 @@ import os import shutil from datetime import datetime -from typing import Dict, List, Any, Optional, Tuple, Set +from typing import Dict, List, Any, Optional, Tuple, Set, Iterable import agate from dbt.dataclass_schema import ValidationError @@ -223,11 +223,6 @@ def run(self) -> CatalogArtifact: DOCS_INDEX_FILE_PATH, os.path.join(self.config.project_target_path, "index.html") ) - # Get the list of nodes that have been selected - selected_nodes = None - if self.job_queue is not None: - selected_nodes = self.job_queue.get_selected_nodes() - for asset_path in self.config.asset_paths: to_asset_path = os.path.join(self.config.project_target_path, asset_path) @@ -247,8 +242,18 @@ def run(self) -> CatalogArtifact: adapter = get_adapter(self.config) with adapter.connection_named("generate_catalog"): fire_event(BuildingCatalog()) + # Get a list of relations we need from the catalog + relations = None + if self.job_queue is not None: + selected_node_ids = self.job_queue.get_selected_nodes() + selected_nodes = self._get_nodes_from_ids(self.manifest, selected_node_ids) + relations = { + adapter.Relation.create_from(adapter.config, node_id) + for node_id in selected_nodes + } + # This generates the catalog as an agate.Table - catalog_table, exceptions = adapter.get_catalog(self.manifest, selected_nodes) + catalog_table, exceptions = adapter.get_filtered_catalog(self.manifest, relations) catalog_data: List[PrimitiveDict] = [ dict(zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row))) @@ -298,6 +303,19 @@ def run(self) -> CatalogArtifact: fire_event(CatalogWritten(path=os.path.abspath(catalog_path))) return results + @staticmethod + def _get_nodes_from_ids(manifest: Manifest, node_ids: Iterable[str]) -> List[ResultNode]: + selected: List[ResultNode] = [] + for unique_id in node_ids: + if unique_id in manifest.nodes: + node = manifest.nodes[unique_id] + if node.is_relational and not node.is_ephemeral_model: + selected.append(node) + elif unique_id in manifest.sources: + source = manifest.sources[unique_id] + selected.append(source) + return selected + def get_node_selector(self) -> ResourceTypeSelector: if self.manifest is None or self.graph is None: raise DbtInternalError("manifest and graph must be set to perform node selection") diff --git a/tests/functional/docs/test_generate.py b/tests/functional/docs/test_generate.py index 641e2fe0e0e..ed833077120 100644 --- a/tests/functional/docs/test_generate.py +++ b/tests/functional/docs/test_generate.py @@ -28,3 +28,8 @@ def test_select_limits_catalog(self, project): catalog = run_dbt(["docs", "generate", "--select", "my_model"]) assert len(catalog.nodes) == 1 assert "model.test.my_model" in catalog.nodes + + def test_select_limits_no_match(self, project): + run_dbt(["run"]) + catalog = run_dbt(["docs", "generate", "--select", "my_missing_model"]) + assert len(catalog.nodes) == 0 diff --git a/tests/unit/test_postgres_adapter.py b/tests/unit/test_postgres_adapter.py index 80b8d61b9b4..f092ae21062 100644 --- a/tests/unit/test_postgres_adapter.py +++ b/tests/unit/test_postgres_adapter.py @@ -1,8 +1,12 @@ +import dataclasses + import agate import decimal import unittest from unittest import mock +from dbt.adapters.base import BaseRelation +from dbt.contracts.relation import Path from dbt.task.debug import DebugTask from dbt.adapters.base.query_headers import MacroQueryStringSetter @@ -322,34 +326,45 @@ def test_set_zero_keepalive(self, psycopg2): ) @mock.patch.object(PostgresAdapter, "execute_macro") - @mock.patch.object(PostgresAdapter, "_get_catalog_relations_by_info_schema") + @mock.patch.object(PostgresAdapter, "_get_catalog_relations") def test_get_catalog_various_schemas(self, mock_get_relations, mock_execute): + self.catalog_test(mock_get_relations, mock_execute, False) + + @mock.patch.object(PostgresAdapter, "execute_macro") + @mock.patch.object(PostgresAdapter, "_get_catalog_relations") + def test_get_filtered_catalog(self, mock_get_relations, mock_execute): + self.catalog_test(mock_get_relations, mock_execute, True) + + def catalog_test(self, mock_get_relations, mock_execute, filtered=False): column_names = ["table_database", "table_schema", "table_name"] - rows = [ - ("dbt", "foo", "bar"), - ("dbt", "FOO", "baz"), - ("dbt", None, "bar"), - ("dbt", "quux", "bar"), - ("dbt", "skip", "bar"), + relations = [ + BaseRelation(path=Path(database="dbt", schema="foo", identifier="bar")), + BaseRelation(path=Path(database="dbt", schema="FOO", identifier="baz")), + BaseRelation(path=Path(database="dbt", schema=None, identifier="bar")), + BaseRelation(path=Path(database="dbt", schema="quux", identifier="bar")), + BaseRelation(path=Path(database="dbt", schema="skip", identifier="bar")), ] + rows = list(map(lambda x: dataclasses.astuple(x.path), relations)) mock_execute.return_value = agate.Table(rows=rows, column_names=column_names) - mock_get_relations.return_value = { - mock.MagicMock(database="dbt"): [ - mock.MagicMock(schema="foo"), - mock.MagicMock(schema="FOO"), - mock.MagicMock(schema="quux"), - ] - } + mock_get_relations.return_value = relations mock_manifest = mock.MagicMock() mock_manifest.get_used_schemas.return_value = {("dbt", "foo"), ("dbt", "quux")} - catalog, exceptions = self.adapter.get_catalog(mock_manifest) - self.assertEqual( - set(map(tuple, catalog)), - {("dbt", "foo", "bar"), ("dbt", "FOO", "baz"), ("dbt", "quux", "bar")}, - ) + if filtered: + catalog, exceptions = self.adapter.get_filtered_catalog( + mock_manifest, set([relations[0], relations[3]]) + ) + else: + catalog, exceptions = self.adapter.get_catalog(mock_manifest) + + tupled_catalog = set(map(tuple, catalog)) + if filtered: + self.assertEqual(tupled_catalog, {rows[0], rows[3]}) + else: + self.assertEqual(tupled_catalog, {rows[0], rows[1], rows[3]}) + self.assertEqual(exceptions, [])