Skip to content

Commit

Permalink
fix(ingestion/prefect-plugin): auth token with datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
dushayntAW committed Jul 1, 2024
1 parent f756f62 commit b796f59
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from prefect.client.schemas.objects import Flow
from prefect.context import FlowRunContext, TaskRunContext
from prefect.settings import PREFECT_API_URL
from pydantic.v1 import SecretStr

from prefect_datahub.entities import _Entity

Expand Down Expand Up @@ -86,14 +87,18 @@ class DatahubEmitter(Block):
datahub_rest_url: str = "http://localhost:8080"
env: str = builder.DEFAULT_ENV
platform_instance: Optional[str] = None
token: Optional[SecretStr] = None
_datajobs_to_emit: Dict[str, Any] = {}

def __init__(self, *args: Any, **kwargs: Any):
"""
Initialize datahub rest emitter
"""
super().__init__(*args, **kwargs)
self.datajobs_to_emit: Dict[str, _Entity] = {}
self.emitter = DatahubRestEmitter(gms_server=self.datahub_rest_url)
# self._datajobs_to_emit: Dict[str, _Entity] = {}

token = self.token.get_secret_value() if self.token is not None else None
self.emitter = DatahubRestEmitter(gms_server=self.datahub_rest_url, token=token)
self.emitter.test_connection()

