Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send worker metadata with first heartbeat #15898

Merged
merged 10 commits into from
Nov 4, 2024
20 changes: 11 additions & 9 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
TaskRunResult,
Variable,
Worker,
WorkerMetadata,
WorkPool,
WorkQueue,
WorkQueueStatusDetail,
Expand Down Expand Up @@ -2596,6 +2597,7 @@ async def send_worker_heartbeat(
worker_name: str,
heartbeat_interval_seconds: Optional[float] = None,
get_worker_id: bool = False,
worker_metadata: Optional[WorkerMetadata] = None,
) -> Optional[UUID]:
"""
Sends a worker heartbeat for a given work pool.
Expand All @@ -2604,20 +2606,20 @@ async def send_worker_heartbeat(
work_pool_name: The name of the work pool to heartbeat against.
worker_name: The name of the worker sending the heartbeat.
return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`.
worker_metadata: Metadata about the worker to send to the server.
"""

params = {
"name": worker_name,
"heartbeat_interval_seconds": heartbeat_interval_seconds,
}
if worker_metadata:
params["worker_metadata"] = worker_metadata.model_dump(mode="json")
if get_worker_id:
return_dict = {"return_id": get_worker_id}
else:
return_dict = {}
params["return_id"] = get_worker_id

resp = await self._client.post(
f"/work_pools/{work_pool_name}/workers/heartbeat",
json={
"name": worker_name,
"heartbeat_interval_seconds": heartbeat_interval_seconds,
}
| return_dict,
json=params,
)

