From 29a4b0faf1d52fd48277ac55a55ade9acbf42400 Mon Sep 17 00:00:00 2001 From: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com> Date: Thu, 2 Nov 2023 14:36:30 +0100 Subject: [PATCH] feat: resolve relative datasample spec path (#392) Signed-off-by: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com> --- CHANGELOG.md | 3 +++ substra/sdk/models.py | 9 ++++++--- substra/sdk/schemas.py | 33 ++++++++++++++++++++++++++++----- tests/sdk/test_schemas.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 8 deletions(-) create mode 100644 tests/sdk/test_schemas.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3219d8b6..25a45045 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - BREAKING: Renamed `function` field of the Function pydantic model to `archive`([#393](https://github.com/Substra/substra/pull/393)) +### Added + +- Paths are now resolved on DatasampleSpec objects. Which means that users can pass relative paths ([#392](https://github.com/Substra/substra/pull/392)) ## [0.49.0](https://github.com/Substra/substra/releases/tag/0.49.0) - 2023-10-18 diff --git a/substra/sdk/models.py b/substra/sdk/models.py index 1164a1cf..dfa94e8d 100644 --- a/substra/sdk/models.py +++ b/substra/sdk/models.py @@ -187,7 +187,8 @@ def allowed_filters() -> List[str]: return ["key", "name", "owner", "permissions", "compute_plan_key", "dataset_key", "data_sample_key"] @pydantic.field_validator("inputs", mode="before") - def dict_input_to_list(cls, v): # noqa: N805 + @classmethod + def dict_input_to_list(cls, v): if isinstance(v, dict): # Transform the inputs dict to a list return [ @@ -203,7 +204,8 @@ def dict_input_to_list(cls, v): # noqa: N805 return v @pydantic.field_validator("outputs", mode="before") - def dict_output_to_list(cls, v): # noqa: N805 + @classmethod + def dict_output_to_list(cls, v): if isinstance(v, dict): # Transform the outputs dict to a list return [ @@ -397,7 +399,8 @@ class OutputAsset(_TaskAsset): # Deal with remote returning the actual performance object @pydantic.field_validator("asset", mode="before") - def convert_remote_performance(cls, value, values): # noqa: N805 + @classmethod + def convert_remote_performance(cls, value, values): if values.data.get("kind") == schemas.AssetKind.performance and isinstance(value, dict): return value.get("performance_value") diff --git a/substra/sdk/schemas.py b/substra/sdk/schemas.py index b5a715f9..bb9a0a3d 100644 --- a/substra/sdk/schemas.py +++ b/substra/sdk/schemas.py @@ -147,7 +147,8 @@ def is_many(self): return self.paths and len(self.paths) > 0 @pydantic.model_validator(mode="before") - def exclusive_paths(cls, values): # noqa: N805 + @classmethod + def exclusive_paths(cls, values: typing.Any) -> typing.Any: """Check that one and only one path(s) field is defined.""" if "paths" in values and "path" in values: raise ValueError("'path' and 'paths' fields are exclusive.") @@ -155,6 +156,24 @@ def exclusive_paths(cls, values): # noqa: N805 raise ValueError("'path' or 'paths' field must be set.") return values + @pydantic.model_validator(mode="before") + @classmethod + def resolve_paths(cls, values: typing.Any) -> typing.Any: + """Resolve given path is relative.""" + if "paths" in values: + paths = [] + for path in values["paths"]: + path = pathlib.Path(path) + paths.append(path.resolve()) + + values["paths"] = paths + + elif "path" in values: + path = pathlib.Path(values["path"]) + values["path"] = path.resolve() + + return values + @contextlib.contextmanager def build_request_kwargs(self, local): # redefine kwargs builder to handle the local paths @@ -293,7 +312,8 @@ class FunctionInputSpec(_Spec): kind: AssetKind @pydantic.model_validator(mode="before") - def _check_identifiers(cls, values): # noqa: N805 + @classmethod + def _check_identifiers(cls, values): """Checks that the multiplicity and the optionality of a data manager is always set to False""" if values["kind"] == AssetKind.data_manager: if values["multiple"]: @@ -327,7 +347,8 @@ class FunctionOutputSpec(_Spec): multiple: bool @pydantic.model_validator(mode="before") - def _check_performance(cls, values): # noqa: N805 + @classmethod + def _check_performance(cls, values): """Checks that the performance is always set to False""" if values == AssetKind.performance and values["multiple"]: raise ValueError("Performance can't be multiple.") @@ -352,7 +373,8 @@ class FunctionSpec(_Spec): type_: typing.ClassVar[Type] = Type.Function @pydantic.field_validator("inputs") - def _check_inputs(cls, v): # noqa: N805 + @classmethod + def _check_inputs(cls, v): inputs = v or [] identifiers = {value.identifier for value in inputs} if len(identifiers) != len(inputs): @@ -360,7 +382,8 @@ def _check_inputs(cls, v): # noqa: N805 return v @pydantic.field_validator("outputs") - def _check_outputs(cls, v): # noqa: N805 + @classmethod + def _check_outputs(cls, v): outputs = v or [] identifiers = {value.identifier for value in outputs} if len(identifiers) != len(outputs): diff --git a/tests/sdk/test_schemas.py b/tests/sdk/test_schemas.py new file mode 100644 index 00000000..359a33d5 --- /dev/null +++ b/tests/sdk/test_schemas.py @@ -0,0 +1,30 @@ +import pathlib +import uuid + +import pytest + +from substra.sdk.schemas import DataSampleSpec + + +@pytest.mark.parametrize("path", [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"]) +def test_datasample_spec_resolve_path(path): + datasample_spec = DataSampleSpec(path=path, data_manager_keys=[str(uuid.uuid4())]) + + assert datasample_spec.path == pathlib.Path().cwd() / "data" + + +def test_datasample_spec_resolve_paths(): + paths = [pathlib.Path() / "data", "./data", pathlib.Path().cwd() / "data"] + datasample_spec = DataSampleSpec(paths=paths, data_manager_keys=[str(uuid.uuid4())]) + + assert all([path == pathlib.Path().cwd() / "data" for path in datasample_spec.paths]) + + +def test_datasample_spec_exclusive_path(): + with pytest.raises(ValueError): + DataSampleSpec(paths=["fake_paths"], path="fake_paths", data_manager_keys=[str(uuid.uuid4())]) + + +def test_datasample_spec_no_path(): + with pytest.raises(ValueError): + DataSampleSpec(data_manager_keys=[str(uuid.uuid4())])