def _entities_to_urn_list(self, iolets: List[_Entity]) -> List[DatasetUrn]:
Expand Down Expand Up @@ -333,57 +338,59 @@ def _emit_tasks(
dataflow (DataFlow): The datahub dataflow entity.
workspace_name Optional(str): The prefect cloud workpace name.
"""
assert flow_run_ctx.flow_run
try:
assert flow_run_ctx.flow_run

graph_json = asyncio.run(
self._get_flow_run_graph(str(flow_run_ctx.flow_run.id))
)
graph_json = asyncio.run(
self._get_flow_run_graph(str(flow_run_ctx.flow_run.id))
)

if graph_json is None:
return
if graph_json is None:
return

task_run_key_map: Dict[str, str] = {}
task_run_key_map: Dict[str, str] = {}

for prefect_future in flow_run_ctx.task_run_futures:
if prefect_future.task_run is not None:
task_run_key_map[
str(prefect_future.task_run.id)
] = prefect_future.task_run.task_key
for prefect_future in flow_run_ctx.task_run_futures:
if prefect_future.task_run is not None:
task_run_key_map[
str(prefect_future.task_run.id)
] = prefect_future.task_run.task_key

get_run_logger().info("Emitting tasks to datahub...")
for node in graph_json:
datajob_urn = DataJobUrn.create_from_ids(
data_flow_urn=str(dataflow.urn),
job_id=task_run_key_map[node[ID]],
)

for node in graph_json:
datajob_urn = DataJobUrn.create_from_ids(
data_flow_urn=str(dataflow.urn),
job_id=task_run_key_map[node[ID]],
)
datajob: Optional[DataJob] = None

datajob: Optional[DataJob] = None
if str(datajob_urn) in self._datajobs_to_emit:
datajob = cast(DataJob, self._datajobs_to_emit[str(datajob_urn)])
else:
datajob = self._generate_datajob(
flow_run_ctx=flow_run_ctx, task_key=task_run_key_map[node[ID]]
)

if str(datajob_urn) in self.datajobs_to_emit:
datajob = cast(DataJob, self.datajobs_to_emit[str(datajob_urn)])
else:
datajob = self._generate_datajob(
flow_run_ctx=flow_run_ctx, task_key=task_run_key_map[node[ID]]
)
if datajob is not None:
for each in node[UPSTREAM_DEPENDENCIES]:
upstream_task_urn = DataJobUrn.create_from_ids(
data_flow_urn=str(dataflow.urn),
job_id=task_run_key_map[each[ID]],
)
datajob.upstream_urns.extend([upstream_task_urn])

if datajob is not None:
for each in node[UPSTREAM_DEPENDENCIES]:
upstream_task_urn = DataJobUrn.create_from_ids(
data_flow_urn=str(dataflow.urn),
job_id=task_run_key_map[each[ID]],
)
datajob.upstream_urns.extend([upstream_task_urn])
datajob.emit(self.emitter)
datajob.emit(self.emitter)

if workspace_name is not None:
self._emit_browsepath(str(datajob.urn), workspace_name)
if workspace_name is not None:
self._emit_browsepath(str(datajob.urn), workspace_name)

self._emit_task_run(
datajob=datajob,
flow_run_name=flow_run_ctx.flow_run.name,
task_run_id=UUID(node[ID]),
)
self._emit_task_run(
datajob=datajob,
flow_run_name=flow_run_ctx.flow_run.name,
task_run_id=UUID(node[ID]),
)
except Exception:
get_run_logger().debug(traceback.format_exc())

def _emit_flow_run(self, dataflow: DataFlow, flow_run_id: UUID) -> None:
"""
Expand Down Expand Up @@ -583,22 +590,25 @@ def etl():
datahub_emitter.emit_flow()
```
"""
flow_run_ctx = FlowRunContext.get()
task_run_ctx = TaskRunContext.get()
try:
flow_run_ctx = FlowRunContext.get()
task_run_ctx = TaskRunContext.get()

assert flow_run_ctx
assert task_run_ctx
assert flow_run_ctx
assert task_run_ctx

datajob = self._generate_datajob(
flow_run_ctx=flow_run_ctx, task_run_ctx=task_run_ctx
)
datajob = self._generate_datajob(
flow_run_ctx=flow_run_ctx, task_run_ctx=task_run_ctx
)

if datajob is not None:
if inputs is not None:
datajob.inlets.extend(self._entities_to_urn_list(inputs))
if outputs is not None:
datajob.outlets.extend(self._entities_to_urn_list(outputs))
self.datajobs_to_emit[str(datajob.urn)] = cast(_Entity, datajob)
if datajob is not None:
if inputs is not None:
datajob.inlets.extend(self._entities_to_urn_list(inputs))
if outputs is not None:
datajob.outlets.extend(self._entities_to_urn_list(outputs))
self._datajobs_to_emit[str(datajob.urn)] = cast(_Entity, datajob)
except Exception:
get_run_logger().debug(traceback.format_exc())

def emit_flow(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from typing import List, Tuple

from prefect import flow, task

Expand All @@ -8,31 +9,50 @@

async def load_datahub_emitter():
datahub_emitter = DatahubEmitter()
return datahub_emitter.load("datahub-block-7")
emitter = datahub_emitter.load("BLOCK-ID")
print(emitter)
return emitter


@task(name="Extract", description="Extract the data")
def extract():
def extract() -> str:
data = "This is data"
return data


@task(name="Transform", description="Transform the data")
def transform(data, datahub_emitter):
data = data.split(" ")
def transform(
data: str, datahub_emitter: DatahubEmitter
) -> Tuple[List[str], DatahubEmitter]:
data_list_str = data.split(" ")
datahub_emitter.add_task(
inputs=[Dataset("snowflake", "mydb.schema.tableX")],
outputs=[Dataset("snowflake", "mydb.schema.tableY")],
inputs=[
Dataset(
platform="snowflake",
name="mydb.schema.tableA",
env=datahub_emitter.env,
platform_instance=datahub_emitter.platform_instance,
)
],
outputs=[
Dataset(
platform="snowflake",
name="mydb.schema.tableB",
env=datahub_emitter.env,
platform_instance=datahub_emitter.platform_instance,
)
],
)
return data
return data_list_str, datahub_emitter


@flow(name="ETL", description="Extract transform load flow")
def etl():
def etl() -> None:
datahub_emitter = asyncio.run(load_datahub_emitter())
data = extract()
data = transform(data, datahub_emitter)
datahub_emitter.emit_flow()
return_value = transform(data, datahub_emitter)
emitter = return_value[1]
emitter.emit_flow()


etl()
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
async def save_datahub_emitter():
datahub_emitter = DatahubEmitter(
datahub_rest_url="http://localhost:8080",
env="PROD",
env="DEV",
platform_instance="local_prefect",
token=None, # generate auth token in the datahub and provide here if gms endpoint is secure
)

await datahub_emitter.save("datahub-block-7", overwrite=True)
await datahub_emitter.save("BLOCK-ID", overwrite=True)


asyncio.run(save_datahub_emitter())
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,8 @@ def test_add_task(mock_emit, mock_run_context):
f"(prefect,{flow_run_ctx.flow.name},PROD),{task_run_ctx.task.task_key})"
)

assert expected_datajob_urn in datahub_emitter.datajobs_to_emit.keys()
actual_datajob = datahub_emitter.datajobs_to_emit[expected_datajob_urn]
assert expected_datajob_urn in datahub_emitter._datajobs_to_emit.keys()
actual_datajob = datahub_emitter._datajobs_to_emit[expected_datajob_urn]
assert isinstance(actual_datajob, DataJob)
assert str(actual_datajob.flow_urn) == "urn:li:dataFlow:(prefect,etl,PROD)"
assert actual_datajob.name == task_run_ctx.task.name
Expand Down

0 comments on commit b796f59

Please sign in to comment.