Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds in the ability for the client to get a remote_id from the server #15849

Merged
merged 16 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/3.0/develop/settings-ref.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ Settings for for controlling API client behavior

**TOML dotted key path**: `deployments`

### `experiments`
Settings for controlling experimental features

**Type**: [ExperimentsSettings](#experimentssettings)

**TOML dotted key path**: `experiments`

### `flows`

**Type**: [FlowsSettings](#flowssettings)
Expand Down Expand Up @@ -454,7 +461,18 @@ The default Docker namespace to use when building images.
`PREFECT_DEPLOYMENTS_DEFAULT_DOCKER_BUILD_NAMESPACE`, `PREFECT_DEFAULT_DOCKER_BUILD_NAMESPACE`

---
## ExperimentsSettings
Settings for configuring experimental features
### `worker_logging_to_api_enabled`
Enables the logging of worker logs to Prefect Cloud.

**Type**: `boolean`

**Default**: `False`

**TOML dotted key path**: `experiments.worker_logging_to_api_enabled`

---
## FlowsSettings
Settings for controlling flow behavior
### `default_retries`
Expand Down
9 changes: 8 additions & 1 deletion schemas/settings.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,14 @@
},
"ExperimentsSettings": {
"description": "Settings for configuring experimental features",
"properties": {},
"properties": {
"worker_logging_to_api_enabled": {
"default": false,
"description": "Enables the logging of worker logs to Prefect Cloud.",
"title": "Worker Logging To Api Enabled",
"type": "boolean"
}
},
"title": "ExperimentsSettings",
"type": "object"
},
Expand Down
28 changes: 25 additions & 3 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
PREFECT_CLOUD_API_URL,
PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
PREFECT_TESTING_UNIT_TEST_MODE,
get_current_settings,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -2594,22 +2595,43 @@ async def send_worker_heartbeat(
work_pool_name: str,
worker_name: str,
heartbeat_interval_seconds: Optional[float] = None,
):
get_worker_id: bool = False,
) -> Optional[UUID]:
"""
Sends a worker heartbeat for a given work pool.

Args:
work_pool_name: The name of the work pool to heartbeat against.
worker_name: The name of the worker sending the heartbeat.
return_id: Whether to return the worker ID.
sam-phinizy marked this conversation as resolved.
Show resolved Hide resolved
"""
await self._client.post(

if get_worker_id:
return_dict = {"return_id": get_worker_id}
else:
return_dict = {}

resp = await self._client.post(
f"/work_pools/{work_pool_name}/workers/heartbeat",
json={
"name": worker_name,
"heartbeat_interval_seconds": heartbeat_interval_seconds,
},
}
| return_dict,
)

if (
(
self.server_type == ServerType.CLOUD
or get_current_settings().testing.test_mode
)
and get_worker_id
and resp.status_code == 200
):
return UUID(resp.text)
else:
return None

async def read_workers_for_work_pool(
self,
work_pool_name: str,
Expand Down
7 changes: 7 additions & 0 deletions src/prefect/settings/models/experiments.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pydantic import Field

from prefect.settings.base import PrefectBaseSettings, PrefectSettingsConfigDict


Expand All @@ -13,3 +15,8 @@ class ExperimentsSettings(PrefectBaseSettings):
toml_file="prefect.toml",
prefect_toml_table_header=("experiments",),
)

worker_logging_to_api_enabled: bool = Field(
default=False,
description="Enables the logging of worker logs to Prefect Cloud.",
)
41 changes: 36 additions & 5 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import AsyncExitStack
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, Union
from uuid import uuid4
from uuid import UUID, uuid4

import anyio
import anyio.abc
Expand All @@ -15,6 +15,7 @@

import prefect
from prefect._internal.schemas.validators import return_v_or_none
from prefect.client.base import ServerType
from prefect.client.orchestration import PrefectClient, get_client
from prefect.client.schemas.actions import WorkPoolCreate, WorkPoolUpdate
from prefect.client.schemas.objects import StateType, WorkPool
Expand Down Expand Up @@ -421,6 +422,7 @@ def __init__(
heartbeat_interval_seconds or PREFECT_WORKER_HEARTBEAT_SECONDS.value()
)

self.backend_id: Optional[UUID] = None
self._work_pool: Optional[WorkPool] = None
self._exit_stack: AsyncExitStack = AsyncExitStack()
self._runs_task_group: Optional[anyio.abc.TaskGroup] = None
Expand Down Expand Up @@ -710,12 +712,20 @@ async def _update_local_work_pool_info(self):

self._work_pool = work_pool

