Skip to content

Commit

Permalink
feat: check computeplans before running function (#997)
Browse files Browse the repository at this point in the history
* feat: add `image_builder.check_function_is_runnable`

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* feat: add tests

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* fix: add annotation to tests

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* feat: change order which check CP status

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* feat: add  `Function.cancel()`

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* feat: add `BuildCanceledError`

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* fix: use real uuid in `test_check_function_is_runnable`

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* fix: db call on `test_check_function_is_runnable`

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

* doc: change fragment

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>

---------

Signed-off-by: Guilhem Barthés <guilhem.barthes@owkin.com>
  • Loading branch information
guilhem-barthes authored Sep 30, 2024
1 parent b6506e2 commit 07d9311
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 4 deletions.
6 changes: 3 additions & 3 deletions backend/api/models/computeplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def get_task_stats(self) -> dict:
def update_status(self) -> None:
"""Compute cp status from tasks counts."""
stats = self.get_task_stats()
if stats["task_count"] == 0 or stats["waiting_builder_slot_count"] == stats["task_count"]:
if self.cancelation_date or stats["canceled_count"] > 0:
compute_plan_status = self.Status.PLAN_STATUS_CANCELED
elif stats["task_count"] == 0 or stats["waiting_builder_slot_count"] == stats["task_count"]:
compute_plan_status = self.Status.PLAN_STATUS_CREATED
elif stats["done_count"] == stats["task_count"]:
compute_plan_status = self.Status.PLAN_STATUS_DONE
elif stats["failed_count"] > 0:
compute_plan_status = self.Status.PLAN_STATUS_FAILED
elif self.cancelation_date or stats["canceled_count"] > 0:
compute_plan_status = self.Status.PLAN_STATUS_CANCELED
else:
compute_plan_status = self.Status.PLAN_STATUS_DOING

Expand Down
4 changes: 4 additions & 0 deletions backend/api/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ class Function(models.Model, AssetPermissionMixin):

class Meta:
ordering = ["creation_date", "key"] # default order for relations serializations

def cancel(self) -> None:
self.status = Function.Status.FUNCTION_STATUS_CANCELED
self.save()
4 changes: 4 additions & 0 deletions backend/builder/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ class BuildRetryError(_BuildError, CeleryRetryError):
Args:
logs (str): the container image build logs
"""


class BuildCanceledError(CeleryNoRetryError):
"""A function built has been cancelled (for instance, all the linked ocmpute plans has been cancelled or failed)"""
21 changes: 21 additions & 0 deletions backend/builder/image_builder/image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from django.conf import settings

import orchestrator
from api.models import ComputePlan
from api.models import Function
from builder import docker
from builder import exceptions
from builder.exceptions import BuildError
Expand Down Expand Up @@ -348,3 +350,22 @@ def _build_container_args(dockerfile_mount_path: str, image_tag: str) -> list[st
if REGISTRY_SCHEME == "http":
args.append("--insecure-pull")
return args


def check_function_is_runnable(function_key: str, channel_name: str) -> bool:
compute_plans_statuses = set(
ComputePlan.objects.filter(compute_tasks__function__key=function_key, channel=channel_name)
.values_list("status", flat=True)
.distinct()
)

if len(compute_plans_statuses) == 0:
return True

target_statuses = {ComputePlan.Status.PLAN_STATUS_CANCELED, ComputePlan.Status.PLAN_STATUS_FAILED}
is_runnable = not compute_plans_statuses.issubset(target_statuses)

if not is_runnable:
Function.objects.get(key=function_key).cancel()

return is_runnable
12 changes: 12 additions & 0 deletions backend/builder/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any

import structlog
from billiard.einfo import ExceptionInfo
from django.conf import settings

import orchestrator
from builder.exceptions import BuildCanceledError
from substrapp.models import FailedAssetKind
from substrapp.tasks.task import FailableTask

Expand Down Expand Up @@ -36,3 +40,11 @@ def get_task_info(self, args: tuple, kwargs: dict) -> tuple[str, str]:
function = orchestrator.Function.model_validate_json(kwargs["function_serialized"])
channel_name = kwargs["channel_name"]
return function.key, channel_name

def on_failure(
self, exc: Exception, task_id: str, args: tuple, kwargs: dict[str, Any], einfo: ExceptionInfo
) -> None:
if isinstance(exc, BuildCanceledError):
return

super().on_failure(exc, task_id, args, kwargs, einfo)
5 changes: 4 additions & 1 deletion backend/builder/tasks/tasks_build_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import orchestrator
from backend.celery import app
from builder.exceptions import BuildCanceledError
from builder.exceptions import BuildError
from builder.exceptions import BuildRetryError
from builder.exceptions import CeleryNoRetryError
Expand All @@ -24,9 +25,11 @@ def build_image(task: BuildTask, function_serialized: str, channel_name: str) ->
timer = Timer()
attempt = 0
while attempt <= task.max_retries:
if not image_builder.check_function_is_runnable(function.key, channel_name):
logger.info("build has been canceled", function_id=function.key)
raise BuildCanceledError
try:
timer.start()

image_builder.build_image_if_missing(channel_name, function)

with orchestrator.get_orchestrator_client(channel_name) as client:
Expand Down
32 changes: 32 additions & 0 deletions backend/builder/tests/test_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytest_mock import MockerFixture

import orchestrator
from api.models import ComputePlan
from builder.exceptions import BuildError
from builder.exceptions import BuildRetryError
from builder.exceptions import PodTimeoutError
Expand Down Expand Up @@ -85,3 +86,34 @@ def test_get_entrypoint_from_dockerfile_invalid_dockerfile(
image_builder._get_entrypoint_from_dockerfile(str(tmp_path))

assert expected_exc_content in bytes.decode(exc.value.logs.read())


@pytest.mark.parametrize(
["statuses", "is_function_runnable"],
[
([], True),
([ComputePlan.Status.PLAN_STATUS_DONE.value], True),
([ComputePlan.Status.PLAN_STATUS_FAILED.value, ComputePlan.Status.PLAN_STATUS_CANCELED.value], False),
(
[
ComputePlan.Status.PLAN_STATUS_DONE.value,
ComputePlan.Status.PLAN_STATUS_FAILED.value,
ComputePlan.Status.PLAN_STATUS_CANCELED.value,
],
True,
),
],
ids=["no cp", "done cp", "failed + canceled cp", "done + failed + canceled cp"],
)
def test_check_function_is_runnable(mocker: MockerFixture, statuses: str, is_function_runnable: bool) -> None:
function_key = "e7f8aed4-f2c9-442d-a02c-8b7858a2ac4f"
channel_name = "channel_name"
compute_plan_getter = mocker.patch("builder.image_builder.image_builder.ComputePlan.objects.filter")
function_cancel = mocker.patch("builder.image_builder.image_builder.Function.objects.get")
compute_plan_getter.return_value.values_list.return_value.distinct.return_value = statuses
result = image_builder.check_function_is_runnable(function_key=function_key, channel_name=channel_name)

assert result == is_function_runnable
compute_plan_getter.assert_called_once_with(compute_tasks__function__key=function_key, channel=channel_name)
if not is_function_runnable:
function_cancel.assert_called_once_with(key=function_key)
4 changes: 4 additions & 0 deletions backend/builder/tests/test_task_build_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_store_failure_build_error():
assert failure_report.logs.read() == str.encode(msg)


@pytest.mark.django_db
def test_catch_all_exceptions(celery_app, celery_worker, mocker):
mocker.patch("builder.tasks.task.orchestrator.get_orchestrator_client")
mocker.patch("builder.image_builder.image_builder.build_image_if_missing", side_effect=Exception("random error"))
Expand All @@ -39,6 +40,7 @@ def test_catch_all_exceptions(celery_app, celery_worker, mocker):
r.get()


@pytest.mark.django_db
@pytest.mark.parametrize("execution_number", range(10))
def test_order_building_success(celery_app, celery_worker, mocker, execution_number):
function_1 = orc_mock.FunctionFactory()
Expand All @@ -63,6 +65,7 @@ def test_order_building_success(celery_app, celery_worker, mocker, execution_num
assert result_2.state == "WAITING"


@pytest.mark.django_db
@pytest.mark.parametrize("execution_number", range(10))
def test_order_building_retry(celery_app, celery_worker, mocker, execution_number):
function_retry = orc_mock.FunctionFactory()
Expand Down Expand Up @@ -100,6 +103,7 @@ def side_effect(*args, **kwargs):
assert result_other.state == "WAITING"


@pytest.mark.django_db
def test_ssl_connection_timeout(celery_app, celery_worker, mocker):
"""
Test that in case of a SSL connection timeout, the task is retried max_retries times,
Expand Down
2 changes: 2 additions & 0 deletions changes/997.added
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Check if function is linked with compute plans (through the compute tasks) before building. If all compute plans have been cancelled or failed, cancels the function.

0 comments on commit 07d9311

Please sign in to comment.