Skip to content

Commit

Permalink
Fix issues in _get_outputs
Browse files Browse the repository at this point in the history
OutputMixin._get_outputs was mutating the original annotation instead of
copying it; not defaulting the name to the field name; and not copying
the model default into the annotation.

Signed-off-by: Alice Purcell <alicederyn@gmail.com>
  • Loading branch information
alicederyn committed Sep 5, 2024
1 parent a838374 commit 70fd617
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Self

from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields
from hera.shared._type_util import get_workflow_annotation, is_annotated
from hera.shared._type_util import get_workflow_annotation
from hera.shared.serialization import MISSING, serialize
from hera.workflows._context import _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -189,7 +189,13 @@ def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Pa
if field in {"exit_code", "result"}:
continue
if param_or_artifact := get_workflow_annotation(annotations[field]):
param_or_artifact = param_or_artifact.copy()
param_or_artifact.name = param_or_artifact.name or field
if isinstance(param_or_artifact, Parameter):
if param_or_artifact.default is None:
default = model_fields[field].default
if default is not None and default != PydanticUndefined:
param_or_artifact.default = serialize(default)
if add_missing_path and (
param_or_artifact.value_from is None or param_or_artifact.value_from.path is None
):
Expand All @@ -201,7 +207,7 @@ def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Pa
if add_missing_path and param_or_artifact.path is None:
param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}"
outputs.append(param_or_artifact)
elif not is_annotated(annotations[field]):
else:
# Create a Parameter from basic type annotations
default = model_fields[field].default
if default is None or default == PydanticUndefined:
Expand Down
211 changes: 211 additions & 0 deletions tests/test_unit/test_io_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,217 @@ class Foo(Input):
)


def test_get_outputs_no_path_unannotated():
class Foo(Output):
foo: int
fum: int = 5
bar: str = "a default"

parameters = Foo._get_outputs()

assert parameters == [
Parameter(name="foo"),
Parameter(name="fum", default=5),
Parameter(name="bar", default="a default"),
]


def test_get_outputs_no_path_with_pydantic_annotations():
class Foo(Output):
foo: Annotated[int, Field(gt=0)]
fum: Annotated[int, Field(lt=1000)] = 5
bar: Annotated[str, Field(max_length=10)] = "a default"

parameters = Foo._get_outputs()

assert parameters == [
Parameter(name="foo"),
Parameter(name="fum", default=5),
Parameter(name="bar", default="a default"),
]


def test_get_outputs_no_path_annotated_with_name():
class Foo(Output):
foo: Annotated[int, Parameter(name="f_oo")]
fum: Annotated[int, Parameter(name="f_um")] = 5
bar: Annotated[str, Parameter(name="b_ar")] = "a default"
baz: Annotated[str, Artifact(name="b_az")]

parameters = Foo._get_outputs()

assert parameters == [
Parameter(name="f_oo"),
Parameter(name="f_um", default=5),
Parameter(name="b_ar", default="a default"),
Artifact(name="b_az"),
]


def test_get_outputs_no_path_annotated_with_path():
class Foo(Output):
foo: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/one"))]
fum: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/two"))] = 5
bar: Annotated[str, Parameter(value_from=ValueFrom(path="/tmp/three"))] = "a default"
baz: Annotated[str, Artifact(path="/tmp/four")]

parameters = Foo._get_outputs()

assert parameters == [
Parameter(name="foo", value_from=ValueFrom(path="/tmp/one")),
Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/two")),
Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/three")),
Artifact(name="baz", path="/tmp/four"),
]


def test_get_outputs_no_path_annotated_with_description():
class Foo(Output):
foo: Annotated[int, Parameter(description="param foo")]
fum: Annotated[int, Parameter(description="param fum")] = 5
bar: Annotated[str, Parameter(description="param bar")] = "a default"
baz: Annotated[str, Artifact(description="artifact baz")]

parameters = Foo._get_outputs()

assert parameters == [
Parameter(name="foo", description="param foo"),
Parameter(name="fum", description="param fum", default=5),
Parameter(name="bar", default="a default", description="param bar"),
Artifact(name="baz", description="artifact baz"),
]


def test_get_outputs_no_path_with_multiple_annotations():
class Foo(Output):
foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)]
fum: Annotated[int, Field(lt=10000), Parameter(name="f_um")] = 5
bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default"
baz: Annotated[str, Field(max_length=15), Artifact()]

