From b796f5932ee79778d815060686050dd7707dbd9d Mon Sep 17 00:00:00 2001 From: Dushyant Bhalgami Date: Mon, 1 Jul 2024 13:40:53 +0200 Subject: [PATCH] fix(ingestion/prefect-plugin): auth token with datasets --- .../src/prefect_datahub/datahub_emitter.py | 120 ++++++++++-------- .../src/prefect_datahub/example/flow.py | 40 ++++-- .../src/prefect_datahub/example/save_block.py | 5 +- .../tests/unit/test_datahub_emitter.py | 4 +- 4 files changed, 100 insertions(+), 69 deletions(-) diff --git a/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/datahub_emitter.py b/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/datahub_emitter.py index d2bce2a959c216..5991503416aec7 100644 --- a/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/datahub_emitter.py +++ b/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/datahub_emitter.py @@ -24,6 +24,7 @@ from prefect.client.schemas.objects import Flow from prefect.context import FlowRunContext, TaskRunContext from prefect.settings import PREFECT_API_URL +from pydantic.v1 import SecretStr from prefect_datahub.entities import _Entity @@ -86,14 +87,18 @@ class DatahubEmitter(Block): datahub_rest_url: str = "http://localhost:8080" env: str = builder.DEFAULT_ENV platform_instance: Optional[str] = None + token: Optional[SecretStr] = None + _datajobs_to_emit: Dict[str, Any] = {} def __init__(self, *args: Any, **kwargs: Any): """ Initialize datahub rest emitter """ super().__init__(*args, **kwargs) - self.datajobs_to_emit: Dict[str, _Entity] = {} - self.emitter = DatahubRestEmitter(gms_server=self.datahub_rest_url) + # self._datajobs_to_emit: Dict[str, _Entity] = {} + + token = self.token.get_secret_value() if self.token is not None else None + self.emitter = DatahubRestEmitter(gms_server=self.datahub_rest_url, token=token) self.emitter.test_connection() def _entities_to_urn_list(self, iolets: List[_Entity]) -> List[DatasetUrn]: @@ -333,57 +338,59 @@ def _emit_tasks( dataflow (DataFlow): The datahub dataflow entity. workspace_name Optional(str): The prefect cloud workpace name. """ - assert flow_run_ctx.flow_run + try: + assert flow_run_ctx.flow_run - graph_json = asyncio.run( - self._get_flow_run_graph(str(flow_run_ctx.flow_run.id)) - ) + graph_json = asyncio.run( + self._get_flow_run_graph(str(flow_run_ctx.flow_run.id)) + ) - if graph_json is None: - return + if graph_json is None: + return - task_run_key_map: Dict[str, str] = {} + task_run_key_map: Dict[str, str] = {} - for prefect_future in flow_run_ctx.task_run_futures: - if prefect_future.task_run is not None: - task_run_key_map[ - str(prefect_future.task_run.id) - ] = prefect_future.task_run.task_key + for prefect_future in flow_run_ctx.task_run_futures: + if prefect_future.task_run is not None: + task_run_key_map[ + str(prefect_future.task_run.id) + ] = prefect_future.task_run.task_key - get_run_logger().info("Emitting tasks to datahub...") + for node in graph_json: + datajob_urn = DataJobUrn.create_from_ids( + data_flow_urn=str(dataflow.urn), + job_id=task_run_key_map[node[ID]], + ) - for node in graph_json: - datajob_urn = DataJobUrn.create_from_ids( - data_flow_urn=str(dataflow.urn), - job_id=task_run_key_map[node[ID]], - ) + datajob: Optional[DataJob] = None - datajob: Optional[DataJob] = None + if str(datajob_urn) in self._datajobs_to_emit: + datajob = cast(DataJob, self._datajobs_to_emit[str(datajob_urn)]) + else: + datajob = self._generate_datajob( + flow_run_ctx=flow_run_ctx, task_key=task_run_key_map[node[ID]] + ) - if str(datajob_urn) in self.datajobs_to_emit: - datajob = cast(DataJob, self.datajobs_to_emit[str(datajob_urn)]) - else: - datajob = self._generate_datajob( - flow_run_ctx=flow_run_ctx, task_key=task_run_key_map[node[ID]] - ) + if datajob is not None: + for each in node[UPSTREAM_DEPENDENCIES]: + upstream_task_urn = DataJobUrn.create_from_ids( + data_flow_urn=str(dataflow.urn), + job_id=task_run_key_map[each[ID]], + ) + datajob.upstream_urns.extend([upstream_task_urn]) - if datajob is not None: - for each in node[UPSTREAM_DEPENDENCIES]: - upstream_task_urn = DataJobUrn.create_from_ids( - data_flow_urn=str(dataflow.urn), - job_id=task_run_key_map[each[ID]], - ) - datajob.upstream_urns.extend([upstream_task_urn]) - datajob.emit(self.emitter) + datajob.emit(self.emitter) - if workspace_name is not None: - self._emit_browsepath(str(datajob.urn), workspace_name) + if workspace_name is not None: + self._emit_browsepath(str(datajob.urn), workspace_name) - self._emit_task_run( - datajob=datajob, - flow_run_name=flow_run_ctx.flow_run.name, - task_run_id=UUID(node[ID]), - ) + self._emit_task_run( + datajob=datajob, + flow_run_name=flow_run_ctx.flow_run.name, + task_run_id=UUID(node[ID]), + ) + except Exception: + get_run_logger().debug(traceback.format_exc()) def _emit_flow_run(self, dataflow: DataFlow, flow_run_id: UUID) -> None: """ @@ -583,22 +590,25 @@ def etl(): datahub_emitter.emit_flow() ``` """ - flow_run_ctx = FlowRunContext.get() - task_run_ctx = TaskRunContext.get() + try: + flow_run_ctx = FlowRunContext.get() + task_run_ctx = TaskRunContext.get() - assert flow_run_ctx - assert task_run_ctx + assert flow_run_ctx + assert task_run_ctx - datajob = self._generate_datajob( - flow_run_ctx=flow_run_ctx, task_run_ctx=task_run_ctx - ) + datajob = self._generate_datajob( + flow_run_ctx=flow_run_ctx, task_run_ctx=task_run_ctx + ) - if datajob is not None: - if inputs is not None: - datajob.inlets.extend(self._entities_to_urn_list(inputs)) - if outputs is not None: - datajob.outlets.extend(self._entities_to_urn_list(outputs)) - self.datajobs_to_emit[str(datajob.urn)] = cast(_Entity, datajob) + if datajob is not None: + if inputs is not None: + datajob.inlets.extend(self._entities_to_urn_list(inputs)) + if outputs is not None: + datajob.outlets.extend(self._entities_to_urn_list(outputs)) + self._datajobs_to_emit[str(datajob.urn)] = cast(_Entity, datajob) + except Exception: + get_run_logger().debug(traceback.format_exc()) def emit_flow(self) -> None: """ diff --git a/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/flow.py b/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/flow.py index 9652ee3f56aa9e..8d65ff0d82dc1c 100644 --- a/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/flow.py +++ b/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/flow.py @@ -1,4 +1,5 @@ import asyncio +from typing import List, Tuple from prefect import flow, task @@ -8,31 +9,50 @@ async def load_datahub_emitter(): datahub_emitter = DatahubEmitter() - return datahub_emitter.load("datahub-block-7") + emitter = datahub_emitter.load("BLOCK-ID") + print(emitter) + return emitter @task(name="Extract", description="Extract the data") -def extract(): +def extract() -> str: data = "This is data" return data @task(name="Transform", description="Transform the data") -def transform(data, datahub_emitter): - data = data.split(" ") +def transform( + data: str, datahub_emitter: DatahubEmitter +) -> Tuple[List[str], DatahubEmitter]: + data_list_str = data.split(" ") datahub_emitter.add_task( - inputs=[Dataset("snowflake", "mydb.schema.tableX")], - outputs=[Dataset("snowflake", "mydb.schema.tableY")], + inputs=[ + Dataset( + platform="snowflake", + name="mydb.schema.tableA", + env=datahub_emitter.env, + platform_instance=datahub_emitter.platform_instance, + ) + ], + outputs=[ + Dataset( + platform="snowflake", + name="mydb.schema.tableB", + env=datahub_emitter.env, + platform_instance=datahub_emitter.platform_instance, + ) + ], ) - return data + return data_list_str, datahub_emitter @flow(name="ETL", description="Extract transform load flow") -def etl(): +def etl() -> None: datahub_emitter = asyncio.run(load_datahub_emitter()) data = extract() - data = transform(data, datahub_emitter) - datahub_emitter.emit_flow() + return_value = transform(data, datahub_emitter) + emitter = return_value[1] + emitter.emit_flow() etl() diff --git a/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/save_block.py b/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/save_block.py index 7656b13a4a49fc..d4f7a932b0929e 100644 --- a/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/save_block.py +++ b/metadata-ingestion-modules/prefect-plugin/src/prefect_datahub/example/save_block.py @@ -6,11 +6,12 @@ async def save_datahub_emitter(): datahub_emitter = DatahubEmitter( datahub_rest_url="http://localhost:8080", - env="PROD", + env="DEV", platform_instance="local_prefect", + token=None, # generate auth token in the datahub and provide here if gms endpoint is secure ) - await datahub_emitter.save("datahub-block-7", overwrite=True) + await datahub_emitter.save("BLOCK-ID", overwrite=True) asyncio.run(save_datahub_emitter()) diff --git a/metadata-ingestion-modules/prefect-plugin/tests/unit/test_datahub_emitter.py b/metadata-ingestion-modules/prefect-plugin/tests/unit/test_datahub_emitter.py index c1586a0aa02f47..b7b57df666d2c0 100644 --- a/metadata-ingestion-modules/prefect-plugin/tests/unit/test_datahub_emitter.py +++ b/metadata-ingestion-modules/prefect-plugin/tests/unit/test_datahub_emitter.py @@ -547,8 +547,8 @@ def test_add_task(mock_emit, mock_run_context): f"(prefect,{flow_run_ctx.flow.name},PROD),{task_run_ctx.task.task_key})" ) - assert expected_datajob_urn in datahub_emitter.datajobs_to_emit.keys() - actual_datajob = datahub_emitter.datajobs_to_emit[expected_datajob_urn] + assert expected_datajob_urn in datahub_emitter._datajobs_to_emit.keys() + actual_datajob = datahub_emitter._datajobs_to_emit[expected_datajob_urn] assert isinstance(actual_datajob, DataJob) assert str(actual_datajob.flow_urn) == "urn:li:dataFlow:(prefect,etl,PROD)" assert actual_datajob.name == task_run_ctx.task.name