Skip to content

Commit

Permalink
Remove worker and runner handling for deployment concurrency (#15497)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanluciano authored Sep 27, 2024
1 parent 953b733 commit 8e0410b
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 344 deletions.
42 changes: 5 additions & 37 deletions src/prefect/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def fast_flow():
from prefect.concurrency.asyncio import (
AcquireConcurrencySlotTimeoutError,
ConcurrencySlotAcquisitionError,
concurrency,
)
from prefect.events import DeploymentTriggerTypes, TriggerTypes
from prefect.events.related import tags_as_related_resources
Expand All @@ -92,7 +91,6 @@ def fast_flow():
get_current_settings,
)
from prefect.states import (
AwaitingConcurrencySlot,
Crashed,
Pending,
exception_to_failed_state,
Expand Down Expand Up @@ -1047,22 +1045,12 @@ async def _submit_run_and_capture_errors(
) -> Union[Optional[int], Exception]:
run_logger = self._get_flow_run_logger(flow_run)

if flow_run.deployment_id:
deployment = await self._client.read_deployment(flow_run.deployment_id)
if deployment and deployment.global_concurrency_limit:
limit_name = deployment.global_concurrency_limit.name
concurrency_ctx = concurrency
else:
limit_name = ""
concurrency_ctx = asyncnullcontext

try:
async with concurrency_ctx(limit_name, max_retries=0, strict=True):
status_code = await self._run_process(
flow_run=flow_run,
task_status=task_status,
entrypoint=entrypoint,
)
status_code = await self._run_process(
flow_run=flow_run,
task_status=task_status,
entrypoint=entrypoint,
)
except (
AcquireConcurrencySlotTimeoutError,
ConcurrencySlotAcquisitionError,
Expand Down Expand Up @@ -1164,26 +1152,6 @@ async def _propose_failed_state(self, flow_run: "FlowRun", exc: Exception) -> No
exc_info=True,
)

async def _propose_scheduled_state(self, flow_run: "FlowRun") -> None:
run_logger = self._get_flow_run_logger(flow_run)
try:
state = await propose_state(
self._client,
AwaitingConcurrencySlot(),
flow_run_id=flow_run.id,
)
self._logger.info(f"Flow run {flow_run.id} now has state {state.name}")
except Abort as exc:
run_logger.info(
(
f"Aborted rescheduling of flow run '{flow_run.id}'. "
f"Server sent an abort signal: {exc}"
),
)
pass
except Exception:
run_logger.exception(f"Failed to update state of flow run '{flow_run.id}'")

async def _propose_crashed_state(self, flow_run: "FlowRun", message: str) -> None:
run_logger = self._get_flow_run_logger(flow_run)
try:
Expand Down
61 changes: 6 additions & 55 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from prefect.client.schemas.actions import WorkPoolCreate, WorkPoolUpdate
from prefect.client.schemas.objects import StateType, WorkPool
from prefect.client.utilities import inject_client
from prefect.concurrency.asyncio import (
AcquireConcurrencySlotTimeoutError,
ConcurrencySlotAcquisitionError,
concurrency,
)
from prefect.events import Event, RelatedResource, emit_event
from prefect.events.related import object_as_related_resource, tags_as_related_resources
from prefect.exceptions import (
Expand All @@ -41,12 +36,10 @@
get_current_settings,
)
from prefect.states import (
AwaitingConcurrencySlot,
Crashed,
Pending,
exception_to_failed_state,
)
from prefect.utilities.asyncutils import asyncnullcontext
from prefect.utilities.dispatch import get_registry_for_type, register_base_type
from prefect.utilities.engine import propose_state
from prefect.utilities.services import critical_service_loop
Expand Down Expand Up @@ -865,42 +858,15 @@ async def _submit_run_and_capture_errors(
self, flow_run: "FlowRun", task_status: Optional[anyio.abc.TaskStatus] = None
) -> Union[BaseWorkerResult, Exception]:
run_logger = self.get_flow_run_logger(flow_run)
deployment = None

if flow_run.deployment_id:
deployment = await self._client.read_deployment(flow_run.deployment_id)
if deployment and deployment.global_concurrency_limit:
limit_name = deployment.global_concurrency_limit.name
concurrency_ctx = concurrency
else:
limit_name = ""
concurrency_ctx = asyncnullcontext

try:
async with concurrency_ctx(limit_name, max_retries=0, strict=True):
configuration = await self._get_configuration(flow_run, deployment)
submitted_event = self._emit_flow_run_submitted_event(configuration)
result = await self.run(
flow_run=flow_run,
task_status=task_status,
configuration=configuration,
)
except (
AcquireConcurrencySlotTimeoutError,
ConcurrencySlotAcquisitionError,
) as exc:
self._logger.info(
(
"Deployment %s has reached its concurrency limit when submitting flow run %s"
),
flow_run.deployment_id,
flow_run.name,
configuration = await self._get_configuration(flow_run)
submitted_event = self._emit_flow_run_submitted_event(configuration)
result = await self.run(
flow_run=flow_run,
task_status=task_status,
configuration=configuration,
)
await self._propose_scheduled_state(flow_run)

if not task_status._future.done():
task_status.started(exc)
return exc
except Exception as exc:
if not task_status._future.done():
# This flow run was being submitted and did not start successfully
Expand Down Expand Up @@ -1026,21 +992,6 @@ async def _propose_pending_state(self, flow_run: "FlowRun") -> bool:

return True

async def _propose_scheduled_state(self, flow_run: "FlowRun") -> None:
run_logger = self.get_flow_run_logger(flow_run)
try:
state = await propose_state(
self._client,
AwaitingConcurrencySlot(),
flow_run_id=flow_run.id,
)
self._logger.info(f"Flow run {flow_run.id} now has state {state.name}")
except Abort:
# Flow run already marked as failed
pass
except Exception:
run_logger.exception(f"Failed to update state of flow run '{flow_run.id}'")

async def _propose_failed_state(self, flow_run: "FlowRun", exc: Exception) -> None:
run_logger = self.get_flow_run_logger(flow_run)
try:
Expand Down
35 changes: 0 additions & 35 deletions tests/fixtures/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,41 +1111,6 @@ def hello(name: str = "world"):
return deployment


@pytest.fixture
async def worker_deployment_wq1_cl1(
session,
flow,
flow_function,
work_queue_1,
):
def hello(name: str = "world"):
pass

deployment = await models.deployments.create_deployment(
session=session,
deployment=schemas.core.Deployment(
name="My Deployment 1",
tags=["test"],
flow_id=flow.id,
schedules=[
schemas.actions.DeploymentScheduleCreate(
schedule=schemas.schedules.IntervalSchedule(
interval=datetime.timedelta(days=1),
anchor_date=pendulum.datetime(2020, 1, 1),
)
)
],
concurrency_limit=2,
path="./subdir",
entrypoint="/file.py:flow",
parameter_openapi_schema=parameter_schema(hello).model_dump_for_openapi(),
work_queue_id=work_queue_1.id,
),
)
await session.commit()
return deployment


@pytest.fixture
async def worker_deployment_infra_wq1(session, flow, flow_function, work_queue_1):
def hello(name: str = "world"):
Expand Down
131 changes: 0 additions & 131 deletions tests/runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
from prefect.client.schemas.actions import DeploymentScheduleCreate
from prefect.client.schemas.objects import ConcurrencyLimitConfig, StateType
from prefect.client.schemas.schedules import CronSchedule, IntervalSchedule
from prefect.concurrency.asyncio import (
AcquireConcurrencySlotTimeoutError,
_acquire_concurrency_slots,
_release_concurrency_slots,
)
from prefect.deployments.runner import (
DeploymentApplyError,
EntrypointType,
Expand Down Expand Up @@ -636,132 +631,6 @@ async def test_runner_respects_set_limit(
flow_run = await prefect_client.read_flow_run(flow_run_id=bad_run.id)
assert flow_run.state.is_completed()

@pytest.mark.usefixtures("use_hosted_api_server")
async def test_runner_enforces_deployment_concurrency_limits(
self, prefect_client: PrefectClient, caplog
):
concurrency_limit_config = ConcurrencyLimitConfig(limit=42)

async def test(*args, **kwargs):
return 0

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=_acquire_concurrency_slots,
) as acquire_spy:
with mock.patch(
"prefect.concurrency.asyncio._release_concurrency_slots",
wraps=_release_concurrency_slots,
) as release_spy:
async with Runner(pause_on_shutdown=False) as runner:
deployment = RunnerDeployment.from_flow(
flow=dummy_flow_1,
name=__file__,
concurrency_limit=concurrency_limit_config,
)

deployment_id = await runner.add_deployment(deployment)

flow_run = await prefect_client.create_flow_run_from_deployment(
deployment_id=deployment_id
)

assert flow_run.state.is_scheduled()

runner.run = test # simulate running a flow

await runner._get_and_submit_flow_runs()

acquire_spy.assert_called_once_with(
[f"deployment:{deployment_id}"],
1,
timeout_seconds=None,
create_if_missing=None,
max_retries=0,
strict=True,
)

names, occupy, occupy_seconds = release_spy.call_args[0]
assert names == [f"deployment:{deployment_id}"]
assert occupy == 1
assert occupy_seconds > 0

@pytest.mark.usefixtures("use_hosted_api_server")
async def test_runner_proposes_awaiting_concurrency_limit_state_name(
self, prefect_client: PrefectClient, caplog
):
async def test(*args, **kwargs):
return 0

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=_acquire_concurrency_slots,
) as acquire_spy:
# Simulate a Locked response from the API
acquire_spy.side_effect = AcquireConcurrencySlotTimeoutError

async with Runner(pause_on_shutdown=False) as runner:
deployment = RunnerDeployment.from_flow(
flow=dummy_flow_1,
name=__file__,
concurrency_limit=2,
)

deployment_id = await runner.add_deployment(deployment)

flow_run = await prefect_client.create_flow_run_from_deployment(
deployment_id=deployment_id
)

assert flow_run.state.is_scheduled()

runner.run = test # simulate running a flow

await runner._get_and_submit_flow_runs()

acquire_spy.assert_called_once_with(
[f"deployment:{deployment_id}"],
1,
timeout_seconds=None,
create_if_missing=None,
max_retries=0,
strict=True,
)

flow_run = await prefect_client.read_flow_run(flow_run.id)
assert flow_run.state.name == "AwaitingConcurrencySlot"

@pytest.mark.usefixtures("use_hosted_api_server")
async def test_runner_does_not_attempt_to_acquire_limit_if_deployment_has_no_concurrency_limit(
self, prefect_client: PrefectClient, caplog
):
async def test(*args, **kwargs):
return 0

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=_acquire_concurrency_slots,
) as acquire_spy:
async with Runner(pause_on_shutdown=False) as runner:
deployment = RunnerDeployment.from_flow(
flow=dummy_flow_1,
name=__file__,
)

deployment_id = await runner.add_deployment(deployment)

flow_run = await prefect_client.create_flow_run_from_deployment(
deployment_id=deployment_id
)

assert flow_run.state.is_scheduled()

runner.run = test # simulate running a flow

await runner._get_and_submit_flow_runs()

acquire_spy.assert_not_called()

async def test_handles_spaces_in_sys_executable(self, monkeypatch, prefect_client):
"""
Regression test for https://github.com/PrefectHQ/prefect/issues/10820
Expand Down
Loading

0 comments on commit 8e0410b

Please sign in to comment.