parameters = Foo._get_outputs()

assert parameters == [
Parameter(name="f_oo"),
Parameter(name="f_um", default=5),
Parameter(name="bar", default="a default", description="param bar"),
Artifact(name="baz"),
]


def test_get_outputs_add_path_unannotated():
class Foo(Output):
foo: int
fum: int = 5
bar: str = "a default"

parameters = Foo._get_outputs(add_missing_path=True)

assert parameters == [
Parameter(name="foo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/foo")),
Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/fum")),
Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar")),
]


def test_get_outputs_add_path_with_pydantic_annotations():
class Foo(Output):
foo: Annotated[int, Field(gt=0)]
fum: Annotated[int, Field(lt=1000)] = 5
bar: Annotated[str, Field(max_length=10)] = "a default"

parameters = Foo._get_outputs(add_missing_path=True)

assert parameters == [
Parameter(name="foo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/foo")),
Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/fum")),
Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar")),
]


def test_get_outputs_add_path_annotated_with_name():
class Foo(Output):
foo: Annotated[int, Parameter(name="f_oo")]
fum: Annotated[int, Parameter(name="f_um")] = 5
bar: Annotated[str, Parameter(name="b_ar")] = "a default"
baz: Annotated[str, Artifact(name="b_az")]

parameters = Foo._get_outputs(add_missing_path=True)

assert parameters == [
Parameter(name="f_oo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_oo")),
Parameter(name="f_um", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_um")),
Parameter(name="b_ar", default="a default", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/b_ar")),
Artifact(name="b_az", path="/tmp/hera-outputs/artifacts/b_az"),
]


def test_get_outputs_add_path_annotated_with_path():
class Foo(Output):
foo: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/one"))]
fum: Annotated[int, Parameter(value_from=ValueFrom(path="/tmp/two"))] = 5
bar: Annotated[str, Parameter(value_from=ValueFrom(path="/tmp/three"))] = "a default"
baz: Annotated[str, Artifact(path="/tmp/four")]

parameters = Foo._get_outputs(add_missing_path=True)

assert parameters == [
Parameter(name="foo", value_from=ValueFrom(path="/tmp/one")),
Parameter(name="fum", default=5, value_from=ValueFrom(path="/tmp/two")),
Parameter(name="bar", default="a default", value_from=ValueFrom(path="/tmp/three")),
Artifact(name="baz", path="/tmp/four"),
]


def test_get_outputs_add_path_annotated_with_description():
class Foo(Output):
foo: Annotated[int, Parameter(description="param foo")]
fum: Annotated[int, Parameter(description="param fum")] = 5
bar: Annotated[str, Parameter(description="param bar")] = "a default"
baz: Annotated[str, Artifact(description="artifact baz")]

parameters = Foo._get_outputs(add_missing_path=True)

assert parameters == [
Parameter(name="foo", description="param foo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/foo")),
Parameter(
name="fum",
description="param fum",
default=5,
value_from=ValueFrom(path="/tmp/hera-outputs/parameters/fum"),
),
Parameter(
name="bar",
default="a default",
description="param bar",
value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar"),
),
Artifact(name="baz", description="artifact baz", path="/tmp/hera-outputs/artifacts/baz"),
]


def test_get_outputs_add_path_with_multiple_annotations():
class Foo(Output):
foo: Annotated[int, Parameter(name="f_oo"), Field(gt=0)]
fum: Annotated[int, Field(lt=10000), Parameter(name="f_um")] = 5
bar: Annotated[str, Field(max_length=10), Parameter(description="param bar")] = "a default"
baz: Annotated[str, Field(max_length=15), Artifact()]

parameters = Foo._get_outputs(add_missing_path=True)

assert parameters == [
Parameter(name="f_oo", value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_oo")),
Parameter(name="f_um", default=5, value_from=ValueFrom(path="/tmp/hera-outputs/parameters/f_um")),
Parameter(
name="bar",
default="a default",
description="param bar",
value_from=ValueFrom(path="/tmp/hera-outputs/parameters/bar"),
),
Artifact(name="baz", path="/tmp/hera-outputs/artifacts/baz"),
]


def test_get_as_invocator_output_unannotated():
class Foo(Output):
foo: int
Expand Down

0 comments on commit 70fd617

Please sign in to comment.