diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py index 912f9f6c4..a7f2bdd1a 100644 --- a/src/hera/workflows/io/_io_mixins.py +++ b/src/hera/workflows/io/_io_mixins.py @@ -1,14 +1,16 @@ import sys import warnings -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self +from pydantic.fields import FieldInfo + from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields -from hera.shared._type_util import get_workflow_annotation, is_annotated +from hera.shared._type_util import get_workflow_annotation from hera.shared.serialization import MISSING, serialize from hera.workflows._context import _context from hera.workflows.artifact import Artifact @@ -39,6 +41,23 @@ BaseModel = object # type: ignore +def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]: + """Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations. + + If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing. + Otherwise, a Parameter object will be constructed. + """ + annotations = get_field_annotations(cls) + for field, field_info in get_fields(cls).items(): + if annotation := get_workflow_annotation(annotations[field]): + # Copy so as to not modify the fields themselves + annotation_copy = annotation.copy() + annotation_copy.name = annotation.name or field + yield field, field_info, annotation_copy + else: + yield field, field_info, Parameter(name=field) + + class InputMixin(BaseModel): def __new__(cls, **kwargs): if _context.declaring: @@ -62,14 +81,9 @@ def __init__(self, /, **kwargs): @classmethod def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]: parameters = [] - annotations = get_field_annotations(cls) - for field, field_info in get_fields(cls).items(): - if (param := get_workflow_annotation(annotations[field])) and isinstance(param, Parameter): - # Copy so as to not modify the Input fields themselves - param = param.copy() - if param.name is None: - param.name = field + for field, field_info, param in _construct_io_from_fields(cls): + if isinstance(param, Parameter): if param.default is not None: warnings.warn( "Using the default field for Parameters in Annotations is deprecated since v5.16" @@ -81,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet # Serialize the value (usually done in Parameter's validator) param.default = serialize(field_info.default) # type: ignore parameters.append(param) - elif not is_annotated(annotations[field]): - # Create a Parameter from basic type annotations - default = getattr(object_override, field) if object_override else field_info.default - - # For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None - if default is None or default == PydanticUndefined: - default = MISSING - - parameters.append(Parameter(name=field, default=default)) return parameters @classmethod def _get_artifacts(cls) -> List[Artifact]: artifacts = [] - annotations = get_field_annotations(cls) - for field in get_fields(cls): - if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact): - # Copy so as to not modify the Input fields themselves - artifact = artifact.copy() - if artifact.name is None: - artifact.name = field + for _, _, artifact in _construct_io_from_fields(cls): + if isinstance(artifact, Artifact): if artifact.path is None: artifact.path = artifact._get_default_inputs_path() artifacts.append(artifact) @@ -117,42 +117,33 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]: def _get_as_templated_arguments(cls) -> Self: """Returns the Input with templated values to propagate through a DAG/Steps function.""" object_dict = {} - cls_fields = get_fields(cls) - annotations = get_field_annotations(cls) - for field in cls_fields: - if param_or_artifact := get_workflow_annotation(annotations[field]): - if isinstance(param_or_artifact, Parameter): - object_dict[field] = "{{inputs.parameters." + f"{param_or_artifact.name}" + "}}" - else: - object_dict[field] = "{{inputs.artifacts." + f"{param_or_artifact.name}" + "}}" - elif not is_annotated(annotations[field]): - object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}" + for field, _, annotation in _construct_io_from_fields(cls): + input_type = "parameters" if isinstance(annotation, Parameter) else "artifacts" + object_dict[field] = "{{" + f"inputs.{input_type}.{annotation.name}" + "}}" return cls.construct(None, **object_dict) def _get_as_arguments(self) -> ModelArguments: params = [] artifacts = [] - annotations = get_field_annotations(type(self)) if isinstance(self, V1BaseModel): self_dict = self.dict() elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel): self_dict = self.model_dump() - for field in get_fields(type(self)): + for field, _, annotation in _construct_io_from_fields(type(self)): # The value may be a static value (of any time) if it has a default value, so we need to serialize it # If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")`` templated_value = serialize(self_dict[field]) + name = annotation.name + assert name is not None # guaranteed by _get_workflow_annotations - if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name: - if isinstance(param_or_artifact, Parameter): - params.append(ModelParameter(name=param_or_artifact.name, value=templated_value)) - else: - artifacts.append(ModelArtifact(name=param_or_artifact.name, from_=templated_value)) - elif not is_annotated(annotations[field]): - params.append(ModelParameter(name=field, value=templated_value)) + if isinstance(annotation, Parameter): + params.append(ModelParameter(name=name, value=templated_value)) + else: + artifacts.append(ModelArtifact(name=name, from_=templated_value)) return ModelArguments(parameters=params or None, artifacts=artifacts or None) @@ -178,37 +169,22 @@ def __init__(self, /, **kwargs): @classmethod def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]: outputs: List[Union[Artifact, Parameter]] = [] - annotations = get_field_annotations(cls) - - model_fields = get_fields(cls) - for field in model_fields: + for field, field_info, annotation in _construct_io_from_fields(cls): if field in {"exit_code", "result"}: continue - if param_or_artifact := get_workflow_annotation(annotations[field]): - if isinstance(param_or_artifact, Parameter): - if add_missing_path and ( - param_or_artifact.value_from is None or param_or_artifact.value_from.path is None - ): - param_or_artifact.value_from = ValueFrom( - path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}" - ) - outputs.append(param_or_artifact) - else: - if add_missing_path and param_or_artifact.path is None: - param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}" - outputs.append(param_or_artifact) - elif not is_annotated(annotations[field]): - # Create a Parameter from basic type annotations - default = model_fields[field].default - if default is None or default == PydanticUndefined: - default = MISSING - - value_from = None - if add_missing_path: - value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{field}") - - outputs.append(Parameter(name=field, default=default, value_from=value_from)) + if isinstance(annotation, Parameter): + if annotation.default is None: + default = field_info.default + if default is not None and default != PydanticUndefined: + annotation.default = serialize(default) + + if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None): + annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}") + else: + if add_missing_path and annotation.path is None: + annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}" + outputs.append(annotation) return outputs @classmethod @@ -230,27 +206,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]: This lets dags and steps hoist task/step outputs into its own outputs. """ outputs: List[Union[Artifact, Parameter]] = [] - annotations = get_field_annotations(type(self)) if isinstance(self, V1BaseModel): self_dict = self.dict() elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel): self_dict = self.model_dump() - for field in get_fields(type(self)): + for field, _, annotation in _construct_io_from_fields(type(self)): if field in {"exit_code", "result"}: continue templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"` - if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name: - if isinstance(param_or_artifact, Parameter): - outputs.append( - Parameter(name=param_or_artifact.name, value_from=ValueFrom(parameter=templated_value)) - ) - else: - outputs.append(Artifact(name=param_or_artifact.name, from_=templated_value)) - elif not is_annotated(annotations[field]): - outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value))) + if isinstance(annotation, Parameter): + outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value))) + else: + outputs.append(Artifact(name=annotation.name, from_=templated_value)) return outputs diff --git a/tests/test_unit/test_io_mixins.py b/tests/test_unit/test_io_mixins.py index 832d37ea3..f4f3e1161 100644 --- a/tests/test_unit/test_io_mixins.py +++ b/tests/test_unit/test_io_mixins.py @@ -1,21 +1,576 @@ -from hera.workflows.io import Input -from hera.workflows.parameter import Parameter +import sys -try: +if sys.version_info >= (3, 9): from typing import Annotated -except ImportError: +else: from typing_extensions import Annotated +from pydantic import Field -def test_input_mixin_get_parameters(): +from hera.workflows import Artifact, Input, Output, Parameter +from hera.workflows.models import ( + Arguments as ModelArguments, + Artifact as ModelArtifact, + Parameter as ModelParameter, + ValueFrom, +) + + +def test_get_parameters_unannotated(): + class Foo(Input): + foo: int + bar: str = "a default" + + assert Foo._get_parameters() == [ + Parameter(name="foo"), + Parameter(name="bar", default="a default"), + ] + + +def test_get_parameters_with_pydantic_annotations(): + class Foo(Input): + foo: Annotated[int, Field(gt=0)] + bar: Annotated[str, Field(max_length=10)] = "a default" + + assert Foo._get_parameters() == [ + Parameter(name="foo"), + Parameter(name="bar", default="a default"), + ] + + +def test_get_parameters_annotated_with_name(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo")] + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + assert Foo._get_parameters() == [ + Parameter(name="f_oo"), + Parameter(name="b_ar", default="a default"), + ] + + +def test_get_parameters_annotated_with_description(): + class Foo(Input): + foo: Annotated[int, Parameter(description="param foo")] + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] + + assert Foo._get_parameters() == [ + Parameter(name="foo", description="param foo"), + Parameter(name="bar", default="a default", description="param bar"), + ] + + +def test_get_parameters_with_multiple_annotations(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + assert Foo._get_parameters() == [ + Parameter(name="f_oo"), + Parameter(name="bar", default="a default", description="param bar"), + ] + + +def test_get_artifacts_unannotated(): + class Foo(Input): + foo: int + bar: str = "a default" + + assert Foo._get_artifacts() == [] + + +def test_get_artifacts_with_pydantic_annotations(): + class Foo(Input): + foo: Annotated[int, Field(gt=0)] + bar: Annotated[str, Field(max_length=10)] = "a default" + + assert Foo._get_artifacts() == [] + + +def test_get_artifacts_annotated_with_name(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo")] + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + assert Foo._get_artifacts() == [Artifact(name="b_az", path="/tmp/hera-inputs/artifacts/b_az")] + + +def test_get_artifacts_annotated_with_description(): + class Foo(Input): + foo: Annotated[int, Parameter(description="param foo")] + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] + + assert Foo._get_artifacts() == [ + Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz", description="artifact baz") + ] + + +def test_get_artifacts_annotated_with_path(): + class Foo(Input): + baz: Annotated[str, Artifact(path="/tmp/hera-inputs/artifacts/bishbosh")] + + assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/bishbosh")] + + +def test_get_artifacts_with_multiple_annotations(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + assert Foo._get_artifacts() == [Artifact(name="baz", path="/tmp/hera-inputs/artifacts/baz")] + + +def test_get_as_arguments_unannotated(): + class Foo(Input): + foo: int + bar: str = "a default" + + foo = Foo(foo=1) + parameters = foo._get_as_arguments() + + assert parameters == ModelArguments( + parameters=[ + ModelParameter(name="foo", value=1), + ModelParameter(name="bar", value="a default"), + ], + ) + + +def test_get_as_arguments_with_pydantic_annotations(): + class Foo(Input): + foo: Annotated[int, Field(gt=0)] + bar: Annotated[str, Field(max_length=10)] = "a default" + + foo = Foo(foo=1) + parameters = foo._get_as_arguments() + + assert parameters == ModelArguments( + parameters=[ + ModelParameter(name="foo", value=1), + ModelParameter(name="bar", value="a default"), + ] + ) + + +def test_get_as_arguments_annotated_with_name(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo")] + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + foo = Foo(foo=1, baz="previous step") + parameters = foo._get_as_arguments() + + assert parameters == ModelArguments( + artifacts=[ + ModelArtifact(name="b_az", from_="previous step"), + ], + parameters=[ + ModelParameter(name="f_oo", value=1), + ModelParameter(name="b_ar", value="a default"), + ], + ) + + +def test_get_as_arguments_annotated_with_description(): + class Foo(Input): + foo: Annotated[int, Parameter(description="param foo")] + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] + + foo = Foo(foo=1, baz="previous step") + parameters = foo._get_as_arguments() + + assert parameters == ModelArguments( + artifacts=[ + ModelArtifact(name="baz", from_="previous step"), + ], + parameters=[ + ModelParameter(name="foo", value=1), + ModelParameter(name="bar", value="a default"), + ], + ) + + +def test_get_as_arguments_with_multiple_annotations(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + foo = Foo(foo=1, baz="previous step") + parameters = foo._get_as_arguments() + + assert parameters == ModelArguments( + artifacts=[ + ModelArtifact(name="baz", from_="previous step"), + ], + parameters=[ + ModelParameter(name="f_oo", value=1), + ModelParameter(name="bar", value="a default"), + ], + ) + + +def test_get_as_templated_arguments_unannotated(): + class Foo(Input): + foo: int + bar: str = "a default" + + templated_arguments = Foo._get_as_templated_arguments() + + assert templated_arguments == Foo.construct( + foo="{{inputs.parameters.foo}}", + bar="{{inputs.parameters.bar}}", + ) + + +def test_get_as_templated_arguments_with_pydantic_annotations(): + class Foo(Input): + foo: Annotated[int, Field(gt=0)] + bar: Annotated[str, Field(max_length=10)] = "a default" + + templated_arguments = Foo._get_as_templated_arguments() + + assert templated_arguments == Foo.construct( + foo="{{inputs.parameters.foo}}", + bar="{{inputs.parameters.bar}}", + ) + + +def test_get_as_templated_arguments_annotated_with_name(): + class Foo(Input): + foo: Annotated[int, Parameter(name="f_oo")] + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + templated_arguments = Foo._get_as_templated_arguments() + + assert templated_arguments == Foo.construct( + foo="{{inputs.parameters.f_oo}}", + bar="{{inputs.parameters.b_ar}}", + baz="{{inputs.artifacts.b_az}}", + ) + + +def test_get_as_templated_arguments_annotated_with_description(): class Foo(Input): - foo: Annotated[int, Parameter(name="foo")] + foo: Annotated[int, Parameter(description="param foo")] + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] - assert Foo._get_parameters() == [Parameter(name="foo")] + templated_arguments = Foo._get_as_templated_arguments() + assert templated_arguments == Foo.construct( + foo="{{inputs.parameters.foo}}", + bar="{{inputs.parameters.bar}}", + baz="{{inputs.artifacts.baz}}", + ) -def test_input_mixin_get_parameters_default_name(): + +def test_get_as_templated_arguments_with_multiple_annotations(): class Foo(Input): - foo: Annotated[int, Parameter(description="a foo")] + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + templated_arguments = Foo._get_as_templated_arguments() + + assert templated_arguments == Foo.construct( + foo="{{inputs.parameters.f_oo}}", + bar="{{inputs.parameters.bar}}", + baz="{{inputs.artifacts.baz}}", + ) + + +def test_get_outputs_no_path_unannotated(): + class Foo(Output): + foo: int + fum: int = 5 + bar: str = "a default" + + parameters = Foo._get_outputs() + + assert parameters == [ + Parameter(name="foo"), + Parameter(name="fum", default=5), + Parameter(name="bar", default="a default"), + ] + + +def test_get_outputs_no_path_with_pydantic_annotations(): + class Foo(Output): + foo: Annotated[int, Field(gt=0)] + fum: Annotated[int, Field(lt=1000)] = 5 + bar: Annotated[str, Field(max_length=10)] = "a default" + + parameters = Foo._get_outputs() + + assert parameters == [ + Parameter(name="foo"), + Parameter(name="fum", default=5), + Parameter(name="bar", default="a default"), + ] + + +def test_get_outputs_no_path_annotated_with_name(): + class Foo(Output): + foo: Annotated[int, Parameter(name="f_oo")] + fum: Annotated[int, Parameter(name="f_um")] = 5 + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + parameters = Foo._get_outputs() + + assert parameters == [ + Parameter(name="f_oo"), + Parameter(name="f_um", default=5), + Parameter(name="b_ar", default="a default"), + Artifact(name="b_az"), + ] + + +def test_get_outputs_no_path_annotated_with_path(): + class Foo(Output): + foo: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/one"))] + fum: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/two"))] = 5 + bar: Annotated[str, Parameter(value_from=ValueFrom(path="/tmp/three"))] = "a default" + baz: Annotated[str, Artifact(path="/tmp/four")] + + parameters = Foo._get_outputs() + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(path="/tmp/one")), + Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/two")), + Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/three")), + Artifact(name="baz", path="/tmp/four"), + ] + + +def test_get_outputs_no_path_annotated_with_description(): + class Foo(Output): + foo: Annotated[int, Parameter(description="param foo")] + fum: Annotated[int, Parameter(description="param fum")] = 5 + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] + + parameters = Foo._get_outputs() + + assert parameters == [ + Parameter(name="foo", description="param foo"), + Parameter(name="fum", description="param fum", default=5), + Parameter(name="bar", default="a default", description="param bar"), + Artifact(name="baz", description="artifact baz"), + ] + + +def test_get_outputs_no_path_with_multiple_annotations(): + class Foo(Output): + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + fum: Annotated[int, Field(lt=10000), Parameter(name="f_um")] = 5 + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + parameters = Foo._get_outputs() + + assert parameters == [ + Parameter(name="f_oo"), + Parameter(name="f_um", default=5), + Parameter(name="bar", default="a default", description="param bar"), + Artifact(name="baz"), + ] + + +def test_get_outputs_add_path_unannotated(): + class Foo(Output): + foo: int + fum: int = 5 + bar: str = "a default" + + parameters = Foo._get_outputs(add_missing_path=True) + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/foo")), + Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/fum")), + Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar")), + ] + + +def test_get_outputs_add_path_with_pydantic_annotations(): + class Foo(Output): + foo: Annotated[int, Field(gt=0)] + fum: Annotated[int, Field(lt=1000)] = 5 + bar: Annotated[str, Field(max_length=10)] = "a default" + + parameters = Foo._get_outputs(add_missing_path=True) + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/foo")), + Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/fum")), + Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar")), + ] + + +def test_get_outputs_add_path_annotated_with_name(): + class Foo(Output): + foo: Annotated[int, Parameter(name="f_oo")] + fum: Annotated[int, Parameter(name="f_um")] = 5 + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + parameters = Foo._get_outputs(add_missing_path=True) + + assert parameters == [ + Parameter(name="f_oo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_oo")), + Parameter(name="f_um", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_um")), + Parameter(name="b_ar", default="a default", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/b_ar")), + Artifact(name="b_az", path="/tmp/hera-outputs/artifacts/b_az"), + ] + + +def test_get_outputs_add_path_annotated_with_path(): + class Foo(Output): + foo: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/one"))] + fum: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/two"))] = 5 + bar: Annotated[str, Parameter(value_from=ValueFrom(path="/tmp/three"))] = "a default" + baz: Annotated[str, Artifact(path="/tmp/four")] + + parameters = Foo._get_outputs(add_missing_path=True) + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(path="/tmp/one")), + Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/two")), + Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/three")), + Artifact(name="baz", path="/tmp/four"), + ] + + +def test_get_outputs_add_path_annotated_with_description(): + class Foo(Output): + foo: Annotated[int, Parameter(description="param foo")] + fum: Annotated[int, Parameter(description="param fum")] = 5 + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] + + parameters = Foo._get_outputs(add_missing_path=True) + + assert parameters == [ + Parameter(name="foo", description="param foo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/foo")), + Parameter( + name="fum", + description="param fum", + default=5, + value_from=ValueFrom(path="/tmp/hera-outputs/parameters/fum"), + ), + Parameter( + name="bar", + default="a default", + description="param bar", + value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar"), + ), + Artifact(name="baz", description="artifact baz", path="/tmp/hera-outputs/artifacts/baz"), + ] + + +def test_get_outputs_add_path_with_multiple_annotations(): + class Foo(Output): + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + fum: Annotated[int, Field(lt=10000), Parameter(name="f_um")] = 5 + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + parameters = Foo._get_outputs(add_missing_path=True) + + assert parameters == [ + Parameter(name="f_oo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_oo")), + Parameter(name="f_um", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_um")), + Parameter( + name="bar", + default="a default", + description="param bar", + value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar"), + ), + Artifact(name="baz", path="/tmp/hera-outputs/artifacts/baz"), + ] + + +def test_get_as_invocator_output_unannotated(): + class Foo(Output): + foo: int + bar: str = "a default" + + foo = Foo.construct(foo="{{...foo}}", bar="{{...bar}}") + parameters = foo._get_as_invocator_output() + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(parameter="{{...foo}}")), + Parameter(name="bar", value_from=ValueFrom(parameter="{{...bar}}")), + ] + + +def test_get_as_invocator_output_with_pydantic_annotations(): + class Foo(Output): + foo: Annotated[int, Field(gt=0)] + bar: Annotated[str, Field(max_length=10)] = "a default" + + foo = Foo.construct(foo="{{...foo}}", bar="{{...bar}}") + parameters = foo._get_as_invocator_output() + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(parameter="{{...foo}}")), + Parameter(name="bar", value_from=ValueFrom(parameter="{{...bar}}")), + ] + + +def test_get_as_invocator_output_annotated_with_name(): + class Foo(Output): + foo: Annotated[int, Parameter(name="f_oo")] + bar: Annotated[str, Parameter(name="b_ar")] = "a default" + baz: Annotated[str, Artifact(name="b_az")] + + foo = Foo.construct(foo="{{...foo}}", bar="{{...bar}}", baz="{{...baz}}") + parameters = foo._get_as_invocator_output() + + assert parameters == [ + Parameter(name="f_oo", value_from=ValueFrom(parameter="{{...foo}}")), + Parameter(name="b_ar", value_from=ValueFrom(parameter="{{...bar}}")), + Artifact(name="b_az", from_="{{...baz}}"), + ] + + +def test_get_as_invocator_output_annotated_with_description(): + class Foo(Output): + foo: Annotated[int, Parameter(description="param foo")] + bar: Annotated[str, Parameter(description="param bar")] = "a default" + baz: Annotated[str, Artifact(description="artifact baz")] + + foo = Foo.construct(foo="{{...foo}}", bar="{{...bar}}", baz="{{...baz}}") + parameters = foo._get_as_invocator_output() + + assert parameters == [ + Parameter(name="foo", value_from=ValueFrom(parameter="{{...foo}}")), + Parameter(name="bar", value_from=ValueFrom(parameter="{{...bar}}")), + Artifact(name="baz", from_="{{...baz}}"), + ] + + +def test_get_as_invocator_output_with_multiple_annotations(): + class Foo(Output): + foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)] + bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default" + baz: Annotated[str, Field(max_length=15), Artifact()] + + foo = Foo.construct(foo="{{...foo}}", bar="{{...bar}}", baz="{{...baz}}") + parameters = foo._get_as_invocator_output() - assert Foo._get_parameters() == [Parameter(name="foo", description="a foo")] + assert parameters == [ + Parameter(name="f_oo", value_from=ValueFrom(parameter="{{...foo}}")), + Parameter(name="bar", value_from=ValueFrom(parameter="{{...bar}}")), + Artifact(name="baz", from_="{{...baz}}"), + ]