async def _send_worker_heartbeat(self):
async def _send_worker_heartbeat(
self, get_worker_id: bool = False
) -> Optional[UUID]:
"""
Sends a heartbeat to the API.

If `get_worker_id` is True, the worker ID will be retrieved from the API.
"""
if self._work_pool:
await self._client.send_worker_heartbeat(
return await self._client.send_worker_heartbeat(
work_pool_name=self._work_pool_name,
worker_name=self.name,
heartbeat_interval_seconds=self.heartbeat_interval_seconds,
get_worker_id=get_worker_id,
)

async def sync_with_backend(self):
Expand All @@ -724,10 +734,31 @@ async def sync_with_backend(self):
queues. Sends a worker heartbeat to the API.
"""
await self._update_local_work_pool_info()
# Only do this logic if we've enabled the experiment, are connected to cloud and we don't have an ID.
if (
get_current_settings().experiments.worker_logging_to_api_enabled
and (
self._client.server_type == ServerType.CLOUD
or get_current_settings().testing.test_mode
)
and self.backend_id is None
):
get_worker_id = True
else:
get_worker_id = False

await self._send_worker_heartbeat()
remote_id = await self._send_worker_heartbeat(get_worker_id=get_worker_id)

self._logger.debug("Worker synchronized with the Prefect API server.")
if get_worker_id and remote_id is None:
self._logger.warning(
"Failed to retrieve worker ID from the Prefect API server."
)
else:
self.backend_id = remote_id

self._logger.debug(
f"Worker synchronized with the Prefect API server. Remote ID: {self.backend_id}"
)

async def _get_scheduled_flow_runs(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@
},
"PREFECT_EVENTS_WEBSOCKET_BACKFILL_PAGE_SIZE": {"test_value": 10, "legacy": True},
"PREFECT_EXPERIMENTAL_WARN": {"test_value": True},
"PREFECT_EXPERIMENTS_WORKER_LOGGING_TO_API_ENABLED": {"test_value": False},
"PREFECT_FLOW_DEFAULT_RETRIES": {"test_value": 10, "legacy": True},
"PREFECT_FLOWS_DEFAULT_RETRIES": {"test_value": 10},
"PREFECT_FLOW_DEFAULT_RETRY_DELAY_SECONDS": {"test_value": 10, "legacy": True},
Expand Down
47 changes: 47 additions & 0 deletions tests/workers/test_base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Any, Dict, Optional, Type
from unittest.mock import MagicMock

import httpx
import pendulum
import pytest
from packaging import version
from pydantic import Field
from starlette import status

import prefect
import prefect.client.schemas as schemas
Expand All @@ -23,6 +25,7 @@
from prefect.server.schemas.responses import DeploymentResponse
from prefect.settings import (
PREFECT_API_URL,
PREFECT_EXPERIMENTS_WORKER_LOGGING_TO_API_ENABLED,
PREFECT_TEST_MODE,
PREFECT_WORKER_PREFETCH_SECONDS,
get_current_settings,
Expand Down Expand Up @@ -80,6 +83,14 @@ def no_api_url():
yield


@pytest.fixture
def experimental_logging_enabled():
with temporary_settings(
updates={PREFECT_EXPERIMENTS_WORKER_LOGGING_TO_API_ENABLED: True}
):
yield


async def test_worker_requires_api_url_when_not_in_test_mode(no_api_url):
with pytest.raises(ValueError, match="PREFECT_API_URL"):
async with WorkerTestImpl(
Expand Down Expand Up @@ -161,6 +172,42 @@ async def test_worker_sends_heartbeat_messages(
assert second_heartbeat > first_heartbeat


async def test_worker_sends_heartbeat_gets_id(experimental_logging_enabled, respx_mock):
work_pool_name = "test-work-pool"
test_worker_id = uuid.UUID("028EC481-5899-49D7-B8C5-37A2726E9840")
async with WorkerTestImpl(name="test", work_pool_name=work_pool_name) as worker:
# Pass through the non-relevant paths
respx_mock.get(f"api/work_pools/{work_pool_name}").pass_through()
respx_mock.get("api/csrf-token?").pass_through()
respx_mock.post("api/work_pools/").pass_through()
respx_mock.patch(f"api/work_pools/{work_pool_name}").pass_through()

respx_mock.post(
f"api/work_pools/{work_pool_name}/workers/heartbeat",
).mock(
return_value=httpx.Response(status.HTTP_200_OK, text=str(test_worker_id))
)

await worker.sync_with_backend()

assert worker.backend_id == test_worker_id


async def test_worker_sends_heartbeat_only_gets_id_once(
experimental_logging_enabled,
):
async with WorkerTestImpl(name="test", work_pool_name="test-work-pool") as worker:
mock = AsyncMock(return_value="test")
setattr(worker._client, "send_worker_heartbeat", mock)
await worker.sync_with_backend()
await worker.sync_with_backend()

second_call = mock.await_args_list[1]

assert worker.backend_id == "test"
assert not second_call.kwargs["get_worker_id"]


async def test_worker_with_work_pool(
prefect_client: PrefectClient, worker_deployment_wq1, work_pool
):
Expand Down
Loading