Skip to content

Commit

Permalink
feat: resolve relative datasample spec path (#392)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com>
  • Loading branch information
ThibaultFy committed Jan 11, 2024
1 parent 96d6f60 commit 29a4b0f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- BREAKING: Renamed `function` field of the Function pydantic model to `archive`([#393](https://github.com/Substra/substra/pull/393))

### Added

- Paths are now resolved on DatasampleSpec objects. Which means that users can pass relative paths ([#392](https://github.com/Substra/substra/pull/392))

## [0.49.0](https://github.com/Substra/substra/releases/tag/0.49.0) - 2023-10-18

Expand Down
9 changes: 6 additions & 3 deletions substra/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def allowed_filters() -> List[str]:
return ["key", "name", "owner", "permissions", "compute_plan_key", "dataset_key", "data_sample_key"]

@pydantic.field_validator("inputs", mode="before")
def dict_input_to_list(cls, v): # noqa: N805
@classmethod
def dict_input_to_list(cls, v):
if isinstance(v, dict):
# Transform the inputs dict to a list
return [
Expand All @@ -203,7 +204,8 @@ def dict_input_to_list(cls, v): # noqa: N805
return v

@pydantic.field_validator("outputs", mode="before")
def dict_output_to_list(cls, v): # noqa: N805
@classmethod
def dict_output_to_list(cls, v):
if isinstance(v, dict):
# Transform the outputs dict to a list
return [
Expand Down Expand Up @@ -397,7 +399,8 @@ class OutputAsset(_TaskAsset):

# Deal with remote returning the actual performance object
@pydantic.field_validator("asset", mode="before")
def convert_remote_performance(cls, value, values): # noqa: N805
@classmethod
def convert_remote_performance(cls, value, values):
if values.data.get("kind") == schemas.AssetKind.performance and isinstance(value, dict):
return value.get("performance_value")

Expand Down
33 changes: 28 additions & 5 deletions substra/sdk/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,33 @@ def is_many(self):
return self.paths and len(self.paths) > 0

@pydantic.model_validator(mode="before")
def exclusive_paths(cls, values): # noqa: N805
@classmethod
def exclusive_paths(cls, values: typing.Any) -> typing.Any:
"""Check that one and only one path(s) field is defined."""
if "paths" in values and "path" in values:
raise ValueError("'path' and 'paths' fields are exclusive.")
if "paths" not in values and "path" not in values:
raise ValueError("'path' or 'paths' field must be set.")
return values

@pydantic.model_validator(mode="before")
@classmethod
def resolve_paths(cls, values: typing.Any) -> typing.Any:
"""Resolve given path is relative."""
if "paths" in values:
paths = []
for path in values["paths"]:
path = pathlib.Path(path)
paths.append(path.resolve())

values["paths"] = paths

elif "path" in values:
path = pathlib.Path(values["path"])
values["path"] = path.resolve()

return values

@contextlib.contextmanager
def build_request_kwargs(self, local):
# redefine kwargs builder to handle the local paths
Expand Down Expand Up @@ -293,7 +312,8 @@ class FunctionInputSpec(_Spec):
kind: AssetKind

@pydantic.model_validator(mode="before")
def _check_identifiers(cls, values): # noqa: N805
@classmethod
def _check_identifiers(cls, values):
"""Checks that the multiplicity and the optionality of a data manager is always set to False"""
if values["kind"] == AssetKind.data_manager:
if values["multiple"]:
Expand Down Expand Up @@ -327,7 +347,8 @@ class FunctionOutputSpec(_Spec):
multiple: bool

@pydantic.model_validator(mode="before")
def _check_performance(cls, values): # noqa: N805
@classmethod
def _check_performance(cls, values):
"""Checks that the performance is always set to False"""
if values == AssetKind.performance and values["multiple"]:
raise ValueError("Performance can't be multiple.")
Expand All @@ -352,15 +373,17 @@ class FunctionSpec(_Spec):
type_: typing.ClassVar[Type] = Type.Function

@pydantic.field_validator("inputs")
def _check_inputs(cls, v): # noqa: N805
@classmethod
def _check_inputs(cls, v):
inputs = v or []
identifiers = {value.identifier for value in inputs}
if len(identifiers) != len(inputs):
raise ValueError("Several function inputs cannot have the same identifier.")
return v

@pydantic.field_validator("outputs")
def _check_outputs(cls, v): # noqa: N805
@classmethod
def _check_outputs(cls, v):
outputs = v or []
identifiers = {value.identifier for value in outputs}
if len(identifiers) != len(outputs):
Expand Down
30 changes: 30 additions & 0 deletions tests/sdk/test_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pathlib
import uuid

import pytest

from substra.sdk.schemas import DataSampleSpec


@pytest.mark.parametrize("path", [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"])
def test_datasample_spec_resolve_path(path):
datasample_spec = DataSampleSpec(path=path, data_manager_keys=[str(uuid.uuid4())])

assert datasample_spec.path == pathlib.Path().cwd() / "data"


def test_datasample_spec_resolve_paths():
paths = [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"]
datasample_spec = DataSampleSpec(paths=paths, data_manager_keys=[str(uuid.uuid4())])

assert all([path == pathlib.Path().cwd() / "data" for path in datasample_spec.paths])


def test_datasample_spec_exclusive_path():
with pytest.raises(ValueError):
DataSampleSpec(paths=["fake_paths"], path="fake_paths", data_manager_keys=[str(uuid.uuid4())])


def test_datasample_spec_no_path():
with pytest.raises(ValueError):
DataSampleSpec(data_manager_keys=[str(uuid.uuid4())])

0 comments on commit 29a4b0f

Please sign in to comment.