From a0e7c97065079bbbda73d4ad93962dfefc3aea0a Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 12 Jul 2024 15:24:01 -0700 Subject: [PATCH 01/16] support explicit schema resolver in aggregator --- .../source/snowflake/snowflake_config.py | 2 +- .../source/snowflake/snowflake_queries.py | 8 ++- .../source/snowflake/snowflake_v2.py | 7 +-- .../sql_parsing/sql_parsing_aggregator.py | 61 ++++++++----------- 4 files changed, 36 insertions(+), 42 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index f6247eb949417..056724f89eb10 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -206,7 +206,7 @@ class SnowflakeV2Config( ) lazy_schema_resolver: bool = Field( - default=False, + default=True, description="If enabled, uses lazy schema resolver to resolve schemas for tables and views. " "This is useful if you have a large number of schemas and want to avoid bulk fetching the schema for each table/view.", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index c647a624a5467..229814c183466 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -20,6 +20,7 @@ from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.api.source_helpers import auto_workunit from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.source.snowflake.constants import SnowflakeObjectDomain from datahub.ingestion.source.snowflake.snowflake_config import ( DEFAULT_TEMP_TABLES_PATTERNS, @@ -111,6 +112,7 @@ def __init__( connection: SnowflakeConnection, config: SnowflakeQueriesExtractorConfig, structured_report: SourceReport, + graph: Optional[DataHubGraph], ): self.connection = connection @@ -122,7 +124,8 @@ def __init__( platform=self.platform, platform_instance=self.config.platform_instance, env=self.config.env, - # graph=self.ctx.graph, + graph=graph, + eager_graph_load=False, generate_lineage=self.config.include_lineage, generate_queries=self.config.include_queries, generate_usage_statistics=self.config.include_usage_statistics, @@ -371,14 +374,13 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeQueriesSourceConfig): self.config = config self.report = SnowflakeQueriesSourceReport() - self.platform = "snowflake" - self.connection = self.config.connection.get_connection() self.queries_extractor = SnowflakeQueriesExtractor( connection=self.connection, config=self.config, structured_report=self.report, + graph=self.ctx.graph, ) self.report.queries_extractor = self.queries_extractor.report diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index d8eda98da422b..358aa09db9585 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -147,18 +147,17 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): platform=self.platform, platform_instance=self.config.platform_instance, env=self.config.env, - graph=( + graph=self.ctx.graph, + eager_graph_load=( # If we're ingestion schema metadata for tables/views, then we will populate # schemas into the resolver as we go. We only need to do a bulk fetch # if we're not ingesting schema metadata as part of ingestion. - self.ctx.graph - if not ( + ( self.config.include_technical_schema and self.config.include_tables and self.config.include_views ) and not self.config.lazy_schema_resolver - else None ), generate_usage_statistics=False, generate_operations=False, diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index 677b96269fe58..d74faf72fb54e 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -251,7 +251,9 @@ def __init__( platform: str, platform_instance: Optional[str] = None, env: str = builder.DEFAULT_ENV, + schema_resolver: Optional[SchemaResolver] = None, graph: Optional[DataHubGraph] = None, + eager_graph_load: bool = True, generate_lineage: bool = True, generate_queries: bool = True, generate_query_subject_fields: bool = True, @@ -297,17 +299,29 @@ def __init__( # Set up the schema resolver. self._schema_resolver: SchemaResolver - if graph is None: + if schema_resolver is not None: + # If explicitly provided, use it. + assert self.platform.platform_name == schema_resolver.platform + assert self.platform_instance == schema_resolver.platform_instance + assert self.env == schema_resolver.env + self._schema_resolver = schema_resolver + elif graph is not None and eager_graph_load and self._need_schemas: + # Bulk load schemas using the graph client. + self._schema_resolver = graph.initialize_schema_resolver_from_datahub( + platform=self.platform.urn(), + platform_instance=self.platform_instance, + env=self.env, + ) + else: + # Otherwise, use a lazy-loading schema resolver. self._schema_resolver = self._exit_stack.enter_context( SchemaResolver( platform=self.platform.platform_name, platform_instance=self.platform_instance, env=self.env, + graph=graph, ) ) - else: - self._schema_resolver = None # type: ignore - self._initialize_schema_resolver_from_graph(graph) # Initialize internal data structures. # This leans pretty heavily on the our query fingerprinting capabilities. @@ -373,6 +387,8 @@ def __init__( # Usage aggregator. This will only be initialized if usage statistics are enabled. # TODO: Replace with FileBackedDict. + # TODO: The BaseUsageConfig class is much too broad for our purposes, and has a number of + # configs that won't be respected here. Using it is misleading. self._usage_aggregator: Optional[UsageAggregator[UrnStr]] = None if self.generate_usage_statistics: assert self.usage_config is not None @@ -392,7 +408,13 @@ def close(self) -> None: @property def _need_schemas(self) -> bool: - return self.generate_lineage or self.generate_usage_statistics + # Unless the aggregator is totally disabled, we will need schema information. + return ( + self.generate_lineage + or self.generate_usage_statistics + or self.generate_queries + or self.generate_operations + ) def register_schema( self, urn: Union[str, DatasetUrn], schema: models.SchemaMetadataClass @@ -414,35 +436,6 @@ def register_schemas_from_stream( yield wu - def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None: - # requires a graph instance - # if no schemas are currently registered in the schema resolver - # and we need the schema resolver (e.g. lineage or usage is enabled) - # then use the graph instance to fetch all schemas for the - # platform/instance/env combo - if not self._need_schemas: - return - - if ( - self._schema_resolver is not None - and self._schema_resolver.schema_count() > 0 - ): - # TODO: Have a mechanism to override this, e.g. when table ingestion is enabled but view ingestion is not. - logger.info( - "Not fetching any schemas from the graph, since " - f"there are {self._schema_resolver.schema_count()} schemas already registered." - ) - return - - # TODO: The initialize_schema_resolver_from_datahub method should take in a SchemaResolver - # that it can populate or add to, rather than creating a new one and dropping any schemas - # that were already loaded into the existing one. - self._schema_resolver = graph.initialize_schema_resolver_from_datahub( - platform=self.platform.urn(), - platform_instance=self.platform_instance, - env=self.env, - ) - def _maybe_format_query(self, query: str) -> str: if self.format_queries: with self.report.sql_formatting_timer: From cf2b940561b898c11716e629348a8c6a90bcb8fe Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 12 Jul 2024 17:09:41 -0700 Subject: [PATCH 02/16] call snowflake queries from main source --- .../source/snowflake/snowflake_config.py | 6 ++ .../source/snowflake/snowflake_queries.py | 11 ++-- .../source/snowflake/snowflake_report.py | 5 ++ .../source/snowflake/snowflake_v2.py | 66 +++++++++++++++---- .../source_report/ingestion_stage.py | 1 + 5 files changed, 71 insertions(+), 18 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 056724f89eb10..c7612df6c5035 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -131,6 +131,7 @@ class SnowflakeIdentifierConfig( # Changing default value here. convert_urns_to_lowercase: bool = Field( default=True, + description="Whether to convert dataset urns to lowercase.", ) @@ -205,6 +206,11 @@ class SnowflakeV2Config( description="Populates view->view and table->view column lineage using DataHub's sql parser.", ) + use_queries_v2: bool = Field( + default=False, + description="If enabled, uses the new queries extractor to extract queries from snowflake.", + ) + lazy_schema_resolver: bool = Field( default=True, description="If enabled, uses lazy schema resolver to resolve schemas for tables and views. " diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 229814c183466..1b1f9b66f2df2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -38,6 +38,7 @@ ) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig from datahub.metadata.urns import CorpUserUrn +from datahub.sql_parsing.schema_resolver import SchemaResolver from datahub.sql_parsing.sql_parsing_aggregator import ( KnownLineageMapping, PreparsedQuery, @@ -77,12 +78,6 @@ class SnowflakeQueriesExtractorConfig(SnowflakeIdentifierConfig, SnowflakeFilter hidden_from_docs=True, ) - convert_urns_to_lowercase: bool = pydantic.Field( - # Override the default. - default=True, - description="Whether to convert dataset urns to lowercase.", - ) - include_lineage: bool = True include_queries: bool = True include_usage_statistics: bool = True @@ -112,7 +107,8 @@ def __init__( connection: SnowflakeConnection, config: SnowflakeQueriesExtractorConfig, structured_report: SourceReport, - graph: Optional[DataHubGraph], + graph: Optional[DataHubGraph] = None, + schema_resolver: Optional[SchemaResolver] = None, ): self.connection = connection @@ -124,6 +120,7 @@ def __init__( platform=self.platform, platform_instance=self.config.platform_instance, env=self.config.env, + schema_resolver=schema_resolver, graph=graph, eager_graph_load=False, generate_lineage=self.config.include_lineage, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py index 4924546383aa4..80b6be36e5ffa 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py @@ -15,6 +15,9 @@ from datahub.utilities.perf_timer import PerfTimer if TYPE_CHECKING: + from datahub.ingestion.source.snowflake.snowflake_queries import ( + SnowflakeQueriesExtractorReport, + ) from datahub.ingestion.source.snowflake.snowflake_schema import ( SnowflakeDataDictionary, ) @@ -113,6 +116,8 @@ class SnowflakeV2Report( data_dictionary_cache: Optional["SnowflakeDataDictionary"] = None + queries_extractor: Optional["SnowflakeQueriesExtractorReport"] = None + # These will be non-zero if snowflake information_schema queries fail with error - # "Information schema query returned too much data. Please repeat query with more selective predicates."" # This will result in overall increase in time complexity diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 358aa09db9585..df22a406adfde 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -25,6 +25,7 @@ TestableSource, TestConnectionReport, ) +from datahub.ingestion.api.source_helpers import auto_workunit from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.snowflake.constants import ( GENERIC_PERMISSION_ERROR_KEY, @@ -42,6 +43,10 @@ SnowflakeLineageExtractor, ) from datahub.ingestion.source.snowflake.snowflake_profiler import SnowflakeProfiler +from datahub.ingestion.source.snowflake.snowflake_queries import ( + SnowflakeQueriesExtractor, + SnowflakeQueriesExtractorConfig, +) from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report from datahub.ingestion.source.snowflake.snowflake_schema import ( SnowflakeDataDictionary, @@ -72,6 +77,7 @@ from datahub.ingestion.source_report.ingestion_stage import ( LINEAGE_EXTRACTION, METADATA_EXTRACTION, + QUERIES_EXTRACTION, ) from datahub.sql_parsing.sql_parsing_aggregator import SqlParsingAggregator from datahub.utilities.registries.domain_registry import DomainRegistry @@ -452,8 +458,6 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: databases = schema_extractor.databases - self.connection.close() - # TODO: The checkpoint state for stale entity detection can be committed here. if self.config.shares: @@ -483,23 +487,63 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: discovered_datasets = discovered_tables + discovered_views - if self.config.include_table_lineage and self.lineage_extractor: - self.report.set_ingestion_stage("*", LINEAGE_EXTRACTION) - yield from self.lineage_extractor.get_workunits( - discovered_tables=discovered_tables, - discovered_views=discovered_views, + if self.config.use_queries_v2: + self.report.set_ingestion_stage("*", "View Parsing") + assert self.aggregator is not None + yield from auto_workunit(self.aggregator.gen_metadata()) + + self.report.set_ingestion_stage("*", QUERIES_EXTRACTION) + + schema_resolver = self.aggregator._schema_resolver + + queries_extractor = SnowflakeQueriesExtractor( + connection=self.connection, + config=SnowflakeQueriesExtractorConfig( + # TODO: Refactor this a bit so it's not as redundant. + database_pattern=self.config.database_pattern, + schema_pattern=self.config.schema_pattern, + table_pattern=self.config.table_pattern, + view_pattern=self.config.view_pattern, + match_fully_qualified_names=self.config.match_fully_qualified_names, + convert_urns_to_lowercase=self.config.convert_urns_to_lowercase, + env=self.config.env, + platform_instance=self.config.platform_instance, + window=self.config, + temporary_tables_pattern=self.config.temporary_tables_pattern, + include_lineage=self.config.include_table_lineage, + include_usage_statistics=self.config.include_usage_stats, + include_operations=self.config.include_operational_stats, + ), + structured_report=self.report, + schema_resolver=schema_resolver, ) - if ( - self.config.include_usage_stats or self.config.include_operational_stats - ) and self.usage_extractor: - yield from self.usage_extractor.get_usage_workunits(discovered_datasets) + # TODO: This is slightly suboptimal because we create two SqlParsingAggregator instances with different configs + # but a shared schema resolver. That's fine for now though - once we remove the old lineage/usage extractors, + # it should be pretty straightforward to refactor this and only initialize the aggregator once. + self.report.queries_extractor = queries_extractor.report + yield from queries_extractor.get_workunits_internal() + + else: + if self.config.include_table_lineage and self.lineage_extractor: + self.report.set_ingestion_stage("*", LINEAGE_EXTRACTION) + yield from self.lineage_extractor.get_workunits( + discovered_tables=discovered_tables, + discovered_views=discovered_views, + ) + + if ( + self.config.include_usage_stats or self.config.include_operational_stats + ) and self.usage_extractor: + yield from self.usage_extractor.get_usage_workunits(discovered_datasets) if self.config.include_assertion_results: yield from SnowflakeAssertionsHandler( self.config, self.report, self.connection ).get_assertion_workunits(discovered_datasets) + self.connection.close() + def report_warehouse_failure(self) -> None: if self.config.warehouse is not None: self.report_error( diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py b/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py index 14dc428b65389..4308b405e46e3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py +++ b/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py @@ -14,6 +14,7 @@ USAGE_EXTRACTION_INGESTION = "Usage Extraction Ingestion" USAGE_EXTRACTION_OPERATIONAL_STATS = "Usage Extraction Operational Stats" USAGE_EXTRACTION_USAGE_AGGREGATION = "Usage Extraction Usage Aggregation" +QUERIES_EXTRACTION = "Queries Extraction" PROFILING = "Profiling" From fe69b78465a991eb3e4a51e2f0482ed3b123a4de Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 12 Jul 2024 17:28:35 -0700 Subject: [PATCH 03/16] fix lint --- .../src/datahub/ingestion/source/sql/sql_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py index 93c7025aeee4e..3ead59eed2d39 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py @@ -83,10 +83,10 @@ class SQLCommonConfig( description='Attach domains to databases, schemas or tables during ingestion using regex patterns. Domain key can be a guid like *urn:li:domain:ec428203-ce86-4db3-985d-5a8ee6df32ba* or a string like "Marketing".) If you provide strings, then datahub will attempt to resolve this name to a guid, and will error out if this fails. There can be multiple domain keys specified.', ) - include_views: Optional[bool] = Field( + include_views: bool = Field( default=True, description="Whether views should be ingested." ) - include_tables: Optional[bool] = Field( + include_tables: bool = Field( default=True, description="Whether tables should be ingested." ) From 3bdfb16d060da44240cfa5c9ff3d017aa940e1c0 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 12 Jul 2024 17:39:36 -0700 Subject: [PATCH 04/16] improve timers --- .../source/snowflake/snowflake_queries.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 1b1f9b66f2df2..1b595515e05ad 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -1,3 +1,4 @@ +import dataclasses import functools import json import logging @@ -52,6 +53,7 @@ DownstreamColumnRef, ) from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList +from datahub.utilities.perf_timer import PerfTimer logger = logging.getLogger(__name__) @@ -91,13 +93,14 @@ class SnowflakeQueriesSourceConfig(SnowflakeQueriesExtractorConfig): @dataclass class SnowflakeQueriesExtractorReport(Report): - window: Optional[BaseTimeWindowConfig] = None - + audit_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) sql_aggregator: Optional[SqlAggregatorReport] = None @dataclass class SnowflakeQueriesSourceReport(SourceReport): + window: Optional[BaseTimeWindowConfig] = None queries_extractor: Optional[SnowflakeQueriesExtractorReport] = None @@ -175,8 +178,6 @@ def is_allowed_table(self, name: str) -> bool: def get_workunits_internal( self, ) -> Iterable[MetadataWorkUnit]: - self.report.window = self.config.window - # TODO: Add some logic to check if the cached audit log is stale or not. audit_log_file = self.local_temp_path / "audit_log.sqlite" use_cached_audit_log = audit_log_file.exists() @@ -193,11 +194,13 @@ def get_workunits_internal( queries = FileBackedList(shared_connection) logger.info("Fetching audit log") - for entry in self.fetch_audit_log(): - queries.append(entry) + with self.report.audit_log_fetch_timer: + for entry in self.fetch_audit_log(): + queries.append(entry) - for query in queries: - self.aggregator.add(query) + with self.report.audit_log_load_timer: + for query in queries: + self.aggregator.add(query) yield from auto_workunit(self.aggregator.gen_metadata()) @@ -387,6 +390,8 @@ def create(cls, config_dict: dict, ctx: PipelineContext) -> Self: return cls(ctx, config) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: + self.report.window = self.config.window + # TODO: Disable auto status processor? return self.queries_extractor.get_workunits_internal() From b334d43fcad8fa3745a5edc4a63e78651d02ba3c Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 18:07:53 -0700 Subject: [PATCH 05/16] fix logging --- .../src/datahub/ingestion/api/source.py | 15 ++++++--- .../source/snowflake/snowflake_lineage_v2.py | 6 ++-- .../source/snowflake/snowflake_schema_gen.py | 7 ++-- .../source/snowflake/snowflake_shares.py | 7 ++-- .../source/snowflake/snowflake_utils.py | 33 +++++++++++-------- 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index ad1b312ef445c..b79f69970a634 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -97,6 +97,7 @@ def report_log( context: Optional[str] = None, exc: Optional[BaseException] = None, log: bool = False, + stacklevel: int = 1, ) -> None: """ Report a user-facing warning for the ingestion run. @@ -109,7 +110,8 @@ def report_log( exc: The exception associated with the event. We'll show the stack trace when in debug mode. """ - stacklevel = 2 + # One for this method, and one for the containing report_* call. + stacklevel = stacklevel + 2 log_key = f"{title}-{message}" entries = self._entries[level] @@ -118,6 +120,8 @@ def report_log( context = f"{context[:_MAX_CONTEXT_STRING_LENGTH]} ..." log_content = f"{message} => {context}" if context else message + if title: + log_content = f"{title}: {log_content}" if exc: log_content += f"{log_content}: {exc}" @@ -255,9 +259,10 @@ def report_failure( context: Optional[str] = None, title: Optional[LiteralString] = None, exc: Optional[BaseException] = None, + log: bool = True, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.ERROR, message, title, context, exc, log=False + StructuredLogLevel.ERROR, message, title, context, exc, log=log ) def failure( @@ -266,9 +271,10 @@ def failure( context: Optional[str] = None, title: Optional[LiteralString] = None, exc: Optional[BaseException] = None, + log: bool = True, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.ERROR, message, title, context, exc, log=True + StructuredLogLevel.ERROR, message, title, context, exc, log=log ) def info( @@ -277,9 +283,10 @@ def info( context: Optional[str] = None, title: Optional[LiteralString] = None, exc: Optional[BaseException] = None, + log: bool = True, ) -> None: self._structured_logs.report_log( - StructuredLogLevel.INFO, message, title, context, exc, log=True + StructuredLogLevel.INFO, message, title, context, exc, log=log ) def __post_init__(self) -> None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 3e65f06200418..d799c90594c4b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -306,9 +306,9 @@ def _populate_external_lineage_from_show_query( self.report.num_external_table_edges_scanned += 1 except Exception as e: logger.debug(e, exc_info=e) - self.report_warning( - "external_lineage", - f"Populating external table lineage from Snowflake failed due to error {e}.", + self.structured_reporter.warning( + "Error populating external table lineage from Snowflake", + exc=e, ) self.report_status(EXTERNAL_LINEAGE, False) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py index e604ed96b8eb6..a19253e5c5e15 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py @@ -461,9 +461,10 @@ def _process_schema( yield from self._process_tag(tag) if not snowflake_schema.views and not snowflake_schema.tables: - self.report_warning( - "No tables/views found in schema. If tables exist, please grant REFERENCES or SELECT permissions on them.", - f"{db_name}.{schema_name}", + self.structured_reporter.warning( + title="No tables/views found in schema", + message="If tables exist, please grant REFERENCES or SELECT permissions on them.", + context=f"{db_name}.{schema_name}", ) def fetch_views_for_schema( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index dad0ce7b59ee1..9c61fc3ec9bee 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -94,9 +94,10 @@ def report_missing_databases( missing_dbs = [db for db in inbounds + outbounds if db not in db_names] if missing_dbs and self.config.platform_instance: - self.report_warning( - "snowflake-shares", - f"Databases {missing_dbs} were not ingested. Siblings/Lineage will not be set for these.", + self.report.warning( + title="Extra Snowflake share configurations", + message="Some databases referenced by the share configs were not ingested. Siblings/lineage will not be set for these.", + context=f"{missing_dbs}", ) elif missing_dbs: logger.debug( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index c33fbb3d0bfc8..b1645e32a8229 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -167,21 +167,28 @@ def is_dataset_pattern_allowed( SnowflakeObjectDomain.MATERIALIZED_VIEW, ): return False + if len(dataset_params) != 3: - self.report_warning( - "invalid-dataset-pattern", - f"Found {dataset_params} of type {dataset_type}", + self.structured_reporter.info( + title="Unexpected dataset pattern", + message=f"Found a {dataset_type} with an unexpected number of parts. Database and schema filtering will not work as expected, but table filtering will still work.", + context=dataset_name, ) - # NOTE: this case returned `True` earlier when extracting lineage - return False + # We fall-through here so table/view filtering still works. - if not self.filter_config.database_pattern.allowed( - dataset_params[0].strip('"') - ) or not is_schema_allowed( - self.filter_config.schema_pattern, - dataset_params[1].strip('"'), - dataset_params[0].strip('"'), - self.filter_config.match_fully_qualified_names, + if ( + len(dataset_params) >= 1 + and not self.filter_config.database_pattern.allowed( + dataset_params[0].strip('"') + ) + ) or ( + len(dataset_params) >= 2 + and not is_schema_allowed( + self.filter_config.schema_pattern, + dataset_params[1].strip('"'), + dataset_params[0].strip('"'), + self.filter_config.match_fully_qualified_names, + ) ): return False @@ -210,7 +217,7 @@ def is_dataset_pattern_allowed( def cleanup_qualified_name(self, qualified_name: str) -> str: name_parts = qualified_name.split(".") if len(name_parts) != 3: - self.structured_reporter.report_warning( + self.structured_reporter.info( title="Unexpected dataset pattern", message="We failed to parse a Snowflake qualified name into its constituent parts. " "DB/schema/table filtering may not work as expected on these entities.", From fb11cda2745f66e2bc73f7ee73f2ce3d9f2e15ab Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 18:54:48 -0700 Subject: [PATCH 06/16] start refactoring SnowflakeCommonMixin --- .../source/snowflake/snowflake_assertion.py | 26 +++---- .../source/snowflake/snowflake_lineage_v2.py | 41 +++++++---- .../source/snowflake/snowflake_profiler.py | 2 +- .../source/snowflake/snowflake_queries.py | 65 +++++++++++------ .../source/snowflake/snowflake_schema_gen.py | 73 +++++++++++-------- .../source/snowflake/snowflake_shares.py | 20 +++-- .../source/snowflake/snowflake_summary.py | 12 ++- .../source/snowflake/snowflake_usage_v2.py | 30 +++++--- .../source/snowflake/snowflake_utils.py | 67 +++++++++-------- .../source/snowflake/snowflake_v2.py | 37 +++++----- .../tests/unit/test_snowflake_shares.py | 20 ++--- 11 files changed, 221 insertions(+), 172 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py index 2a1d18c83e6fa..a7c008d932a71 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py @@ -11,14 +11,13 @@ ) from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.snowflake.snowflake_config import ( - SnowflakeIdentifierConfig, - SnowflakeV2Config, -) +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config from datahub.ingestion.source.snowflake.snowflake_connection import SnowflakeConnection from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report -from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeIdentifierMixin +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeIdentifierBuilder, +) from datahub.metadata.com.linkedin.pegasus2avro.assertion import ( AssertionResult, AssertionResultType, @@ -40,23 +39,20 @@ class DataQualityMonitoringResult(BaseModel): VALUE: int -class SnowflakeAssertionsHandler(SnowflakeIdentifierMixin): +class SnowflakeAssertionsHandler: def __init__( self, config: SnowflakeV2Config, report: SnowflakeV2Report, connection: SnowflakeConnection, + identifiers: SnowflakeIdentifierBuilder, ) -> None: self.config = config self.report = report - self.logger = logger self.connection = connection + self.identifiers = identifiers self._urns_processed: List[str] = [] - @property - def identifier_config(self) -> SnowflakeIdentifierConfig: - return self.config - def get_assertion_workunits( self, discovered_datasets: List[str] ) -> Iterable[MetadataWorkUnit]: @@ -80,10 +76,10 @@ def _gen_platform_instance_wu(self, urn: str) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( entityUrn=urn, aspect=DataPlatformInstance( - platform=make_data_platform_urn(self.platform), + platform=make_data_platform_urn(self.identifiers.platform), instance=( make_dataplatform_instance_urn( - self.platform, self.config.platform_instance + self.identifiers.platform, self.config.platform_instance ) if self.config.platform_instance else None @@ -98,7 +94,7 @@ def _process_result_row( result = DataQualityMonitoringResult.parse_obj(result_row) assertion_guid = result.METRIC_NAME.split("__")[-1].lower() status = bool(result.VALUE) # 1 if PASS, 0 if FAIL - assertee = self.get_dataset_identifier( + assertee = self.identifiers.get_dataset_identifier( result.TABLE_NAME, result.TABLE_SCHEMA, result.TABLE_DATABASE ) if assertee in discovered_datasets: @@ -107,7 +103,7 @@ def _process_result_row( aspect=AssertionRunEvent( timestampMillis=datetime_to_ts_millis(result.MEASUREMENT_TIME), runId=result.MEASUREMENT_TIME.strftime("%Y-%m-%dT%H:%M:%SZ"), - asserteeUrn=self.gen_dataset_urn(assertee), + asserteeUrn=self.identifiers.gen_dataset_urn(assertee), status=AssertionRunStatus.COMPLETE, assertionUrn=make_assertion_urn(assertion_guid), result=AssertionResult( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index d799c90594c4b..50556efdcdd77 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass from datetime import datetime -from typing import Any, Callable, Collection, Iterable, List, Optional, Set, Tuple, Type +from typing import Any, Collection, Iterable, List, Optional, Set, Tuple, Type from pydantic import BaseModel, validator @@ -21,7 +21,11 @@ ) from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report -from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeFilter, + SnowflakeIdentifierBuilder, +) from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantLineageRunSkipHandler, ) @@ -119,18 +123,19 @@ def __init__( config: SnowflakeV2Config, report: SnowflakeV2Report, connection: SnowflakeConnection, - dataset_urn_builder: Callable[[str], str], + filters: SnowflakeFilter, + identifiers: SnowflakeIdentifierBuilder, redundant_run_skip_handler: Optional[RedundantLineageRunSkipHandler], sql_aggregator: SqlParsingAggregator, ) -> None: self.config = config self.report = report - self.logger = logger - self.dataset_urn_builder = dataset_urn_builder self.connection = connection + self.filters = filters + self.identifiers = identifiers + self.redundant_run_skip_handler = redundant_run_skip_handler self.sql_aggregator = sql_aggregator - self.redundant_run_skip_handler = redundant_run_skip_handler self.start_time, self.end_time = ( self.report.lineage_start_time, self.report.lineage_end_time, @@ -233,7 +238,7 @@ def get_known_query_lineage( if not db_row.UPSTREAM_TABLES: return None - downstream_table_urn = self.dataset_urn_builder(dataset_name) + downstream_table_urn = self.identifiers.gen_dataset_urn(dataset_name) known_lineage = KnownQueryLineageInfo( query_text=query.query_text, @@ -288,7 +293,7 @@ def _populate_external_lineage_from_show_query( external_tables_query: str = SnowflakeQuery.show_external_tables() try: for db_row in self.connection.query(external_tables_query): - key = self.get_dataset_identifier( + key = self.identifiers.get_dataset_identifier( db_row["name"], db_row["schema_name"], db_row["database_name"] ) @@ -299,7 +304,7 @@ def _populate_external_lineage_from_show_query( upstream_urn=make_s3_urn_for_lineage( db_row["location"], self.config.env ), - downstream_urn=self.dataset_urn_builder(key), + downstream_urn=self.identifiers.gen_dataset_urn(key), ) self.report.num_external_table_edges_scanned += 1 @@ -362,7 +367,7 @@ def _process_external_lineage_result_row( self.report.num_external_table_edges_scanned += 1 return KnownLineageMapping( upstream_urn=make_s3_urn_for_lineage(loc, self.config.env), - downstream_urn=self.dataset_urn_builder(key), + downstream_urn=self.identifiers.gen_dataset_urn(key), ) return None @@ -422,12 +427,14 @@ def map_query_result_upstreams( ) if upstream_name and ( not self.config.validate_upstreams_against_patterns - or self.is_dataset_pattern_allowed( + or self.filters.is_dataset_pattern_allowed( upstream_name, upstream_table.upstream_object_domain, ) ): - upstreams.append(self.dataset_urn_builder(upstream_name)) + upstreams.append( + self.identifiers.gen_dataset_urn(upstream_name) + ) except Exception as e: logger.debug(e, exc_info=e) return upstreams @@ -491,7 +498,7 @@ def build_finegrained_lineage( return None column_lineage = ColumnLineageInfo( downstream=DownstreamColumnRef( - table=dataset_urn, column=self.snowflake_identifier(col) + table=dataset_urn, column=self.identifiers.snowflake_identifier(col) ), upstreams=sorted(column_upstreams), ) @@ -508,7 +515,7 @@ def build_finegrained_lineage_upstreams( and upstream_col.column_name and ( not self.config.validate_upstreams_against_patterns - or self.is_dataset_pattern_allowed( + or self.filters.is_dataset_pattern_allowed( upstream_col.object_name, upstream_col.object_domain, ) @@ -519,8 +526,10 @@ def build_finegrained_lineage_upstreams( ) column_upstreams.append( ColumnRef( - table=self.dataset_urn_builder(upstream_dataset_name), - column=self.snowflake_identifier(upstream_col.column_name), + table=self.identifiers.gen_dataset_urn(upstream_dataset_name), + column=self.identifiers.snowflake_identifier( + upstream_col.column_name + ), ) ) return column_upstreams diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py index 4deeb9f96f48e..89dc949e844f4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py @@ -86,7 +86,7 @@ def get_workunits( ) def get_dataset_name(self, table_name: str, schema_name: str, db_name: str) -> str: - return self.get_dataset_identifier(table_name, schema_name, db_name) + return self.identifiers.get_dataset_identifier(table_name, schema_name, db_name) def get_batch_kwargs( self, table: BaseTable, schema_name: str, db_name: str diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 1b595515e05ad..50cb1242016b7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -12,6 +12,7 @@ import pydantic from typing_extensions import Self +from datahub.configuration.common import ConfigModel from datahub.configuration.time_window_config import ( BaseTimeWindowConfig, BucketDuration, @@ -34,8 +35,9 @@ ) from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_utils import ( - SnowflakeFilterMixin, - SnowflakeIdentifierMixin, + SnowflakeFilter, + SnowflakeIdentifierBuilder, + SnowflakeStructuredReportMixin, ) from datahub.ingestion.source.usage.usage_common import BaseUsageConfig from datahub.metadata.urns import CorpUserUrn @@ -58,7 +60,7 @@ logger = logging.getLogger(__name__) -class SnowflakeQueriesExtractorConfig(SnowflakeIdentifierConfig, SnowflakeFilterConfig): +class SnowflakeQueriesExtractorConfig(ConfigModel): # TODO: Support stateful ingestion for the time windows. window: BaseTimeWindowConfig = BaseTimeWindowConfig() @@ -87,7 +89,9 @@ class SnowflakeQueriesExtractorConfig(SnowflakeIdentifierConfig, SnowflakeFilter include_operations: bool = True -class SnowflakeQueriesSourceConfig(SnowflakeQueriesExtractorConfig): +class SnowflakeQueriesSourceConfig( + SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig +): connection: SnowflakeConnectionConfig @@ -104,12 +108,14 @@ class SnowflakeQueriesSourceReport(SourceReport): queries_extractor: Optional[SnowflakeQueriesExtractorReport] = None -class SnowflakeQueriesExtractor(SnowflakeFilterMixin, SnowflakeIdentifierMixin): +class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin): def __init__( self, connection: SnowflakeConnection, config: SnowflakeQueriesExtractorConfig, structured_report: SourceReport, + filters: SnowflakeFilter, + identifiers: SnowflakeIdentifierBuilder, graph: Optional[DataHubGraph] = None, schema_resolver: Optional[SchemaResolver] = None, ): @@ -117,12 +123,15 @@ def __init__( self.config = config self.report = SnowflakeQueriesExtractorReport() + self.filters = filters + self.identifiers = identifiers + self._structured_report = structured_report self.aggregator = SqlParsingAggregator( - platform=self.platform, - platform_instance=self.config.platform_instance, - env=self.config.env, + platform=self.identifiers.platform, + platform_instance=self.identifiers.identifier_config.platform_instance, + env=self.identifiers.identifier_config.env, schema_resolver=schema_resolver, graph=graph, eager_graph_load=False, @@ -147,14 +156,6 @@ def __init__( def structured_reporter(self) -> SourceReport: return self._structured_report - @property - def filter_config(self) -> SnowflakeFilterConfig: - return self.config - - @property - def identifier_config(self) -> SnowflakeIdentifierConfig: - return self.config - @functools.cached_property def local_temp_path(self) -> pathlib.Path: if self.config.local_temp_path: @@ -173,7 +174,9 @@ def is_temp_table(self, name: str) -> bool: ) def is_allowed_table(self, name: str) -> bool: - return self.is_dataset_pattern_allowed(name, SnowflakeObjectDomain.TABLE) + return self.filters.is_dataset_pattern_allowed( + name, SnowflakeObjectDomain.TABLE + ) def get_workunits_internal( self, @@ -261,7 +264,9 @@ def fetch_audit_log( def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: # Copied from SnowflakeCommonMixin. - return self.snowflake_identifier(self.cleanup_qualified_name(qualified_name)) + return self.identifiers.snowflake_identifier( + self.filters.cleanup_qualified_name(qualified_name) + ) def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: json_fields = { @@ -283,13 +288,15 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: column_usage = {} for obj in direct_objects_accessed: - dataset = self.gen_dataset_urn( + dataset = self.identifiers.gen_dataset_urn( self.get_dataset_identifier_from_qualified_name(obj["objectName"]) ) columns = set() for modified_column in obj["columns"]: - columns.add(self.snowflake_identifier(modified_column["columnName"])) + columns.add( + self.identifiers.snowflake_identifier(modified_column["columnName"]) + ) upstreams.append(dataset) column_usage[dataset] = columns @@ -304,7 +311,7 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: context=f"{row}", ) - downstream = self.gen_dataset_urn( + downstream = self.identifiers.gen_dataset_urn( self.get_dataset_identifier_from_qualified_name(obj["objectName"]) ) column_lineage = [] @@ -313,18 +320,18 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: ColumnLineageInfo( downstream=DownstreamColumnRef( dataset=downstream, - column=self.snowflake_identifier( + column=self.identifiers.snowflake_identifier( modified_column["columnName"] ), ), upstreams=[ ColumnRef( - table=self.gen_dataset_urn( + table=self.identifiers.gen_dataset_urn( self.get_dataset_identifier_from_qualified_name( upstream["objectName"] ) ), - column=self.snowflake_identifier( + column=self.identifiers.snowflake_identifier( upstream["columnName"] ), ) @@ -374,12 +381,22 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeQueriesSourceConfig): self.config = config self.report = SnowflakeQueriesSourceReport() + self.filters = SnowflakeFilter( + filter_config=self.config, + structured_reporter=self.report, + ) + self.identifiers = SnowflakeIdentifierBuilder( + identifier_config=self.config, + ) + self.connection = self.config.connection.get_connection() self.queries_extractor = SnowflakeQueriesExtractor( connection=self.connection, config=self.config, structured_report=self.report, + filters=self.filters, + identifiers=self.identifiers, graph=self.ctx.graph, ) self.report.queries_extractor = self.queries_extractor.report diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py index a19253e5c5e15..8efee52688496 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py @@ -2,7 +2,7 @@ import itertools import logging import queue -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Union from datahub.configuration.pattern_utils import is_schema_allowed from datahub.emitter.mce_builder import ( @@ -28,8 +28,6 @@ SnowflakeObjectDomain, ) from datahub.ingestion.source.snowflake.snowflake_config import ( - SnowflakeFilterConfig, - SnowflakeIdentifierConfig, SnowflakeV2Config, TagOption, ) @@ -54,8 +52,9 @@ ) from datahub.ingestion.source.snowflake.snowflake_tag import SnowflakeTagExtractor from datahub.ingestion.source.snowflake.snowflake_utils import ( - SnowflakeFilterMixin, - SnowflakeIdentifierMixin, + SnowflakeFilter, + SnowflakeIdentifierBuilder, + SnowflakeStructuredReportMixin, SnowsightUrlBuilder, ) from datahub.ingestion.source.sql.sql_utils import ( @@ -143,13 +142,16 @@ } -class SnowflakeSchemaGenerator(SnowflakeFilterMixin, SnowflakeIdentifierMixin): +class SnowflakeSchemaGenerator(SnowflakeStructuredReportMixin): + platform = "snowflake" + def __init__( self, config: SnowflakeV2Config, report: SnowflakeV2Report, connection: SnowflakeConnection, - dataset_urn_builder: Callable[[str], str], + filters: SnowflakeFilter, + identifiers: SnowflakeIdentifierBuilder, domain_registry: Optional[DomainRegistry], profiler: Optional[SnowflakeProfiler], aggregator: Optional[SqlParsingAggregator], @@ -158,7 +160,8 @@ def __init__( self.config: SnowflakeV2Config = config self.report: SnowflakeV2Report = report self.connection: SnowflakeConnection = connection - self.dataset_urn_builder = dataset_urn_builder + self.filters: SnowflakeFilter = filters + self.identifiers: SnowflakeIdentifierBuilder = identifiers self.data_dictionary: SnowflakeDataDictionary = SnowflakeDataDictionary( connection=self.connection @@ -186,19 +189,17 @@ def get_connection(self) -> SnowflakeConnection: def structured_reporter(self) -> SourceReport: return self.report - @property - def filter_config(self) -> SnowflakeFilterConfig: - return self.config + def gen_dataset_urn(self, dataset_identifier: str) -> str: + return self.identifiers.gen_dataset_urn(dataset_identifier) - @property - def identifier_config(self) -> SnowflakeIdentifierConfig: - return self.config + def snowflake_identifier(self, identifier: str) -> str: + return self.identifiers.snowflake_identifier(identifier) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.databases = [] for database in self.get_databases() or []: self.report.report_entity_scanned(database.name, "database") - if not self.filter_config.database_pattern.allowed(database.name): + if not self.filters.filter_config.database_pattern.allowed(database.name): self.report.report_dropped(f"{database.name}.*") else: self.databases.append(database) @@ -362,10 +363,10 @@ def fetch_schemas_for_database( for schema in self.data_dictionary.get_schemas_for_database(db_name): self.report.report_entity_scanned(schema.name, "schema") if not is_schema_allowed( - self.filter_config.schema_pattern, + self.filters.filter_config.schema_pattern, schema.name, db_name, - self.filter_config.match_fully_qualified_names, + self.filters.filter_config.match_fully_qualified_names, ): self.report.report_dropped(f"{db_name}.{schema.name}.*") else: @@ -441,12 +442,12 @@ def _process_schema( and self.config.parse_view_ddl ): for view in views: - view_identifier = self.get_dataset_identifier( + view_identifier = self.identifiers.get_dataset_identifier( view.name, schema_name, db_name ) if view.view_definition: self.aggregator.add_view_definition( - view_urn=self.dataset_urn_builder(view_identifier), + view_urn=self.identifiers.gen_dataset_urn(view_identifier), view_definition=view.view_definition, default_db=db_name, default_schema=schema_name, @@ -473,11 +474,13 @@ def fetch_views_for_schema( try: views: List[SnowflakeView] = [] for view in self.get_views_for_schema(schema_name, db_name): - view_name = self.get_dataset_identifier(view.name, schema_name, db_name) + view_name = self.identifiers.get_dataset_identifier( + view.name, schema_name, db_name + ) self.report.report_entity_scanned(view_name, "view") - if not self.filter_config.view_pattern.allowed(view_name): + if not self.filters.filter_config.view_pattern.allowed(view_name): self.report.report_dropped(view_name) else: views.append(view) @@ -506,11 +509,13 @@ def fetch_tables_for_schema( try: tables: List[SnowflakeTable] = [] for table in self.get_tables_for_schema(schema_name, db_name): - table_identifier = self.get_dataset_identifier( + table_identifier = self.identifiers.get_dataset_identifier( table.name, schema_name, db_name ) self.report.report_entity_scanned(table_identifier) - if not self.filter_config.table_pattern.allowed(table_identifier): + if not self.filters.filter_config.table_pattern.allowed( + table_identifier + ): self.report.report_dropped(table_identifier) else: tables.append(table) @@ -547,7 +552,9 @@ def _process_table( db_name: str, ) -> Iterable[MetadataWorkUnit]: schema_name = snowflake_schema.name - table_identifier = self.get_dataset_identifier(table.name, schema_name, db_name) + table_identifier = self.identifiers.get_dataset_identifier( + table.name, schema_name, db_name + ) try: table.columns = self.get_columns_for_table( @@ -627,7 +634,9 @@ def _process_view( db_name: str, ) -> Iterable[MetadataWorkUnit]: schema_name = snowflake_schema.name - view_name = self.get_dataset_identifier(view.name, schema_name, db_name) + view_name = self.identifiers.get_dataset_identifier( + view.name, schema_name, db_name + ) try: view.columns = self.get_columns_for_table( @@ -678,8 +687,10 @@ def gen_dataset_workunits( for tag in table.column_tags[column_name]: yield from self._process_tag(tag) - dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) - dataset_urn = self.dataset_urn_builder(dataset_name) + dataset_name = self.identifiers.get_dataset_identifier( + table.name, schema_name, db_name + ) + dataset_urn = self.identifiers.gen_dataset_urn(dataset_name) status = Status(removed=False) yield MetadataChangeProposalWrapper( @@ -816,8 +827,10 @@ def gen_schema_metadata( schema_name: str, db_name: str, ) -> SchemaMetadata: - dataset_name = self.get_dataset_identifier(table.name, schema_name, db_name) - dataset_urn = self.dataset_urn_builder(dataset_name) + dataset_name = self.identifiers.get_dataset_identifier( + table.name, schema_name, db_name + ) + dataset_urn = self.identifiers.gen_dataset_urn(dataset_name) foreign_keys: Optional[List[ForeignKeyConstraint]] = None if isinstance(table, SnowflakeTable) and len(table.foreign_keys) > 0: @@ -876,7 +889,7 @@ def build_foreign_keys( for fk in table.foreign_keys: foreign_dataset = make_dataset_urn_with_platform_instance( platform=self.platform, - name=self.get_dataset_identifier( + name=self.identifiers.get_dataset_identifier( fk.referred_table, fk.referred_schema, fk.referred_database ), env=self.config.env, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 9c61fc3ec9bee..9550e548c9949 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Iterable, List +from typing import Iterable, List from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.emitter.mcp import MetadataChangeProposalWrapper @@ -26,12 +26,10 @@ def __init__( self, config: SnowflakeV2Config, report: SnowflakeV2Report, - dataset_urn_builder: Callable[[str], str], ) -> None: self.config = config self.report = report self.logger = logger - self.dataset_urn_builder = dataset_urn_builder def get_shares_workunits( self, databases: List[SnowflakeDatabase] @@ -114,15 +112,15 @@ def gen_siblings( ) -> Iterable[MetadataWorkUnit]: if not sibling_databases: return - dataset_identifier = self.get_dataset_identifier( + dataset_identifier = self.identifiers.get_dataset_identifier( table_name, schema_name, database_name ) - urn = self.dataset_urn_builder(dataset_identifier) + urn = self.identifiers.gen_dataset_urn(dataset_identifier) sibling_urns = [ make_dataset_urn_with_platform_instance( - self.platform, - self.get_dataset_identifier( + self.identifiers.platform, + self.identifiers.get_dataset_identifier( table_name, schema_name, sibling_db.database ), sibling_db.platform_instance, @@ -142,14 +140,14 @@ def get_upstream_lineage_with_primary_sibling( table_name: str, primary_sibling_db: DatabaseId, ) -> MetadataWorkUnit: - dataset_identifier = self.get_dataset_identifier( + dataset_identifier = self.identifiers.get_dataset_identifier( table_name, schema_name, database_name ) - urn = self.dataset_urn_builder(dataset_identifier) + urn = self.identifiers.gen_dataset_urn(dataset_identifier) upstream_urn = make_dataset_urn_with_platform_instance( - self.platform, - self.get_dataset_identifier( + self.identifiers.platform, + self.identifiers.get_dataset_identifier( table_name, schema_name, primary_sibling_db.database ), primary_sibling_db.platform_instance, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py index f78ae70291f8a..a09e314711881 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py @@ -9,7 +9,10 @@ from datahub.ingestion.api.decorators import SupportStatus, config_class, support_status from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.api.workunit import MetadataWorkUnit -from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeFilterConfig +from datahub.ingestion.source.snowflake.snowflake_config import ( + SnowflakeFilterConfig, + SnowflakeIdentifierConfig, +) from datahub.ingestion.source.snowflake.snowflake_connection import ( SnowflakeConnectionConfig, ) @@ -17,6 +20,9 @@ from datahub.ingestion.source.snowflake.snowflake_schema_gen import ( SnowflakeSchemaGenerator, ) +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeIdentifierBuilder, +) from datahub.ingestion.source_report.time_window import BaseTimeWindowReport from datahub.utilities.lossy_collections import LossyList @@ -69,7 +75,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: config=self.config, # type: ignore report=self.report, # type: ignore connection=self.connection, - dataset_urn_builder=lambda x: "", + identifiers=SnowflakeIdentifierBuilder( + identifier_config=SnowflakeIdentifierConfig() + ), domain_registry=None, profiler=None, aggregator=None, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index c5e0994059f2e..380bfd108ddb4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -2,7 +2,7 @@ import logging import time from datetime import datetime, timezone -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import pydantic @@ -20,7 +20,11 @@ ) from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_report import SnowflakeV2Report -from datahub.ingestion.source.snowflake.snowflake_utils import SnowflakeCommonMixin +from datahub.ingestion.source.snowflake.snowflake_utils import ( + SnowflakeCommonMixin, + SnowflakeFilter, + SnowflakeIdentifierBuilder, +) from datahub.ingestion.source.state.redundant_run_skip_handler import ( RedundantUsageRunSkipHandler, ) @@ -112,13 +116,14 @@ def __init__( config: SnowflakeV2Config, report: SnowflakeV2Report, connection: SnowflakeConnection, - dataset_urn_builder: Callable[[str], str], + filter: SnowflakeFilter, + identifiers: SnowflakeIdentifierBuilder, redundant_run_skip_handler: Optional[RedundantUsageRunSkipHandler], ) -> None: self.config: SnowflakeV2Config = config self.report: SnowflakeV2Report = report - self.dataset_urn_builder = dataset_urn_builder - self.logger = logger + self.filter = filter + self.identifiers = identifiers self.connection = connection self.redundant_run_skip_handler = redundant_run_skip_handler @@ -171,7 +176,7 @@ def get_usage_workunits( bucket_duration=self.config.bucket_duration, ), dataset_urns={ - self.dataset_urn_builder(dataset_identifier) + self.identifiers.gen_dataset_urn(dataset_identifier) for dataset_identifier in discovered_datasets }, ) @@ -232,7 +237,7 @@ def _get_workunits_internal( logger.debug(f"Processing usage row number {results.rownumber}") logger.debug(self.report.usage_aggregation.as_string()) - if not self.is_dataset_pattern_allowed( + if not self.filter.is_dataset_pattern_allowed( row["OBJECT_NAME"], row["OBJECT_DOMAIN"], ): @@ -279,7 +284,8 @@ def build_usage_statistics_for_dataset( fieldCounts=self._map_field_counts(row["FIELD_COUNTS"]), ) return MetadataChangeProposalWrapper( - entityUrn=self.dataset_urn_builder(dataset_identifier), aspect=stats + entityUrn=self.identifiers.gen_dataset_urn(dataset_identifier), + aspect=stats, ).as_workunit() except Exception as e: logger.debug( @@ -356,7 +362,9 @@ def _map_field_counts(self, field_counts_str: str) -> List[DatasetFieldUsageCoun return sorted( [ DatasetFieldUsageCounts( - fieldPath=self.snowflake_identifier(field_count["col"]), + fieldPath=self.identifiers.snowflake_identifier( + field_count["col"] + ), count=field_count["total"], ) for field_count in field_counts @@ -476,7 +484,7 @@ def _get_operation_aspect_work_unit( ), ) mcp = MetadataChangeProposalWrapper( - entityUrn=self.dataset_urn_builder(dataset_identifier), + entityUrn=self.identifiers.gen_dataset_urn(dataset_identifier), aspect=operation_aspect, ) wu = MetadataWorkUnit( @@ -561,7 +569,7 @@ def _is_unsupported_object_accessed(self, obj: Dict[str, Any]) -> bool: def _is_object_valid(self, obj: Dict[str, Any]) -> bool: if self._is_unsupported_object_accessed( obj - ) or not self.is_dataset_pattern_allowed( + ) or not self.filter.is_dataset_pattern_allowed( obj.get("objectName"), obj.get("objectDomain") ): return False diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index b1645e32a8229..384fa744e6e22 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -1,4 +1,5 @@ import abc +from functools import cached_property from typing import ClassVar, Literal, Optional, Tuple from typing_extensions import Protocol @@ -41,14 +42,6 @@ class SnowflakeCommonProtocol(Protocol): config: SnowflakeV2Config report: SnowflakeV2Report - def get_dataset_identifier( - self, table_name: str, schema_name: str, db_name: str - ) -> str: - ... - - def cleanup_qualified_name(self, qualified_name: str) -> str: - ... - def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: ... @@ -140,11 +133,14 @@ def get_external_url_for_database(self, db_name: str) -> Optional[str]: return f"{self.snowsight_base_url}#/data/databases/{db_name}/" -class SnowflakeFilterMixin(SnowflakeStructuredReportMixin): - @property - @abc.abstractmethod - def filter_config(self) -> SnowflakeFilterConfig: - ... +class SnowflakeFilter: + def __init__( + self, filter_config: SnowflakeFilterConfig, structured_reporter: SourceReport + ) -> None: + self.filter_config = filter_config + self.structured_reporter = structured_reporter + + # TODO: Refactor remaining filtering logic into this class. @staticmethod def _combine_identifier_parts( @@ -224,20 +220,18 @@ def cleanup_qualified_name(self, qualified_name: str) -> str: context=f"{qualified_name} has {len(name_parts)} parts", ) return qualified_name.replace('"', "") - return SnowflakeFilterMixin._combine_identifier_parts( + return SnowflakeFilter._combine_identifier_parts( table_name=name_parts[2].strip('"'), schema_name=name_parts[1].strip('"'), db_name=name_parts[0].strip('"'), ) -class SnowflakeIdentifierMixin(abc.ABC): +class SnowflakeIdentifierBuilder: platform = "snowflake" - @property - @abc.abstractmethod - def identifier_config(self) -> SnowflakeIdentifierConfig: - ... + def __init__(self, identifier_config: SnowflakeIdentifierConfig) -> None: + self.identifier_config = identifier_config def snowflake_identifier(self, identifier: str) -> str: # to be in in sync with older connector, convert name to lowercase @@ -249,7 +243,7 @@ def get_dataset_identifier( self, table_name: str, schema_name: str, db_name: str ) -> str: return self.snowflake_identifier( - SnowflakeCommonMixin._combine_identifier_parts( + SnowflakeFilter._combine_identifier_parts( table_name=table_name, schema_name=schema_name, db_name=db_name ) ) @@ -264,18 +258,16 @@ def gen_dataset_urn(self, dataset_identifier: str) -> str: # TODO: We're most of the way there on fully removing SnowflakeCommonProtocol. -class SnowflakeCommonMixin(SnowflakeFilterMixin, SnowflakeIdentifierMixin): +class SnowflakeCommonMixin(SnowflakeStructuredReportMixin): + platform = "snowflake" + @property def structured_reporter(self: SnowflakeCommonProtocol) -> SourceReport: return self.report - @property - def filter_config(self: SnowflakeCommonProtocol) -> SnowflakeFilterConfig: - return self.config - - @property - def identifier_config(self: SnowflakeCommonProtocol) -> SnowflakeIdentifierConfig: - return self.config + @cached_property + def identifiers(self: SnowflakeCommonProtocol) -> SnowflakeIdentifierBuilder: + return SnowflakeIdentifierBuilder(self.config) @staticmethod def get_quoted_identifier_for_database(db_name): @@ -285,8 +277,23 @@ def get_quoted_identifier_for_database(db_name): def get_quoted_identifier_for_schema(db_name, schema_name): return f'"{db_name}"."{schema_name}"' - def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: - return self.snowflake_identifier(self.cleanup_qualified_name(qualified_name)) + def gen_dataset_urn(self: SnowflakeCommonProtocol, dataset_identifier: str) -> str: + # TODO: Remove this method. + identifiers = SnowflakeIdentifierBuilder(self.config) + return identifiers.gen_dataset_urn(dataset_identifier) + + def snowflake_identifier(self: SnowflakeCommonProtocol, identifier: str) -> str: + # TODO: Remove this method. + identifiers = SnowflakeIdentifierBuilder(self.config) + return identifiers.snowflake_identifier(identifier) + + def get_dataset_identifier_from_qualified_name( + self: SnowflakeCommonProtocol, qualified_name: str + ) -> str: + filter = SnowflakeFilter( + filter_config=self.config, structured_reporter=self.report + ) + return self.snowflake_identifier(filter.cleanup_qualified_name(qualified_name)) @staticmethod def get_quoted_identifier_for_table(db_name, schema_name, table_name): diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index df22a406adfde..8226d77379db9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -61,6 +61,8 @@ ) from datahub.ingestion.source.snowflake.snowflake_utils import ( SnowflakeCommonMixin, + SnowflakeFilter, + SnowflakeIdentifierBuilder, SnowsightUrlBuilder, ) from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler @@ -135,6 +137,11 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.report: SnowflakeV2Report = SnowflakeV2Report() self.logger = logger + self.filters = SnowflakeFilter( + filter_config=self.config, structured_reporter=self.report + ) + self.identifiers = SnowflakeIdentifierBuilder(identifier_config=self.config) + self.connection = self.config.get_connection() self.domain_registry: Optional[DomainRegistry] = None @@ -150,7 +157,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): if self.config.include_table_lineage: self.aggregator = SqlParsingAggregator( - platform=self.platform, + platform=self.identifiers.platform, platform_instance=self.config.platform_instance, env=self.config.env, graph=self.ctx.graph, @@ -185,7 +192,8 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): config, self.report, connection=self.connection, - dataset_urn_builder=self.gen_dataset_urn, + filters=self.filters, + identifiers=self.identifiers, redundant_run_skip_handler=redundant_lineage_run_skip_handler, sql_aggregator=self.aggregator, ) @@ -206,7 +214,8 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): config, self.report, connection=self.connection, - dataset_urn_builder=self.gen_dataset_urn, + filter=self.filters, + identifiers=self.identifiers, redundant_run_skip_handler=redundant_usage_run_skip_handler, ) @@ -450,7 +459,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: profiler=self.profiler, aggregator=self.aggregator, snowsight_url_builder=snowsight_url_builder, - dataset_urn_builder=self.gen_dataset_urn, + filters=self.filters, + identifiers=self.identifiers, ) self.report.set_ingestion_stage("*", METADATA_EXTRACTION) @@ -462,17 +472,17 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.config.shares: yield from SnowflakeSharesHandler( - self.config, self.report, self.gen_dataset_urn + self.config, self.report ).get_shares_workunits(databases) discovered_tables: List[str] = [ - self.get_dataset_identifier(table_name, schema.name, db.name) + self.identifiers.get_dataset_identifier(table_name, schema.name, db.name) for db in databases for schema in db.schemas for table_name in schema.tables ] discovered_views: List[str] = [ - self.get_dataset_identifier(table_name, schema.name, db.name) + self.identifiers.get_dataset_identifier(table_name, schema.name, db.name) for db in databases for schema in db.schemas for table_name in schema.views @@ -499,15 +509,6 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: queries_extractor = SnowflakeQueriesExtractor( connection=self.connection, config=SnowflakeQueriesExtractorConfig( - # TODO: Refactor this a bit so it's not as redundant. - database_pattern=self.config.database_pattern, - schema_pattern=self.config.schema_pattern, - table_pattern=self.config.table_pattern, - view_pattern=self.config.view_pattern, - match_fully_qualified_names=self.config.match_fully_qualified_names, - convert_urns_to_lowercase=self.config.convert_urns_to_lowercase, - env=self.config.env, - platform_instance=self.config.platform_instance, window=self.config, temporary_tables_pattern=self.config.temporary_tables_pattern, include_lineage=self.config.include_table_lineage, @@ -515,6 +516,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: include_operations=self.config.include_operational_stats, ), structured_report=self.report, + filters=self.filters, + identifiers=self.identifiers, schema_resolver=schema_resolver, ) @@ -539,7 +542,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.config.include_assertion_results: yield from SnowflakeAssertionsHandler( - self.config, self.report, self.connection + self.config, self.report, self.connection, self.identifiers ).get_assertion_workunits(discovered_datasets) self.connection.close() diff --git a/metadata-ingestion/tests/unit/test_snowflake_shares.py b/metadata-ingestion/tests/unit/test_snowflake_shares.py index fc753f99b7e8f..2e78f0bb3ae65 100644 --- a/metadata-ingestion/tests/unit/test_snowflake_shares.py +++ b/metadata-ingestion/tests/unit/test_snowflake_shares.py @@ -102,9 +102,7 @@ def test_snowflake_shares_workunit_no_shares( config = SnowflakeV2Config(account_id="abc12345", platform_instance="instance1") report = SnowflakeV2Report() - shares_handler = SnowflakeSharesHandler( - config, report, lambda x: make_snowflake_urn(x) - ) + shares_handler = SnowflakeSharesHandler(config, report) wus = list(shares_handler.get_shares_workunits(snowflake_databases)) @@ -204,9 +202,7 @@ def test_snowflake_shares_workunit_inbound_share( ) report = SnowflakeV2Report() - shares_handler = SnowflakeSharesHandler( - config, report, lambda x: make_snowflake_urn(x, "instance1") - ) + shares_handler = SnowflakeSharesHandler(config, report) wus = list(shares_handler.get_shares_workunits(snowflake_databases)) @@ -262,9 +258,7 @@ def test_snowflake_shares_workunit_outbound_share( ) report = SnowflakeV2Report() - shares_handler = SnowflakeSharesHandler( - config, report, lambda x: make_snowflake_urn(x, "instance1") - ) + shares_handler = SnowflakeSharesHandler(config, report) wus = list(shares_handler.get_shares_workunits(snowflake_databases)) @@ -313,9 +307,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share( ) report = SnowflakeV2Report() - shares_handler = SnowflakeSharesHandler( - config, report, lambda x: make_snowflake_urn(x, "instance1") - ) + shares_handler = SnowflakeSharesHandler(config, report) wus = list(shares_handler.get_shares_workunits(snowflake_databases)) @@ -376,9 +368,7 @@ def test_snowflake_shares_workunit_inbound_and_outbound_share_no_platform_instan ) report = SnowflakeV2Report() - shares_handler = SnowflakeSharesHandler( - config, report, lambda x: make_snowflake_urn(x) - ) + shares_handler = SnowflakeSharesHandler(config, report) assert sorted(config.outbounds().keys()) == ["db1", "db2_main"] assert sorted(config.inbounds().keys()) == [ From b0f60871844c8ebd67d76668ae70affb1de005bc Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 18:58:12 -0700 Subject: [PATCH 07/16] remove mroe methods from mixin --- .../source/snowflake/snowflake_utils.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index 384fa744e6e22..f87e92398d74c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -45,9 +45,6 @@ class SnowflakeCommonProtocol(Protocol): def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: ... - def snowflake_identifier(self, identifier: str) -> str: - ... - def report_warning(self, key: str, reason: str) -> None: ... @@ -277,23 +274,16 @@ def get_quoted_identifier_for_database(db_name): def get_quoted_identifier_for_schema(db_name, schema_name): return f'"{db_name}"."{schema_name}"' - def gen_dataset_urn(self: SnowflakeCommonProtocol, dataset_identifier: str) -> str: - # TODO: Remove this method. - identifiers = SnowflakeIdentifierBuilder(self.config) - return identifiers.gen_dataset_urn(dataset_identifier) - - def snowflake_identifier(self: SnowflakeCommonProtocol, identifier: str) -> str: - # TODO: Remove this method. - identifiers = SnowflakeIdentifierBuilder(self.config) - return identifiers.snowflake_identifier(identifier) - def get_dataset_identifier_from_qualified_name( self: SnowflakeCommonProtocol, qualified_name: str ) -> str: filter = SnowflakeFilter( filter_config=self.config, structured_reporter=self.report ) - return self.snowflake_identifier(filter.cleanup_qualified_name(qualified_name)) + identifiers = SnowflakeIdentifierBuilder(self.config) + return identifiers.snowflake_identifier( + filter.cleanup_qualified_name(qualified_name) + ) @staticmethod def get_quoted_identifier_for_table(db_name, schema_name, table_name): @@ -309,13 +299,14 @@ def get_user_identifier( user_email: Optional[str], email_as_user_identifier: bool, ) -> str: + identifiers = SnowflakeIdentifierBuilder(self.config) if user_email: - return self.snowflake_identifier( + return identifiers.snowflake_identifier( user_email if email_as_user_identifier is True else user_email.split("@")[0] ) - return self.snowflake_identifier(user_name) + return identifiers.snowflake_identifier(user_name) # TODO: Revisit this after stateful ingestion can commit checkpoint # for failures that do not affect the checkpoint From 5542fe17e6cc391ed7f18c4ca2908a371581139d Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 19:01:10 -0700 Subject: [PATCH 08/16] remove SnowflakeCommonProtocol --- .../source/snowflake/snowflake_utils.py | 38 ++++--------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index f87e92398d74c..805a646b53faa 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -2,8 +2,6 @@ from functools import cached_property from typing import ClassVar, Literal, Optional, Tuple -from typing_extensions import Protocol - from datahub.configuration.pattern_utils import is_schema_allowed from datahub.emitter.mce_builder import make_dataset_urn_with_platform_instance from datahub.ingestion.api.source import SourceReport @@ -34,24 +32,6 @@ def report_error(self, key: str, reason: str) -> None: self.structured_reporter.failure(key, reason) -# Required only for mypy, since we are using mixin classes, and not inheritance. -# Reference - https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes -class SnowflakeCommonProtocol(Protocol): - platform: str = "snowflake" - - config: SnowflakeV2Config - report: SnowflakeV2Report - - def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: - ... - - def report_warning(self, key: str, reason: str) -> None: - ... - - def report_error(self, key: str, reason: str) -> None: - ... - - class SnowsightUrlBuilder: CLOUD_REGION_IDS_WITHOUT_CLOUD_SUFFIX: ClassVar = [ "us-west-2", @@ -254,16 +234,18 @@ def gen_dataset_urn(self, dataset_identifier: str) -> str: ) -# TODO: We're most of the way there on fully removing SnowflakeCommonProtocol. class SnowflakeCommonMixin(SnowflakeStructuredReportMixin): platform = "snowflake" + config: SnowflakeV2Config + report: SnowflakeV2Report + @property - def structured_reporter(self: SnowflakeCommonProtocol) -> SourceReport: + def structured_reporter(self) -> SourceReport: return self.report @cached_property - def identifiers(self: SnowflakeCommonProtocol) -> SnowflakeIdentifierBuilder: + def identifiers(self) -> SnowflakeIdentifierBuilder: return SnowflakeIdentifierBuilder(self.config) @staticmethod @@ -274,9 +256,7 @@ def get_quoted_identifier_for_database(db_name): def get_quoted_identifier_for_schema(db_name, schema_name): return f'"{db_name}"."{schema_name}"' - def get_dataset_identifier_from_qualified_name( - self: SnowflakeCommonProtocol, qualified_name: str - ) -> str: + def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: filter = SnowflakeFilter( filter_config=self.config, structured_reporter=self.report ) @@ -294,7 +274,7 @@ def get_quoted_identifier_for_table(db_name, schema_name, table_name): # Users without email were skipped from both user entries as well as aggregates. # However email is not mandatory field in snowflake user, user_name is always present. def get_user_identifier( - self: SnowflakeCommonProtocol, + self, user_name: str, user_email: Optional[str], email_as_user_identifier: bool, @@ -310,9 +290,7 @@ def get_user_identifier( # TODO: Revisit this after stateful ingestion can commit checkpoint # for failures that do not affect the checkpoint - def warn_if_stateful_else_error( - self: SnowflakeCommonProtocol, key: str, reason: str - ) -> None: + def warn_if_stateful_else_error(self, key: str, reason: str) -> None: if ( self.config.stateful_ingestion is not None and self.config.stateful_ingestion.enabled From b4a0fce14e7ab9628774a8ac2facec81f03e7ee2 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 19:14:23 -0700 Subject: [PATCH 09/16] simplify structured reporting mixin --- .../source/snowflake/snowflake_lineage_v2.py | 16 ++--- .../source/snowflake/snowflake_schema_gen.py | 68 +++++++------------ .../source/snowflake/snowflake_utils.py | 12 +--- .../source/snowflake/snowflake_v2.py | 25 ++++--- 4 files changed, 54 insertions(+), 67 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 50556efdcdd77..1ed5a45e8cf31 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -343,7 +343,7 @@ def _populate_external_lineage_from_copy_history( self.warn_if_stateful_else_error(LINEAGE_PERMISSION_ERROR, error_msg) else: logger.debug(e, exc_info=e) - self.report_warning( + self.structured_reporter.warning( "external_lineage", f"Populating table external lineage from Snowflake failed due to error {e}.", ) @@ -393,10 +393,9 @@ def _fetch_upstream_lineages_for_tables(self) -> Iterable[UpstreamLineageEdge]: error_msg = "Failed to get table/view to table lineage. Please grant imported privileges on SNOWFLAKE database. " self.warn_if_stateful_else_error(LINEAGE_PERMISSION_ERROR, error_msg) else: - logger.debug(e, exc_info=e) - self.report_warning( - "table-upstream-lineage", - f"Extracting lineage from Snowflake failed due to error {e}.", + self.structured_reporter.warning( + "Failed to extract table/view -> table lineage from Snowflake", + exc=e, ) self.report_status(TABLE_LINEAGE, False) @@ -407,9 +406,10 @@ def _process_upstream_lineage_row( return UpstreamLineageEdge.parse_obj(db_row) except Exception as e: self.report.num_upstream_lineage_edge_parsing_failed += 1 - self.report_warning( - f"Parsing lineage edge failed due to error {e}", - db_row.get("DOWNSTREAM_TABLE_NAME") or "", + self.structured_reporter.warning( + "Failed to parse lineage edge", + context=db_row.get("DOWNSTREAM_TABLE_NAME") or None, + exc=e, ) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py index 8efee52688496..3c4e633c1b855 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema_gen.py @@ -213,7 +213,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield from self._process_database(snowflake_db) except SnowflakePermissionError as e: - self.report_error(GENERIC_PERMISSION_ERROR_KEY, str(e)) + self.structured_reporter.failure( + GENERIC_PERMISSION_ERROR_KEY, + exc=e, + ) return def get_databases(self) -> Optional[List[SnowflakeDatabase]]: @@ -222,10 +225,9 @@ def get_databases(self) -> Optional[List[SnowflakeDatabase]]: # whose information_schema can be queried to start with. databases = self.data_dictionary.show_databases() except Exception as e: - logger.debug(f"Failed to list databases due to error {e}", exc_info=e) - self.report_error( - "list-databases", - f"Failed to list databases due to error {e}", + self.structured_reporter.failure( + "Failed to list databases", + exc=e, ) return None else: @@ -234,7 +236,7 @@ def get_databases(self) -> Optional[List[SnowflakeDatabase]]: ] = self.get_databases_from_ischema(databases) if len(ischema_databases) == 0: - self.report_error( + self.structured_reporter.failure( GENERIC_PERMISSION_ERROR_KEY, "No databases found. Please check permissions.", ) @@ -277,7 +279,7 @@ def _process_database( # This may happen if REFERENCE_USAGE permissions are set # We can not run show queries on database in such case. # This need not be a failure case. - self.report_warning( + self.structured_reporter.warning( "Insufficient privileges to operate on database, skipping. Please grant USAGE permissions on database to extract its metadata.", db_name, ) @@ -286,9 +288,8 @@ def _process_database( f"Failed to use database {db_name} due to error {e}", exc_info=e, ) - self.report_warning( - "Failed to get schemas for database", - db_name, + self.structured_reporter.warning( + "Failed to get schemas for database", db_name, exc=e ) return @@ -377,17 +378,14 @@ def fetch_schemas_for_database( # Ideal implementation would use PEP 678 – Enriching Exceptions with Notes raise SnowflakePermissionError(error_msg) from e.__cause__ else: - logger.debug( - f"Failed to get schemas for database {db_name} due to error {e}", - exc_info=e, - ) - self.report_warning( + self.structured_reporter.warning( "Failed to get schemas for database", db_name, + exc=e, ) if not schemas: - self.report_warning( + self.structured_reporter.warning( "No schemas found in database. If schemas exist, please grant USAGE permissions on them.", db_name, ) @@ -493,13 +491,10 @@ def fetch_views_for_schema( raise SnowflakePermissionError(error_msg) from e.__cause__ else: - logger.debug( - f"Failed to get views for schema {db_name}.{schema_name} due to error {e}", - exc_info=e, - ) - self.report_warning( + self.structured_reporter.warning( "Failed to get views for schema", f"{db_name}.{schema_name}", + exc=e, ) return [] @@ -527,13 +522,10 @@ def fetch_tables_for_schema( error_msg = f"Failed to get tables for schema {db_name}.{schema_name}. Please check permissions." raise SnowflakePermissionError(error_msg) from e.__cause__ else: - logger.debug( - f"Failed to get tables for schema {db_name}.{schema_name} due to error {e}", - exc_info=e, - ) - self.report_warning( + self.structured_reporter.warning( "Failed to get tables for schema", f"{db_name}.{schema_name}", + exc=e, ) return [] @@ -566,11 +558,9 @@ def _process_table( table.name, schema_name, db_name ) except Exception as e: - logger.debug( - f"Failed to get columns for table {table_identifier} due to error {e}", - exc_info=e, + self.structured_reporter.warning( + "Failed to get columns for table", table_identifier, exc=e ) - self.report_warning("Failed to get columns for table", table_identifier) if self.config.extract_tags != TagOption.skip: table.tags = self.tag_extractor.get_tags_on_object( @@ -603,11 +593,9 @@ def fetch_foreign_keys_for_table( table.name, schema_name, db_name ) except Exception as e: - logger.debug( - f"Failed to get foreign key for table {table_identifier} due to error {e}", - exc_info=e, + self.structured_reporter.warning( + "Failed to get foreign keys for table", table_identifier, exc=e ) - self.report_warning("Failed to get foreign key for table", table_identifier) def fetch_pk_for_table( self, @@ -621,11 +609,9 @@ def fetch_pk_for_table( table.name, schema_name, db_name ) except Exception as e: - logger.debug( - f"Failed to get primary key for table {table_identifier} due to error {e}", - exc_info=e, + self.structured_reporter.warning( + "Failed to get primary key for table", table_identifier, exc=e ) - self.report_warning("Failed to get primary key for table", table_identifier) def _process_view( self, @@ -647,11 +633,9 @@ def _process_view( view.name, schema_name, db_name ) except Exception as e: - logger.debug( - f"Failed to get columns for view {view_name} due to error {e}", - exc_info=e, + self.structured_reporter.warning( + "Failed to get columns for view", view_name, exc=e ) - self.report_warning("Failed to get columns for view", view_name) if self.config.extract_tags != TagOption.skip: view.tags = self.tag_extractor.get_tags_on_object( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index 805a646b53faa..f60dbb9052763 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -24,13 +24,6 @@ class SnowflakeStructuredReportMixin(abc.ABC): def structured_reporter(self) -> SourceReport: ... - # TODO: Eventually I want to deprecate these methods and use the structured_reporter directly. - def report_warning(self, key: str, reason: str) -> None: - self.structured_reporter.warning(key, reason) - - def report_error(self, key: str, reason: str) -> None: - self.structured_reporter.failure(key, reason) - class SnowsightUrlBuilder: CLOUD_REGION_IDS_WITHOUT_CLOUD_SUFFIX: ClassVar = [ @@ -290,11 +283,12 @@ def get_user_identifier( # TODO: Revisit this after stateful ingestion can commit checkpoint # for failures that do not affect the checkpoint + # TODO: Add additional parameters to match the signature of the .warning and .failure methods def warn_if_stateful_else_error(self, key: str, reason: str) -> None: if ( self.config.stateful_ingestion is not None and self.config.stateful_ingestion.enabled ): - self.report_warning(key, reason) + self.structured_reporter.warning(key, reason) else: - self.report_error(key, reason) + self.structured_reporter.failure(key, reason) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 8226d77379db9..1f3a9ca320237 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -489,7 +489,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: ] if len(discovered_tables) == 0 and len(discovered_views) == 0: - self.report_error( + self.structured_reporter.failure( GENERIC_PERMISSION_ERROR_KEY, "No tables/views found. Please check permissions.", ) @@ -549,14 +549,14 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: def report_warehouse_failure(self) -> None: if self.config.warehouse is not None: - self.report_error( + self.structured_reporter.failure( GENERIC_PERMISSION_ERROR_KEY, f"Current role does not have permissions to use warehouse {self.config.warehouse}. Please update permissions.", ) else: - self.report_error( - "no-active-warehouse", - "No default warehouse set for user. Either set default warehouse for user or configure warehouse in recipe.", + self.structured_reporter.failure( + "Could not use a Snowflake warehouse", + "No default warehouse set for user. Either set a default warehouse for the user or configure a warehouse in the recipe.", ) def get_report(self) -> SourceReport: @@ -587,19 +587,28 @@ def inspect_session_metadata(self, connection: SnowflakeConnection) -> None: for db_row in connection.query(SnowflakeQuery.current_version()): self.report.saas_version = db_row["CURRENT_VERSION()"] except Exception as e: - self.report_error("version", f"Error: {e}") + self.structured_reporter.failure( + "Could not determine the current Snowflake version", + exc=e, + ) try: logger.info("Checking current role") for db_row in connection.query(SnowflakeQuery.current_role()): self.report.role = db_row["CURRENT_ROLE()"] except Exception as e: - self.report_error("version", f"Error: {e}") + self.structured_reporter.failure( + "Could not determine the current Snowflake role", + exc=e, + ) try: logger.info("Checking current warehouse") for db_row in connection.query(SnowflakeQuery.current_warehouse()): self.report.default_warehouse = db_row["CURRENT_WAREHOUSE()"] except Exception as e: - self.report_error("current_warehouse", f"Error: {e}") + self.structured_reporter.failure( + "Could not determine the current Snowflake warehouse", + exc=e, + ) try: logger.info("Checking current edition") From 67417a0acd361c207dd17175bda8476e02252ab4 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 19:43:37 -0700 Subject: [PATCH 10/16] only create one SnowflakeIdentifierBuilder --- .../ingestion/source/snowflake/snowflake_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index f60dbb9052763..01ab5d16b649c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -241,6 +241,8 @@ def structured_reporter(self) -> SourceReport: def identifiers(self) -> SnowflakeIdentifierBuilder: return SnowflakeIdentifierBuilder(self.config) + # TODO: These methods should be moved to SnowflakeIdentifierBuilder. + @staticmethod def get_quoted_identifier_for_database(db_name): return f'"{db_name}"' @@ -253,8 +255,7 @@ def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str filter = SnowflakeFilter( filter_config=self.config, structured_reporter=self.report ) - identifiers = SnowflakeIdentifierBuilder(self.config) - return identifiers.snowflake_identifier( + return self.identifiers.snowflake_identifier( filter.cleanup_qualified_name(qualified_name) ) @@ -272,14 +273,13 @@ def get_user_identifier( user_email: Optional[str], email_as_user_identifier: bool, ) -> str: - identifiers = SnowflakeIdentifierBuilder(self.config) if user_email: - return identifiers.snowflake_identifier( + return self.identifiers.snowflake_identifier( user_email if email_as_user_identifier is True else user_email.split("@")[0] ) - return identifiers.snowflake_identifier(user_name) + return self.identifiers.snowflake_identifier(user_name) # TODO: Revisit this after stateful ingestion can commit checkpoint # for failures that do not affect the checkpoint From ab95f984ed488d47aaae42e1fcb3b66b877b929c Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 19:44:39 -0700 Subject: [PATCH 11/16] fix athena lint --- .../src/datahub/ingestion/source/sql/athena.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index ae17cff60fedd..9ddc671e21133 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -251,11 +251,6 @@ class AthenaConfig(SQLCommonConfig): "queries executed by DataHub." ) - # overwrite default behavior of SQLAlchemyConfing - include_views: Optional[bool] = pydantic.Field( - default=True, description="Whether views should be ingested." - ) - _s3_staging_dir_population = pydantic_renamed_field( old_name="s3_staging_dir", new_name="query_result_location", From 8a1b4815fd2c503d826a16504b3af88f9045d7aa Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 20:08:53 -0700 Subject: [PATCH 12/16] finish refactoring into SnowflakeIdentifierBuilder --- .../source/snowflake/snowflake_lineage_v2.py | 16 ++- .../source/snowflake/snowflake_queries.py | 17 ++- .../source/snowflake/snowflake_summary.py | 3 +- .../source/snowflake/snowflake_tag.py | 10 +- .../source/snowflake/snowflake_usage_v2.py | 8 +- .../source/snowflake/snowflake_utils.py | 107 +++++++++--------- .../source/snowflake/snowflake_v2.py | 4 +- 7 files changed, 89 insertions(+), 76 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 1ed5a45e8cf31..9a21033a0cc1d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -215,7 +215,7 @@ def populate_known_query_lineage( results: Iterable[UpstreamLineageEdge], ) -> None: for db_row in results: - dataset_name = self.get_dataset_identifier_from_qualified_name( + dataset_name = self.identifiers.get_dataset_identifier_from_qualified_name( db_row.DOWNSTREAM_TABLE_NAME ) if dataset_name not in discovered_assets or not db_row.QUERIES: @@ -353,7 +353,7 @@ def _process_external_lineage_result_row( self, db_row: dict, discovered_tables: List[str] ) -> Optional[KnownLineageMapping]: # key is the down-stream table name - key: str = self.get_dataset_identifier_from_qualified_name( + key: str = self.identifiers.get_dataset_identifier_from_qualified_name( db_row["DOWNSTREAM_TABLE_NAME"] ) if key not in discovered_tables: @@ -422,8 +422,10 @@ def map_query_result_upstreams( for upstream_table in upstream_tables: if upstream_table and upstream_table.query_id == query_id: try: - upstream_name = self.get_dataset_identifier_from_qualified_name( - upstream_table.upstream_object_name + upstream_name = ( + self.identifiers.get_dataset_identifier_from_qualified_name( + upstream_table.upstream_object_name + ) ) if upstream_name and ( not self.config.validate_upstreams_against_patterns @@ -521,8 +523,10 @@ def build_finegrained_lineage_upstreams( ) ) ): - upstream_dataset_name = self.get_dataset_identifier_from_qualified_name( - upstream_col.object_name + upstream_dataset_name = ( + self.identifiers.get_dataset_identifier_from_qualified_name( + upstream_col.object_name + ) ) column_upstreams.append( ColumnRef( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 50cb1242016b7..fa76f139c3048 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -262,12 +262,6 @@ def fetch_audit_log( else: yield entry - def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: - # Copied from SnowflakeCommonMixin. - return self.identifiers.snowflake_identifier( - self.filters.cleanup_qualified_name(qualified_name) - ) - def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: json_fields = { "DIRECT_OBJECTS_ACCESSED", @@ -289,7 +283,9 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: for obj in direct_objects_accessed: dataset = self.identifiers.gen_dataset_urn( - self.get_dataset_identifier_from_qualified_name(obj["objectName"]) + self.identifiers.get_dataset_identifier_from_qualified_name( + obj["objectName"] + ) ) columns = set() @@ -312,7 +308,9 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: ) downstream = self.identifiers.gen_dataset_urn( - self.get_dataset_identifier_from_qualified_name(obj["objectName"]) + self.identifiers.get_dataset_identifier_from_qualified_name( + obj["objectName"] + ) ) column_lineage = [] for modified_column in obj["columns"]: @@ -327,7 +325,7 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: upstreams=[ ColumnRef( table=self.identifiers.gen_dataset_urn( - self.get_dataset_identifier_from_qualified_name( + self.identifiers.get_dataset_identifier_from_qualified_name( upstream["objectName"] ) ), @@ -387,6 +385,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeQueriesSourceConfig): ) self.identifiers = SnowflakeIdentifierBuilder( identifier_config=self.config, + structured_reporter=self.report, ) self.connection = self.config.connection.get_connection() diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py index a09e314711881..fc93cdc44d496 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py @@ -76,7 +76,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: report=self.report, # type: ignore connection=self.connection, identifiers=SnowflakeIdentifierBuilder( - identifier_config=SnowflakeIdentifierConfig() + identifier_config=SnowflakeIdentifierConfig(), + structured_reporter=self.report, ), domain_registry=None, profiler=None, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py index e6b4ef1fd9607..1d8a3d892e5e0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py @@ -69,16 +69,18 @@ def _get_tags_on_object_with_propagation( ) -> List[SnowflakeTag]: identifier = "" if domain == SnowflakeObjectDomain.DATABASE: - identifier = self.get_quoted_identifier_for_database(db_name) + identifier = self.identifiers.get_quoted_identifier_for_database(db_name) elif domain == SnowflakeObjectDomain.SCHEMA: assert schema_name is not None - identifier = self.get_quoted_identifier_for_schema(db_name, schema_name) + identifier = self.identifiers.get_quoted_identifier_for_schema( + db_name, schema_name + ) elif ( domain == SnowflakeObjectDomain.TABLE ): # Views belong to this domain as well. assert schema_name is not None assert table_name is not None - identifier = self.get_quoted_identifier_for_table( + identifier = self.identifiers.get_quoted_identifier_for_table( db_name, schema_name, table_name ) else: @@ -140,7 +142,7 @@ def get_column_tags_for_table( elif self.config.extract_tags == TagOption.with_lineage: self.report.num_get_tags_on_columns_for_table_queries += 1 temp_column_tags = self.data_dictionary.get_tags_on_columns_for_table( - quoted_table_name=self.get_quoted_identifier_for_table( + quoted_table_name=self.identifiers.get_quoted_identifier_for_table( db_name, schema_name, table_name ), db_name=db_name, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index 380bfd108ddb4..aff15386c5083 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -247,7 +247,7 @@ def _get_workunits_internal( continue dataset_identifier = ( - self.get_dataset_identifier_from_qualified_name( + self.identifiers.get_dataset_identifier_from_qualified_name( row["OBJECT_NAME"] ) ) @@ -462,8 +462,10 @@ def _get_operation_aspect_work_unit( for obj in event.objects_modified: resource = obj.objectName - dataset_identifier = self.get_dataset_identifier_from_qualified_name( - resource + dataset_identifier = ( + self.identifiers.get_dataset_identifier_from_qualified_name( + resource + ) ) if dataset_identifier not in discovered_datasets: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index 01ab5d16b649c..a1878963d3798 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -112,12 +112,6 @@ def __init__( # TODO: Refactor remaining filtering logic into this class. - @staticmethod - def _combine_identifier_parts( - table_name: str, schema_name: str, db_name: str - ) -> str: - return f"{db_name}.{schema_name}.{table_name}" - def is_dataset_pattern_allowed( self, dataset_name: Optional[str], @@ -161,7 +155,7 @@ def is_dataset_pattern_allowed( if dataset_type.lower() in { SnowflakeObjectDomain.TABLE } and not self.filter_config.table_pattern.allowed( - self.cleanup_qualified_name(dataset_name) + _cleanup_qualified_name(dataset_name, self.structured_reporter) ): return False @@ -169,39 +163,53 @@ def is_dataset_pattern_allowed( SnowflakeObjectDomain.VIEW, SnowflakeObjectDomain.MATERIALIZED_VIEW, } and not self.filter_config.view_pattern.allowed( - self.cleanup_qualified_name(dataset_name) + _cleanup_qualified_name(dataset_name, self.structured_reporter) ): return False return True - # Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers, - # For example "test-database"."test-schema".test_table - # whereas we generate urns without quotes even for quoted identifiers for backward compatibility - # and also unavailability of utility function to identify whether current table/schema/database - # name should be quoted in above method get_dataset_identifier - def cleanup_qualified_name(self, qualified_name: str) -> str: - name_parts = qualified_name.split(".") - if len(name_parts) != 3: - self.structured_reporter.info( - title="Unexpected dataset pattern", - message="We failed to parse a Snowflake qualified name into its constituent parts. " - "DB/schema/table filtering may not work as expected on these entities.", - context=f"{qualified_name} has {len(name_parts)} parts", - ) - return qualified_name.replace('"', "") - return SnowflakeFilter._combine_identifier_parts( - table_name=name_parts[2].strip('"'), - schema_name=name_parts[1].strip('"'), - db_name=name_parts[0].strip('"'), + +def _combine_identifier_parts( + *, table_name: str, schema_name: str, db_name: str +) -> str: + return f"{db_name}.{schema_name}.{table_name}" + + +# Qualified Object names from snowflake audit logs have quotes for for snowflake quoted identifiers, +# For example "test-database"."test-schema".test_table +# whereas we generate urns without quotes even for quoted identifiers for backward compatibility +# and also unavailability of utility function to identify whether current table/schema/database +# name should be quoted in above method get_dataset_identifier +def _cleanup_qualified_name( + qualified_name: str, structured_reporter: SourceReport +) -> str: + name_parts = qualified_name.split(".") + if len(name_parts) != 3: + structured_reporter.info( + title="Unexpected dataset pattern", + message="We failed to parse a Snowflake qualified name into its constituent parts. " + "DB/schema/table filtering may not work as expected on these entities.", + context=f"{qualified_name} has {len(name_parts)} parts", ) + return qualified_name.replace('"', "") + return _combine_identifier_parts( + db_name=name_parts[0].strip('"'), + schema_name=name_parts[1].strip('"'), + table_name=name_parts[2].strip('"'), + ) class SnowflakeIdentifierBuilder: platform = "snowflake" - def __init__(self, identifier_config: SnowflakeIdentifierConfig) -> None: + def __init__( + self, + identifier_config: SnowflakeIdentifierConfig, + structured_reporter: SourceReport, + ) -> None: self.identifier_config = identifier_config + self.structured_reporter = structured_reporter def snowflake_identifier(self, identifier: str) -> str: # to be in in sync with older connector, convert name to lowercase @@ -213,7 +221,7 @@ def get_dataset_identifier( self, table_name: str, schema_name: str, db_name: str ) -> str: return self.snowflake_identifier( - SnowflakeFilter._combine_identifier_parts( + _combine_identifier_parts( table_name=table_name, schema_name=schema_name, db_name=db_name ) ) @@ -226,6 +234,23 @@ def gen_dataset_urn(self, dataset_identifier: str) -> str: env=self.identifier_config.env, ) + def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: + return self.snowflake_identifier( + _cleanup_qualified_name(qualified_name, self.structured_reporter) + ) + + @staticmethod + def get_quoted_identifier_for_database(db_name): + return f'"{db_name}"' + + @staticmethod + def get_quoted_identifier_for_schema(db_name, schema_name): + return f'"{db_name}"."{schema_name}"' + + @staticmethod + def get_quoted_identifier_for_table(db_name, schema_name, table_name): + return f'"{db_name}"."{schema_name}"."{table_name}"' + class SnowflakeCommonMixin(SnowflakeStructuredReportMixin): platform = "snowflake" @@ -239,29 +264,7 @@ def structured_reporter(self) -> SourceReport: @cached_property def identifiers(self) -> SnowflakeIdentifierBuilder: - return SnowflakeIdentifierBuilder(self.config) - - # TODO: These methods should be moved to SnowflakeIdentifierBuilder. - - @staticmethod - def get_quoted_identifier_for_database(db_name): - return f'"{db_name}"' - - @staticmethod - def get_quoted_identifier_for_schema(db_name, schema_name): - return f'"{db_name}"."{schema_name}"' - - def get_dataset_identifier_from_qualified_name(self, qualified_name: str) -> str: - filter = SnowflakeFilter( - filter_config=self.config, structured_reporter=self.report - ) - return self.identifiers.snowflake_identifier( - filter.cleanup_qualified_name(qualified_name) - ) - - @staticmethod - def get_quoted_identifier_for_table(db_name, schema_name, table_name): - return f'"{db_name}"."{schema_name}"."{table_name}"' + return SnowflakeIdentifierBuilder(self.config, self.report) # Note - decide how to construct user urns. # Historically urns were created using part before @ from user's email. diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 1f3a9ca320237..5104aa1df5898 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -140,7 +140,9 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.filters = SnowflakeFilter( filter_config=self.config, structured_reporter=self.report ) - self.identifiers = SnowflakeIdentifierBuilder(identifier_config=self.config) + self.identifiers = SnowflakeIdentifierBuilder( + identifier_config=self.config, structured_reporter=self.report + ) self.connection = self.config.get_connection() From 0ba7bc08758503250991def5aa00bdf63ae6bd08 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 20:30:06 -0700 Subject: [PATCH 13/16] support copy history in snowflake-queries --- .../source/snowflake/snowflake_lineage_v2.py | 24 ++-- .../source/snowflake/snowflake_queries.py | 105 +++++++++++------- 2 files changed, 79 insertions(+), 50 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 9a21033a0cc1d..3f956ce3aa224 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -336,38 +336,44 @@ def _populate_external_lineage_from_copy_history( db_row, discovered_tables ) if known_lineage_mapping: + self.report.num_external_table_edges_scanned += 1 yield known_lineage_mapping except Exception as e: if isinstance(e, SnowflakePermissionError): error_msg = "Failed to get external lineage. Please grant imported privileges on SNOWFLAKE database. " self.warn_if_stateful_else_error(LINEAGE_PERMISSION_ERROR, error_msg) else: - logger.debug(e, exc_info=e) self.structured_reporter.warning( - "external_lineage", - f"Populating table external lineage from Snowflake failed due to error {e}.", + "Error fetching external lineage from Snowflake", + exc=e, ) self.report_status(EXTERNAL_LINEAGE, False) + @classmethod def _process_external_lineage_result_row( - self, db_row: dict, discovered_tables: List[str] + cls, + db_row: dict, + discovered_tables: Optional[List[str]], + identifiers: SnowflakeIdentifierBuilder, ) -> Optional[KnownLineageMapping]: # key is the down-stream table name - key: str = self.identifiers.get_dataset_identifier_from_qualified_name( + key: str = identifiers.get_dataset_identifier_from_qualified_name( db_row["DOWNSTREAM_TABLE_NAME"] ) - if key not in discovered_tables: + if discovered_tables is not None and key not in discovered_tables: return None if db_row["UPSTREAM_LOCATIONS"] is not None: external_locations = json.loads(db_row["UPSTREAM_LOCATIONS"]) + loc: str for loc in external_locations: if loc.startswith("s3://"): - self.report.num_external_table_edges_scanned += 1 return KnownLineageMapping( - upstream_urn=make_s3_urn_for_lineage(loc, self.config.env), - downstream_urn=self.identifiers.gen_dataset_urn(key), + upstream_urn=make_s3_urn_for_lineage( + loc, identifiers.identifier_config.env + ), + downstream_urn=identifiers.gen_dataset_urn(key), ) return None diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index fa76f139c3048..a580b5084efbb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -33,6 +33,9 @@ SnowflakeConnection, SnowflakeConnectionConfig, ) +from datahub.ingestion.source.snowflake.snowflake_lineage_v2 import ( + SnowflakeLineageExtractor, +) from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery from datahub.ingestion.source.snowflake.snowflake_utils import ( SnowflakeFilter, @@ -97,7 +100,9 @@ class SnowflakeQueriesSourceConfig( @dataclass class SnowflakeQueriesExtractorReport(Report): - audit_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) sql_aggregator: Optional[SqlAggregatorReport] = None @@ -118,6 +123,7 @@ def __init__( identifiers: SnowflakeIdentifierBuilder, graph: Optional[DataHubGraph] = None, schema_resolver: Optional[SchemaResolver] = None, + discovered_tables: Optional[List[str]] = None, ): self.connection = connection @@ -125,6 +131,7 @@ def __init__( self.report = SnowflakeQueriesExtractorReport() self.filters = filters self.identifiers = identifiers + self.discovered_tables = discovered_tables self._structured_report = structured_report @@ -174,6 +181,9 @@ def is_temp_table(self, name: str) -> bool: ) def is_allowed_table(self, name: str) -> bool: + if self.discovered_tables and name not in self.discovered_tables: + return False + return self.filters.is_dataset_pattern_allowed( name, SnowflakeObjectDomain.TABLE ) @@ -196,9 +206,15 @@ def get_workunits_internal( shared_connection = ConnectionWrapper(audit_log_file) queries = FileBackedList(shared_connection) - logger.info("Fetching audit log") - with self.report.audit_log_fetch_timer: - for entry in self.fetch_audit_log(): + with self.report.copy_history_fetch_timer: + for entry in self.fetch_copy_history(): + queries.append(entry) + + # TODO: Add "show external tables" lineage to the main schema extractor. + # Because it's not a time-based thing, it doesn't really make sense in the snowflake-queries extractor. + + with self.report.query_log_fetch_timer: + for entry in self.fetch_query_log(): queries.append(entry) with self.report.audit_log_load_timer: @@ -207,55 +223,66 @@ def get_workunits_internal( yield from auto_workunit(self.aggregator.gen_metadata()) - def fetch_audit_log( - self, - ) -> Iterable[Union[KnownLineageMapping, PreparsedQuery]]: - """ - # TODO: we need to fetch this info from somewhere - discovered_tables = [] - - snowflake_lineage_v2 = SnowflakeLineageExtractor( - config=self.config, # type: ignore - report=self.report, # type: ignore - dataset_urn_builder=self.gen_dataset_urn, - redundant_run_skip_handler=None, - sql_aggregator=self.aggregator, # TODO this should be unused + def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: + # Derived from _populate_external_lineage_from_copy_history. + + query: str = SnowflakeQuery.copy_lineage_history( + start_time_millis=int(self.config.window.start_time.timestamp() * 1000), + end_time_millis=int(self.config.window.end_time.timestamp() * 1000), + downstreams_deny_pattern=self.config.temporary_tables_pattern, ) - for ( - known_lineage_mapping - ) in snowflake_lineage_v2._populate_external_lineage_from_copy_history( - discovered_tables=discovered_tables - ): - interim_results.append(known_lineage_mapping) - - for ( - known_lineage_mapping - ) in snowflake_lineage_v2._populate_external_lineage_from_show_query( - discovered_tables=discovered_tables - ): - interim_results.append(known_lineage_mapping) - """ - - audit_log_query = _build_enriched_audit_log_query( + try: + logger.info("Fetching copy history from Snowflake") + resp = self.connection.query(query) + + for row in resp: + try: + result = ( + SnowflakeLineageExtractor._process_external_lineage_result_row( + row, + discovered_tables=self.discovered_tables, + identifiers=self.identifiers, + ) + ) + except Exception as e: + self.structured_reporter.warning( + "Error parsing copy history row", + context=f"{row}", + exc=e, + ) + else: + if result: + yield result + except Exception as e: + self.structured_reporter.failure( + "Error fetching copy history from Snowflake", + exc=e, + ) + + def fetch_query_log( + self, + ) -> Iterable[PreparsedQuery]: + query_log_query = _build_enriched_query_log_query( start_time=self.config.window.start_time, end_time=self.config.window.end_time, bucket_duration=self.config.window.bucket_duration, deny_usernames=self.config.deny_usernames, ) - resp = self.connection.query(audit_log_query) + logger.info("Fetching query log from Snowflake") + resp = self.connection.query(query_log_query) for i, row in enumerate(resp): if i % 1000 == 0: - logger.info(f"Processed {i} audit log rows") + logger.info(f"Processed {i} query log rows") assert isinstance(row, dict) try: entry = self._parse_audit_log_row(row) except Exception as e: self.structured_reporter.warning( - "Error parsing audit log row", + "Error parsing query log row", context=f"{row}", exc=e, ) @@ -340,10 +367,6 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: ) ) - # TODO: Support filtering the table names. - # if objects_modified: - # breakpoint() - # TODO implement email address mapping user = CorpUserUrn(res["user_name"]) @@ -419,7 +442,7 @@ def get_report(self) -> SnowflakeQueriesSourceReport: _MAX_TABLES_PER_QUERY = 20 -def _build_enriched_audit_log_query( +def _build_enriched_query_log_query( start_time: datetime, end_time: datetime, bucket_duration: BucketDuration, From 07533f14cbab0e6910be5c9600f0e78aa40cc4f8 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 20:43:10 -0700 Subject: [PATCH 14/16] add reporting context manager --- .../src/datahub/ingestion/api/source.py | 20 +++++++++ .../source/snowflake/snowflake_lineage_v2.py | 2 +- .../source/snowflake/snowflake_queries.py | 45 ++++++++++--------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index b79f69970a634..788bec97a6488 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -1,3 +1,4 @@ +import contextlib import datetime import logging from abc import ABCMeta, abstractmethod @@ -10,6 +11,7 @@ Dict, Generic, Iterable, + Iterator, List, Optional, Sequence, @@ -289,6 +291,24 @@ def info( StructuredLogLevel.INFO, message, title, context, exc, log=log ) + @contextlib.contextmanager + def report_exc( + self, + message: LiteralString, + title: Optional[LiteralString] = None, + context: Optional[str] = None, + level: StructuredLogLevel = StructuredLogLevel.ERROR, + ) -> Iterator[None]: + # Convenience method that helps avoid boilerplate try/except blocks. + # TODO: I'm not super happy with the naming here - it's not obvious that this + # suppresses the exception in addition to reporting it. + try: + yield + except Exception as exc: + self._structured_logs.report_log( + level, message=message, title=title, context=context, exc=exc + ) + def __post_init__(self) -> None: self.start_time = datetime.datetime.now() self.running_time: datetime.timedelta = datetime.timedelta(seconds=0) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 3f956ce3aa224..151e9fb631620 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -333,7 +333,7 @@ def _populate_external_lineage_from_copy_history( try: for db_row in self.connection.query(query): known_lineage_mapping = self._process_external_lineage_result_row( - db_row, discovered_tables + db_row, discovered_tables, identifiers=self.identifiers ) if known_lineage_mapping: self.report.num_external_table_edges_scanned += 1 diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index a580b5084efbb..7fc92d11bee72 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -205,6 +205,7 @@ def get_workunits_internal( shared_connection = ConnectionWrapper(audit_log_file) queries = FileBackedList(shared_connection) + entry: Union[KnownLineageMapping, PreparsedQuery] with self.report.copy_history_fetch_timer: for entry in self.fetch_copy_history(): @@ -232,7 +233,9 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: downstreams_deny_pattern=self.config.temporary_tables_pattern, ) - try: + with self.structured_reporter.report_exc( + "Error fetching copy history from Snowflake" + ): logger.info("Fetching copy history from Snowflake") resp = self.connection.query(query) @@ -254,11 +257,6 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: else: if result: yield result - except Exception as e: - self.structured_reporter.failure( - "Error fetching copy history from Snowflake", - exc=e, - ) def fetch_query_log( self, @@ -270,24 +268,27 @@ def fetch_query_log( deny_usernames=self.config.deny_usernames, ) - logger.info("Fetching query log from Snowflake") - resp = self.connection.query(query_log_query) + with self.structured_reporter.report_exc( + "Error fetching query log from Snowflake" + ): + logger.info("Fetching query log from Snowflake") + resp = self.connection.query(query_log_query) - for i, row in enumerate(resp): - if i % 1000 == 0: - logger.info(f"Processed {i} query log rows") + for i, row in enumerate(resp): + if i % 1000 == 0: + logger.info(f"Processed {i} query log rows") - assert isinstance(row, dict) - try: - entry = self._parse_audit_log_row(row) - except Exception as e: - self.structured_reporter.warning( - "Error parsing query log row", - context=f"{row}", - exc=e, - ) - else: - yield entry + assert isinstance(row, dict) + try: + entry = self._parse_audit_log_row(row) + except Exception as e: + self.structured_reporter.warning( + "Error parsing query log row", + context=f"{row}", + exc=e, + ) + else: + yield entry def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: json_fields = { From db3199e1235f828d43d12bb3714cd853fae033dd Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Mon, 15 Jul 2024 20:50:51 -0700 Subject: [PATCH 15/16] remove extra self.logger calls --- .../datahub/ingestion/source/snowflake/snowflake_profiler.py | 1 - .../datahub/ingestion/source/snowflake/snowflake_schema.py | 4 +--- .../datahub/ingestion/source/snowflake/snowflake_shares.py | 1 - .../datahub/ingestion/source/snowflake/snowflake_summary.py | 2 -- .../src/datahub/ingestion/source/snowflake/snowflake_tag.py | 1 - .../src/datahub/ingestion/source/snowflake/snowflake_v2.py | 5 ++--- 6 files changed, 3 insertions(+), 11 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py index 89dc949e844f4..422bda5284dbc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_profiler.py @@ -37,7 +37,6 @@ def __init__( super().__init__(config, report, self.platform, state_handler) self.config: SnowflakeV2Config = config self.report: SnowflakeV2Report = report - self.logger = logger self.database_default_schema: Dict[str, str] = dict() def get_workunits( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py index ce8f20d23aa6b..600292c2c9942 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_schema.py @@ -185,8 +185,6 @@ def get_column_tags_for_table( class SnowflakeDataDictionary(SupportsAsObj): def __init__(self, connection: SnowflakeConnection) -> None: - self.logger = logger - self.connection = connection def as_obj(self) -> Dict[str, Dict[str, int]]: @@ -514,7 +512,7 @@ def get_tags_for_database_without_propagation( ) else: # This should never happen. - self.logger.error(f"Encountered an unexpected domain: {domain}") + logger.error(f"Encountered an unexpected domain: {domain}") continue return tags diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py index 9550e548c9949..794a6f4a59f46 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_shares.py @@ -29,7 +29,6 @@ def __init__( ) -> None: self.config = config self.report = report - self.logger = logger def get_shares_workunits( self, databases: List[SnowflakeDatabase] diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py index fc93cdc44d496..72952f6b76e8b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_summary.py @@ -1,5 +1,4 @@ import dataclasses -import logging from collections import defaultdict from typing import Dict, Iterable, List, Optional @@ -65,7 +64,6 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeSummaryConfig): super().__init__(ctx) self.config: SnowflakeSummaryConfig = config self.report: SnowflakeSummaryReport = SnowflakeSummaryReport() - self.logger = logging.getLogger(__name__) self.connection = self.config.get_connection() diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py index 1d8a3d892e5e0..9307eb607be26 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_tag.py @@ -27,7 +27,6 @@ def __init__( self.config = config self.data_dictionary = data_dictionary self.report = report - self.logger = logger self.tag_cache: Dict[str, _SnowflakeTagCache] = {} diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 5104aa1df5898..b335e0b0d414b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -135,7 +135,6 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): super().__init__(config, ctx) self.config: SnowflakeV2Config = config self.report: SnowflakeV2Report = SnowflakeV2Report() - self.logger = logger self.filters = SnowflakeFilter( filter_config=self.config, structured_reporter=self.report @@ -144,14 +143,14 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): identifier_config=self.config, structured_reporter=self.report ) - self.connection = self.config.get_connection() - self.domain_registry: Optional[DomainRegistry] = None if self.config.domain: self.domain_registry = DomainRegistry( cached_domains=[k for k in self.config.domain], graph=self.ctx.graph ) + self.connection = self.config.get_connection() + # For database, schema, tables, views, etc self.data_dictionary = SnowflakeDataDictionary(connection=self.connection) self.lineage_extractor: Optional[SnowflakeLineageExtractor] = None From dcb786f9b270b556fc0d86d8a04659043adbecbe Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Tue, 16 Jul 2024 17:51:38 -0700 Subject: [PATCH 16/16] fix some bugs --- .../ingestion/source/snowflake/snowflake_queries.py | 13 +++++++++---- .../ingestion/source/snowflake/snowflake_v2.py | 4 +++- .../datahub/sql_parsing/sql_parsing_aggregator.py | 8 ++++++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 7fc92d11bee72..d5b8f98e40075 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -368,8 +368,9 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: ) ) - # TODO implement email address mapping - user = CorpUserUrn(res["user_name"]) + # TODO: Fetch email addresses from Snowflake to map user -> email + # TODO: Support email_domain fallback for generating user urns. + user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"])) timestamp: datetime = res["query_start_time"] timestamp = timestamp.astimezone(timezone.utc) @@ -380,14 +381,18 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: ) entry = PreparsedQuery( - query_id=res["query_fingerprint"], + # Despite having Snowflake's fingerprints available, our own fingerprinting logic does a better + # job at eliminating redundant / repetitive queries. As such, we don't include the fingerprint + # here so that the aggregator auto-generates one. + # query_id=res["query_fingerprint"], + query_id=None, query_text=res["query_text"], upstreams=upstreams, downstream=downstream, column_lineage=column_lineage, column_usage=column_usage, inferred_schema=None, - confidence_score=1, + confidence_score=1.0, query_count=res["query_count"], user=user, timestamp=timestamp, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index b335e0b0d414b..a2a7ba004a921 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -156,7 +156,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.lineage_extractor: Optional[SnowflakeLineageExtractor] = None self.aggregator: Optional[SqlParsingAggregator] = None - if self.config.include_table_lineage: + if self.config.use_queries_v2 or self.config.include_table_lineage: self.aggregator = SqlParsingAggregator( platform=self.identifiers.platform, platform_instance=self.config.platform_instance, @@ -179,6 +179,8 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): ) self.report.sql_aggregator = self.aggregator.report + if self.config.include_table_lineage: + assert self.aggregator is not None redundant_lineage_run_skip_handler: Optional[ RedundantLineageRunSkipHandler ] = None diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index d74faf72fb54e..56602d747e3a6 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -276,8 +276,12 @@ def __init__( self.generate_usage_statistics = generate_usage_statistics self.generate_query_usage_statistics = generate_query_usage_statistics self.generate_operations = generate_operations - if self.generate_queries and not self.generate_lineage: - raise ValueError("Queries will only be generated if lineage is enabled") + if self.generate_queries and not ( + self.generate_lineage or self.generate_query_usage_statistics + ): + logger.warning( + "Queries will not be generated, as neither lineage nor query usage statistics are enabled" + ) self.usage_config = usage_config if (