Skip to content

Commit

Permalink
[sub]feat: modify computetask failure report (#727)
Browse files Browse the repository at this point in the history
## Companion PR

- Substra/orchestrator#277
- Substra/substra-frontend#240

## Description

The aim is to allow registering failure reports not only for compute
task but for other kind of assets (for now, functions which are not
building as part of the execution of a compute task)

- Modifies `ComputeTaskFailureReport`:
    - renamed the model to `AssetFailureReport`
- renamed field `compute_task_key` to `asset_key` (as we can now have a
function key)
    - added field `asset_type` to provide 
- Updates protobuf reflecting the previous changes
- refactor `download_file` in `PermissionMixin` to provide mroe
flexibility (and decouple from DRF)
- create new `FailableTask` (Celery task):
  - centralize the logic to submit asset failure

## How has this been tested?

As this is going to be merged on a branch that is going to be merged to
a POC branch, we use MNIST as a baseline of a working model. We will
deal with failing tests on the POC before merging on main.

## Checklist

- [x] [changelog](../CHANGELOG.md) was updated with notable changes
- [ ] documentation was updated

---------

Signed-off-by: Guilhem Barthes <guilhem.barthes@owkin.com>
  • Loading branch information
guilhem-barthes authored Sep 8, 2023
1 parent 1a2c9b8 commit 947f49e
Show file tree
Hide file tree
Showing 36 changed files with 639 additions and 315 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- New `SECRET_KEY` optional environment variable ([#671](https://github.com/Substra/substra-backend/pull/671))
- `/api-token-auth/` and the associated tokens can now be disabled through the `EXPIRY_TOKEN_ENABLED` environment variable and `server.allowImplicitLogin` chart value ([#698](https://github.com/Substra/substra-backend/pull/698))
- Tokens issued by `/api-token-auth/` can now be deleted like other API tokens, through a `DELETE` request on the `/active-api-tokens` endpoint ([#698](https://github.com/Substra/substra-backend/pull/698))
- Field `asset_type` on `AssetFailureReport` (based on protobuf enum `orchestrator.FailedAssetKind`) ([#727](https://github.com/Substra/substra-backend/pull/727))
- Celery task `FailableTask` that contains the logic to store the failure report, that can be re-used in different
assets. ([#727](https://github.com/Substra/substra-backend/pull/727))

### Changed

- Increase the number of tasks displayable in frontend workflow [#697](https://github.com/Substra/substra-backend/pull/697)
- BREAKING: Change the format of many API responses from `{"message":...}` to `{"detail":...}` ([#705](https://github.com/Substra/substra-backend/pull/705))
- `ComputeTaskFailureReport` renamed in `AssetFailureReport` ([#727](https://github.com/Substra/substra-backend/pull/727))
- Field `AssetFailureReport.compute_task_key` renamed to `asset_key` ([#727](https://github.com/Substra/substra-backend/pull/727))

### Removed

Expand Down
15 changes: 13 additions & 2 deletions backend/api/events/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from api.serializers import PerformanceSerializer
from orchestrator import client as orc_client
from orchestrator import computetask
from orchestrator import failure_report_pb2

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -89,7 +90,7 @@ def _on_update_function_event(event: dict) -> None:
_update_function(key=event["asset_key"], name=function["name"], status=function["status"])


def _update_function(key: str, *, name: Optional[str], status: Optional[str]) -> None:
def _update_function(key: str, *, name: Optional[str] = None, status: Optional[str] = None) -> None:
"""Process update function event to update local database."""
function = Function.objects.get(key=key)

Expand Down Expand Up @@ -382,7 +383,17 @@ def _disable_model(key: str) -> None:
def _on_create_failure_report(event: dict) -> None:
"""Process create failure report event to update local database."""
logger.debug("Syncing failure report create", asset_key=event["asset_key"], event_id=event["id"])
_update_computetask(key=event["asset_key"], failure_report=event["failure_report"])

asset_key = event["asset_key"]
failure_report = event["failure_report"]
asset_type = failure_report_pb2.FailedAssetKind.Value(failure_report["asset_type"])

if asset_type == failure_report_pb2.FAILED_ASSET_FUNCTION:
# Needed as this field is only in ComputeTask
compute_task_key = ComputeTask.objects.values_list("key", flat=True).get(function_id=asset_key)
_update_computetask(key=str(compute_task_key), failure_report={"error_type": failure_report.get("error_type")})
else:
_update_computetask(key=asset_key, failure_report=failure_report)


EVENT_CALLBACKS = {
Expand Down
4 changes: 2 additions & 2 deletions backend/api/migrations/0053_function_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ class Migration(migrations.Migration):
name="status",
field=models.CharField(
choices=[
("FUNCTION_STATUS_UNKONWN", "Function Status Unkonwn"),
("FUNCTION_STATUS_UNKNOWN", "Function Status Unknown"),
("FUNCTION_STATUS_CREATED", "Function Status Created"),
("FUNCTION_STATUS_BUILDING", "Function Status Building"),
("FUNCTION_STATUS_READY", "Function Status Ready"),
("FUNCTION_STATUS_CANCELED", "Function Status Canceled"),
("FUNCTION_STATUS_FAILED", "Function Status Failed"),
],
default="FUNCTION_STATUS_UNKONWN",
default="FUNCTION_STATUS_UNKNOWN",
max_length=64,
),
preserve_default=False,
Expand Down
36 changes: 27 additions & 9 deletions backend/api/tests/asset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import datetime
import uuid
from typing import Optional

from django.core import files
from django.utils import timezone
Expand All @@ -80,9 +81,10 @@
from api.models import Model
from api.models import Performance
from api.models import TaskProfiling
from substrapp.models import ComputeTaskFailureReport as ComputeTaskLogs
from substrapp.models import AssetFailureReport
from substrapp.models import DataManager as DataManagerFiles
from substrapp.models import DataSample as DataSampleFiles
from substrapp.models import FailedAssetKind
from substrapp.models import Function as FunctionFiles
from substrapp.models import Model as ModelFiles
from substrapp.utils import get_hash
Expand Down Expand Up @@ -535,20 +537,36 @@ def create_model_files(
return model_files


def create_computetask_logs(
compute_task_key: uuid.UUID,
logs: files.File = None,
) -> ComputeTaskLogs:
def create_asset_logs(
asset_key: uuid.UUID,
asset_type: FailedAssetKind,
logs: Optional[files.File] = None,
) -> AssetFailureReport:
if logs is None:
logs = files.base.ContentFile("dummy content")

compute_task_logs = ComputeTaskLogs.objects.create(
compute_task_key=compute_task_key,
asset_logs = AssetFailureReport.objects.create(
asset_key=asset_key,
asset_type=asset_type,
logs_checksum=get_hash(logs),
creation_date=timezone.now(),
)
compute_task_logs.logs.save("logs", logs)
return compute_task_logs
asset_logs.logs.save("logs", logs)
return asset_logs


def create_computetask_logs(
compute_task_key: uuid.UUID,
logs: Optional[files.File] = None,
) -> AssetFailureReport:
return create_asset_logs(compute_task_key, FailedAssetKind.FAILED_ASSET_COMPUTE_TASK, logs)


def create_function_logs(
function_key: uuid.UUID,
logs: Optional[files.File] = None,
) -> AssetFailureReport:
return create_asset_logs(function_key, FailedAssetKind.FAILED_ASSET_FUNCTION, logs)


def create_computetask_profiling(compute_task: ComputeTask) -> TaskProfiling:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from api.views import utils as view_utils
from organization import authentication as organization_auth
from organization import models as organization_models
from substrapp.models import ComputeTaskFailureReport
from substrapp.models import AssetFailureReport


@pytest.fixture
def compute_task_failure_report() -> tuple[ComputeTask, ComputeTaskFailureReport]:
def asset_failure_report() -> tuple[ComputeTask, AssetFailureReport]:
compute_task = factory.create_computetask(
factory.create_computeplan(),
factory.create_function(),
Expand All @@ -41,12 +41,12 @@ def test_download_logs_failure_unauthenticated(api_client: test.APIClient):

@pytest.mark.django_db
def test_download_local_logs_success(
compute_task_failure_report,
asset_failure_report,
authenticated_client: test.APIClient,
):
"""An authorized user download logs located on the organization."""

compute_task, failure_report = compute_task_failure_report
compute_task, failure_report = asset_failure_report
assert compute_task.owner == conf.settings.LEDGER_MSP_ID # local
assert conf.settings.LEDGER_MSP_ID in compute_task.logs_permission_authorized_ids # allowed

Expand All @@ -60,12 +60,12 @@ def test_download_local_logs_success(

@pytest.mark.django_db
def test_download_logs_failure_forbidden(
compute_task_failure_report,
asset_failure_report,
authenticated_client: test.APIClient,
):
"""An authenticated user cannot download logs if he is not authorized."""

compute_task, failure_report = compute_task_failure_report
compute_task, failure_report = asset_failure_report
assert compute_task.owner == conf.settings.LEDGER_MSP_ID # local
compute_task.logs_permission_authorized_ids = [] # not allowed
compute_task.save()
Expand All @@ -77,12 +77,12 @@ def test_download_logs_failure_forbidden(

@pytest.mark.django_db
def test_download_local_logs_failure_not_found(
compute_task_failure_report,
asset_failure_report,
authenticated_client: test.APIClient,
):
"""An authorized user attempt to download logs that are not referenced in the database."""

compute_task, failure_report = compute_task_failure_report
compute_task, failure_report = asset_failure_report
assert compute_task.owner == conf.settings.LEDGER_MSP_ID # local
assert conf.settings.LEDGER_MSP_ID in compute_task.logs_permission_authorized_ids # allowed
failure_report.delete() # not found
Expand All @@ -94,12 +94,12 @@ def test_download_local_logs_failure_not_found(

@pytest.mark.django_db
def test_download_remote_logs_success(
compute_task_failure_report,
asset_failure_report,
authenticated_client: test.APIClient,
):
"""An authorized user download logs on a remote organization by using his organization as proxy."""

compute_task, failure_report = compute_task_failure_report
compute_task, failure_report = asset_failure_report
outgoing_organization = "outgoing-organization"
compute_task.logs_owner = outgoing_organization # remote
compute_task.logs_permission_authorized_ids = [conf.settings.LEDGER_MSP_ID, outgoing_organization] # allowed
Expand Down Expand Up @@ -139,13 +139,13 @@ def get_proxy_headers(channel_name: str) -> dict[str, str]:

@pytest.mark.django_db
def test_organization_download_logs_success(
compute_task_failure_report,
asset_failure_report,
api_client: test.APIClient,
incoming_organization_user: organization_auth.OrganizationUser,
):
"""An authorized organization can download logs from another organization."""

compute_task, failure_report = compute_task_failure_report
compute_task, failure_report = asset_failure_report
compute_task.logs_owner = conf.settings.LEDGER_MSP_ID # local (incoming request from remote)
compute_task.logs_permission_authorized_ids = [
conf.settings.LEDGER_MSP_ID,
Expand All @@ -166,13 +166,13 @@ def test_organization_download_logs_success(

@pytest.mark.django_db
def test_organization_download_logs_forbidden(
compute_task_failure_report,
asset_failure_report,
api_client: test.APIClient,
incoming_organization_user: organization_auth.OrganizationUser,
):
"""An unauthorized organization cannot download logs from another organization."""

compute_task, failure_report = compute_task_failure_report
compute_task, failure_report = asset_failure_report
compute_task.logs_owner = conf.settings.LEDGER_MSP_ID # local (incoming request from remote)
compute_task.logs_permission_authorized_ids = [conf.settings.LEDGER_MSP_ID] # incoming user not allowed
compute_task.channel = incoming_organization_user.username
Expand Down
2 changes: 1 addition & 1 deletion backend/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
router.register(r"compute_plan_metadata", views.ComputePlanMetadataViewSet, basename="compute_plan_metadata")
router.register(r"news_feed", views.NewsFeedViewSet, basename="news_feed")
router.register(r"performance", views.PerformanceViewSet, basename="performance")
router.register(r"logs", views.ComputeTaskLogsViewSet, basename="logs")
router.register(r"logs", views.FailedAssetLogsViewSet, basename="logs")
router.register(r"task_profiling", views.TaskProfilingViewSet, basename="task_profiling")

task_profiling_router = routers.NestedDefaultRouter(router, r"task_profiling", lookup="task_profiling")
Expand Down
4 changes: 2 additions & 2 deletions backend/api/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from .computeplan import ComputePlanViewSet
from .computetask import ComputeTaskViewSet
from .computetask import CPTaskViewSet
from .computetask_logs import ComputeTaskLogsViewSet
from .datamanager import DataManagerPermissionViewSet
from .datamanager import DataManagerViewSet
from .datasample import DataSampleViewSet
from .failed_asset_logs import FailedAssetLogsViewSet
from .function import CPFunctionViewSet
from .function import FunctionPermissionViewSet
from .function import FunctionViewSet
Expand All @@ -24,14 +24,14 @@
"DataManagerPermissionViewSet",
"ModelViewSet",
"ModelPermissionViewSet",
"FailedAssetLogsViewSet",
"FunctionViewSet",
"FunctionPermissionViewSet",
"ComputeTaskViewSet",
"ComputePlanViewSet",
"CPTaskViewSet",
"CPFunctionViewSet",
"NewsFeedViewSet",
"ComputeTaskLogsViewSet",
"CPPerformanceViewSet",
"ComputePlanMetadataViewSet",
"PerformanceViewSet",
Expand Down
18 changes: 0 additions & 18 deletions backend/api/views/computetask_logs.py

This file was deleted.

16 changes: 14 additions & 2 deletions backend/api/views/datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,20 @@ class DataManagerPermissionViewSet(PermissionMixin, GenericViewSet):

@action(detail=True, url_path="description", url_name="description")
def description_(self, request, *args, **kwargs):
return self.download_file(request, DataManager, "description", "description_address")
return self.download_file(
request,
asset_class=DataManager,
local_file_class=DataManagerFiles,
content_field="description",
address_field="description_address",
)

@action(detail=True)
def opener(self, request, *args, **kwargs):
return self.download_file(request, DataManager, "data_opener", "opener_address")
return self.download_file(
request,
asset_class=DataManager,
local_file_class=DataManagerFiles,
content_field="data_opener",
address_field="opener_address",
)
41 changes: 41 additions & 0 deletions backend/api/views/failed_asset_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from rest_framework import response as drf_response
from rest_framework import status
from rest_framework import viewsets
from rest_framework.decorators import action

from api.errors import AssetPermissionError
from api.models import ComputeTask
from api.models import Function
from api.views import utils as view_utils
from substrapp.models import asset_failure_report


class FailedAssetLogsViewSet(view_utils.PermissionMixin, viewsets.GenericViewSet):
queryset = asset_failure_report.AssetFailureReport.objects.all()

@action(detail=True, url_path=asset_failure_report.LOGS_FILE_PATH)
def file(self, request, pk=None) -> drf_response.Response:
report = self.get_object()
channel_name = view_utils.get_channel_name(request)
if report.asset_type == asset_failure_report.FailedAssetKind.FAILED_ASSET_FUNCTION:
asset_class = Function
else:
asset_class = ComputeTask

try:
asset = self.get_asset(request, report.key, channel_name, asset_class)
except AssetPermissionError as e:
return view_utils.ApiResponse({"detail": str(e)}, status=status.HTTP_403_FORBIDDEN)

response = view_utils.get_file_response(
local_file_class=asset_failure_report.AssetFailureReport,
key=report.key,
content_field="logs",
channel_name=channel_name,
url=report.logs_address,
asset_owner=asset.get_owner(),
)

response.headers["Content-Type"] = "text/plain; charset=utf-8"
response.headers["Content-Disposition"] = f'attachment; filename="tuple_logs_{pk}.txt"'
return response
16 changes: 14 additions & 2 deletions backend/api/views/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,27 @@ class FunctionPermissionViewSet(PermissionMixin, GenericViewSet):

@action(detail=True)
def file(self, request, *args, **kwargs):
return self.download_file(request, Function, "file", "function_address")
return self.download_file(
request,
asset_class=Function,
local_file_class=FunctionFiles,
content_field="file",
address_field="function_address",
)

# actions cannot be named "description"
# https://github.com/encode/django-rest-framework/issues/6490
# for some of the restricted names see:
# https://www.django-rest-framework.org/api-guide/viewsets/#introspecting-viewset-actions
@action(detail=True, url_path="description", url_name="description")
def description_(self, request, *args, **kwargs):
return self.download_file(request, Function, "description", "description_address")
return self.download_file(
request,
asset_class=Function,
local_file_class=FunctionFiles,
content_field="description",
address_field="description_address",
)

@action(detail=True)
def image(self, request, *args, **kwargs):
Expand Down
Loading

0 comments on commit 947f49e

Please sign in to comment.