diff --git a/substra/sdk/schemas.py b/substra/sdk/schemas.py index bb9a0a3d..06683e7b 100644 --- a/substra/sdk/schemas.py +++ b/substra/sdk/schemas.py @@ -146,6 +146,24 @@ class DataSampleSpec(_Spec): def is_many(self): return self.paths and len(self.paths) > 0 + @pydantic.field_validator("paths") + @classmethod + def resolve_paths(cls, v: List[pathlib.Path]) -> List[pathlib.Path]: + """Resolve given paths.""" + if v is None: + raise ValueError("'paths' cannot be set to None.") + + return [p.resolve() for p in v] + + @pydantic.field_validator("path") + @classmethod + def resolve_path(cls, v: pathlib.Path) -> pathlib.Path: + """Resolve given path.""" + if v is None: + raise ValueError("'path' cannot be set to None.") + + return v.resolve() + @pydantic.model_validator(mode="before") @classmethod def exclusive_paths(cls, values: typing.Any) -> typing.Any: @@ -156,24 +174,6 @@ def exclusive_paths(cls, values: typing.Any) -> typing.Any: raise ValueError("'path' or 'paths' field must be set.") return values - @pydantic.model_validator(mode="before") - @classmethod - def resolve_paths(cls, values: typing.Any) -> typing.Any: - """Resolve given path is relative.""" - if "paths" in values: - paths = [] - for path in values["paths"]: - path = pathlib.Path(path) - paths.append(path.resolve()) - - values["paths"] = paths - - elif "path" in values: - path = pathlib.Path(values["path"]) - values["path"] = path.resolve() - - return values - @contextlib.contextmanager def build_request_kwargs(self, local): # redefine kwargs builder to handle the local paths diff --git a/tests/sdk/test_schemas.py b/tests/sdk/test_schemas.py index 359a33d5..9c50c757 100644 --- a/tests/sdk/test_schemas.py +++ b/tests/sdk/test_schemas.py @@ -28,3 +28,13 @@ def test_datasample_spec_exclusive_path(): def test_datasample_spec_no_path(): with pytest.raises(ValueError): DataSampleSpec(data_manager_keys=[str(uuid.uuid4())]) + + +def test_datasample_spec_paths_set_to_none(): + with pytest.raises(ValueError): + DataSampleSpec(paths=None, data_manager_keys=[str(uuid.uuid4())]) + + +def test_datasample_spec_path_set_to_none(): + with pytest.raises(ValueError): + DataSampleSpec(path=None, data_manager_keys=[str(uuid.uuid4())])