if (
Expand Down
21 changes: 21 additions & 0 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,3 +1689,24 @@ class CsrfToken(ObjectBaseModel):


__getattr__ = getattr_migration(__name__)


class Integration(PrefectBaseModel):
"""A representation of an installed Prefect integration."""

name: str = Field(description="The name of the Prefect integration.")
version: str = Field(description="The version of the Prefect integration.")


class WorkerMetadata(PrefectBaseModel):
"""
Worker metadata.

We depend on the structure of `integrations`, but otherwise, worker classes
should support flexible metadata.
"""

integrations: List[Integration] = Field(
default=..., description="Prefect integrations installed in the worker."
)
model_config = ConfigDict(extra="allow")
94 changes: 68 additions & 26 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import anyio.abc
import httpx
import pendulum
from importlib_metadata import distributions
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Literal
Expand All @@ -19,7 +20,12 @@
from prefect.client.base import ServerType
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas.actions import WorkPoolCreate, WorkPoolUpdate
from prefect.client.schemas.objects import StateType, WorkPool
from prefect.client.schemas.objects import (
Integration,
StateType,
WorkerMetadata,
WorkPool,
)
from prefect.client.utilities import inject_client
from prefect.events import Event, RelatedResource, emit_event
from prefect.events.related import object_as_related_resource, tags_as_related_resources
Expand Down Expand Up @@ -438,6 +444,7 @@ def __init__(
self._submitting_flow_run_ids = set()
self._cancelling_flow_run_ids = set()
self._scheduled_task_scopes = set()
self._worker_metadata_sent = False

@classmethod
def get_documentation_url(cls) -> str:
Expand Down Expand Up @@ -717,47 +724,81 @@ async def _update_local_work_pool_info(self):

self._work_pool = work_pool

async def _send_worker_heartbeat(
self, get_worker_id: bool = False
) -> Optional[UUID]:
async def _worker_metadata(self) -> Optional[WorkerMetadata]:
"""
Sends a heartbeat to the API.

If `get_worker_id` is True, the worker ID will be retrieved from the API.
Returns metadata about installed Prefect collections for the worker.
"""
if self._work_pool:
return await self._client.send_worker_heartbeat(
work_pool_name=self._work_pool_name,
worker_name=self.name,
heartbeat_interval_seconds=self.heartbeat_interval_seconds,
get_worker_id=get_worker_id,
)
installed_integrations = load_prefect_collections().keys()

async def sync_with_backend(self):
integration_versions = [
Integration(name=dist.metadata["Name"], version=dist.version)
for dist in distributions()
# PyPI packages often use dashes, but Python package names use underscores
# because they must be valid identifiers.
if dist.metadata.get("Name").replace("_", "-") in installed_integrations
]

if integration_versions:
return WorkerMetadata(integrations=integration_versions)
return None

async def _send_worker_heartbeat(self) -> Optional[UUID]:
"""
Updates the worker's local information about it's current work pool and
queues. Sends a worker heartbeat to the API.
Sends a heartbeat to the API.
"""
await self._update_local_work_pool_info()
if not self._client:
self._logger.warning("Client has not been initialized; skipping heartbeat.")
return None
if not self._work_pool:
self._logger.debug("Worker has no work pool; skipping heartbeat.")
return None

should_get_worker_id = self._should_get_worker_id()

params = {
"work_pool_name": self._work_pool_name,
"worker_name": self.name,
"heartbeat_interval_seconds": self.heartbeat_interval_seconds,
"get_worker_id": should_get_worker_id,
}
if (
self._client.server_type == ServerType.CLOUD
and not self._worker_metadata_sent
):
worker_metadata = await self._worker_metadata()
if worker_metadata:
params["worker_metadata"] = worker_metadata
self._worker_metadata_sent = True

worker_id = None
try:
remote_id = await self._send_worker_heartbeat(
get_worker_id=(self._should_get_worker_id())
)
worker_id = await self._client.send_worker_heartbeat(**params)
except httpx.HTTPStatusError as e:
if e.response.status_code == 422 and self._should_get_worker_id():
if e.response.status_code == 422 and should_get_worker_id:
self._logger.warning(
"Failed to retrieve worker ID from the Prefect API server."
)
await self._send_worker_heartbeat(get_worker_id=False)
remote_id = None
params["get_worker_id"] = False
worker_id = await self._client.send_worker_heartbeat(**params)
else:
raise e

if self._should_get_worker_id() and remote_id is None:
if should_get_worker_id and worker_id is None:
self._logger.warning(
"Failed to retrieve worker ID from the Prefect API server."
)
elif self.backend_id is None and remote_id is not None:

return worker_id

async def sync_with_backend(self):
"""
Updates the worker's local information about it's current work pool and
queues. Sends a worker heartbeat to the API.
"""
await self._update_local_work_pool_info()

remote_id = await self._send_worker_heartbeat()
if remote_id:
self.backend_id = remote_id
self._logger = get_worker_logger(self)

Expand All @@ -769,6 +810,7 @@ def _should_get_worker_id(self):
"""Determines if the worker should request an ID from the API server."""
return (
get_current_settings().experiments.worker_logging_to_api_enabled
and self._client
and self._client.server_type == ServerType.CLOUD
and self.backend_id is None
)
Expand Down
60 changes: 60 additions & 0 deletions tests/client/test_prefect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Generator, List
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock
from uuid import UUID, uuid4

Expand Down Expand Up @@ -55,9 +56,11 @@
Flow,
FlowRunNotificationPolicy,
FlowRunPolicy,
Integration,
StateType,
TaskRun,
Variable,
WorkerMetadata,
WorkQueue,
)
from prefect.client.schemas.responses import (
Expand All @@ -69,6 +72,7 @@
from prefect.client.utilities import inject_client
from prefect.events import AutomationCore, EventTrigger, Posture
from prefect.server.api.server import create_app
from prefect.server.database.orm_models import WorkPool
from prefect.settings import (
PREFECT_API_DATABASE_MIGRATE_ON_START,
PREFECT_API_KEY,
Expand Down Expand Up @@ -2698,3 +2702,59 @@ def test_raise_for_api_version_mismatch_with_incompatible_versions(
f"Found incompatible versions: client: {client_version}, server: {api_version}. "
in str(e.value)
)


class TestPrefectClientWorkerHeartbeat:
async def test_worker_heartbeat(
self, prefect_client: PrefectClient, work_pool: WorkPool
):
work_pool_name = str(work_pool.name)
await prefect_client.send_worker_heartbeat(
work_pool_name=work_pool_name,
worker_name="test-worker",
heartbeat_interval_seconds=10,
)
workers = await prefect_client.read_workers_for_work_pool(work_pool_name)
assert len(workers) == 1
assert workers[0].name == "test-worker"
assert workers[0].heartbeat_interval_seconds == 10

async def test_worker_heartbeat_sends_metadata_if_passed(
self, prefect_client: PrefectClient
):
with mock.patch(
"prefect.client.orchestration.PrefectHttpxAsyncClient.post",
return_value=httpx.Response(status_code=204),
) as mock_post:
await prefect_client.send_worker_heartbeat(
work_pool_name="work-pool",
worker_name="test-worker",
heartbeat_interval_seconds=10,
worker_metadata=WorkerMetadata(
integrations=[Integration(name="prefect-aws", version="1.0.0")]
),
)
assert mock_post.call_args[1]["json"] == {
"name": "test-worker",
"heartbeat_interval_seconds": 10,
"worker_metadata": {
"integrations": [{"name": "prefect-aws", "version": "1.0.0"}]
},
}

async def test_worker_heartbeat_does_not_send_metadata_if_not_passed(
self, prefect_client: PrefectClient
):
with mock.patch(
"prefect.client.orchestration.PrefectHttpxAsyncClient.post",
return_value=httpx.Response(status_code=204),
) as mock_post:
await prefect_client.send_worker_heartbeat(
work_pool_name="work-pool",
worker_name="test-worker",
heartbeat_interval_seconds=10,
)
assert mock_post.call_args[1]["json"] == {
"name": "test-worker",
"heartbeat_interval_seconds": 10,
}
Loading