Skip to content

Commit

Permalink
Fix multiple bugs in IO mixins (#1193)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [X] Fixes #1190, fixes #1165
- [X] Tests added
- [ ] Documentation/examples added
- [X] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, the methods that consume annotations in InputMixin and
OutputMixin have multiple bugs:

- ignoring fields with a workflow annotation with no name
- ignoring fields with a Pydantic annotation and no workflow annotation
- `OutputMixin._get_outputs` additionally was:
  * mutating the original annotations without copying them
* ignoring the model default if a workflow annotation was present but
had no default

This PR pulls out a common function to iterate through fields and yield
annotations, which:

- copies the annotations so they cannot be accidentally changed
- ignores unrecognized annotations (this is a requirement of the
Annotated spec, and fixes the issue with Pydantic annotations)
- defaults the annotation name to the field name if unset

This ensures all functions treat an explicit Parameter and a missing
workflow annotation the same way, solving the second issue with
`OutputMixin._get_outputs`.

---------

Signed-off-by: Alice Purcell <alicederyn@gmail.com>
Signed-off-by: Elliot Gunton <elliotgunton@gmail.com>
Co-authored-by: Elliot Gunton <elliotgunton@gmail.com>
  • Loading branch information
alicederyn and elliotgunton authored Sep 17, 2024
1 parent af5b19f commit 3e58f22
Show file tree
Hide file tree
Showing 2 changed files with 618 additions and 93 deletions.
136 changes: 53 additions & 83 deletions src/hera/workflows/io/_io_mixins.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import sys
import warnings
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from pydantic.fields import FieldInfo

from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields
from hera.shared._type_util import get_workflow_annotation, is_annotated
from hera.shared._type_util import get_workflow_annotation
from hera.shared.serialization import MISSING, serialize
from hera.workflows._context import _context
from hera.workflows.artifact import Artifact
Expand Down Expand Up @@ -39,6 +41,23 @@
BaseModel = object # type: ignore


def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]:
"""Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations.
If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing.
Otherwise, a Parameter object will be constructed.
"""
annotations = get_field_annotations(cls)
for field, field_info in get_fields(cls).items():
if annotation := get_workflow_annotation(annotations[field]):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or field
yield field, field_info, annotation_copy
else:
yield field, field_info, Parameter(name=field)


class InputMixin(BaseModel):
def __new__(cls, **kwargs):
if _context.declaring:
Expand All @@ -62,14 +81,9 @@ def __init__(self, /, **kwargs):
@classmethod
def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]:
parameters = []
annotations = get_field_annotations(cls)

for field, field_info in get_fields(cls).items():
if (param := get_workflow_annotation(annotations[field])) and isinstance(param, Parameter):
# Copy so as to not modify the Input fields themselves
param = param.copy()
if param.name is None:
param.name = field
for field, field_info, param in _construct_io_from_fields(cls):
if isinstance(param, Parameter):
if param.default is not None:
warnings.warn(
"Using the default field for Parameters in Annotations is deprecated since v5.16"
Expand All @@ -81,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(field_info.default) # type: ignore
parameters.append(param)
elif not is_annotated(annotations[field]):
# Create a Parameter from basic type annotations
default = getattr(object_override, field) if object_override else field_info.default

# For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None
if default is None or default == PydanticUndefined:
default = MISSING

parameters.append(Parameter(name=field, default=default))

return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
artifacts = []
annotations = get_field_annotations(cls)

for field in get_fields(cls):
if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact):
# Copy so as to not modify the Input fields themselves
artifact = artifact.copy()
if artifact.name is None:
artifact.name = field
for _, _, artifact in _construct_io_from_fields(cls):
if isinstance(artifact, Artifact):
if artifact.path is None:
artifact.path = artifact._get_default_inputs_path()
artifacts.append(artifact)
Expand All @@ -117,42 +117,33 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
def _get_as_templated_arguments(cls) -> Self:
"""Returns the Input with templated values to propagate through a DAG/Steps function."""
object_dict = {}
cls_fields = get_fields(cls)
annotations = get_field_annotations(cls)

for field in cls_fields:
if param_or_artifact := get_workflow_annotation(annotations[field]):
if isinstance(param_or_artifact, Parameter):
object_dict[field] = "{{inputs.parameters." + f"{param_or_artifact.name}" + "}}"
else:
object_dict[field] = "{{inputs.artifacts." + f"{param_or_artifact.name}" + "}}"
elif not is_annotated(annotations[field]):
object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}"
for field, _, annotation in _construct_io_from_fields(cls):
input_type = "parameters" if isinstance(annotation, Parameter) else "artifacts"
object_dict[field] = "{{" + f"inputs.{input_type}.{annotation.name}" + "}}"

return cls.construct(None, **object_dict)

def _get_as_arguments(self) -> ModelArguments:
params = []
artifacts = []
annotations = get_field_annotations(type(self))

if isinstance(self, V1BaseModel):
self_dict = self.dict()
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
self_dict = self.model_dump()

for field in get_fields(type(self)):
for field, _, annotation in _construct_io_from_fields(type(self)):
# The value may be a static value (of any time) if it has a default value, so we need to serialize it
# If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")``
templated_value = serialize(self_dict[field])
name = annotation.name
assert name is not None # guaranteed by _get_workflow_annotations

if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name:
if isinstance(param_or_artifact, Parameter):
params.append(ModelParameter(name=param_or_artifact.name, value=templated_value))
else:
artifacts.append(ModelArtifact(name=param_or_artifact.name, from_=templated_value))
elif not is_annotated(annotations[field]):
params.append(ModelParameter(name=field, value=templated_value))
if isinstance(annotation, Parameter):
params.append(ModelParameter(name=name, value=templated_value))
else:
artifacts.append(ModelArtifact(name=name, from_=templated_value))

return ModelArguments(parameters=params or None, artifacts=artifacts or None)

Expand All @@ -178,37 +169,22 @@ def __init__(self, /, **kwargs):
@classmethod
def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]:
outputs: List[Union[Artifact, Parameter]] = []
annotations = get_field_annotations(cls)

model_fields = get_fields(cls)

for field in model_fields:
for field, field_info, annotation in _construct_io_from_fields(cls):
if field in {"exit_code", "result"}:
continue
if param_or_artifact := get_workflow_annotation(annotations[field]):
if isinstance(param_or_artifact, Parameter):
if add_missing_path and (
param_or_artifact.value_from is None or param_or_artifact.value_from.path is None
):
param_or_artifact.value_from = ValueFrom(
path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}"
)
outputs.append(param_or_artifact)
else:
if add_missing_path and param_or_artifact.path is None:
param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}"
outputs.append(param_or_artifact)
elif not is_annotated(annotations[field]):
# Create a Parameter from basic type annotations
default = model_fields[field].default
if default is None or default == PydanticUndefined:
default = MISSING

value_from = None
if add_missing_path:
value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{field}")

outputs.append(Parameter(name=field, default=default, value_from=value_from))
if isinstance(annotation, Parameter):
if annotation.default is None:
default = field_info.default
if default is not None and default != PydanticUndefined:
annotation.default = serialize(default)

if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None):
annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}")
else:
if add_missing_path and annotation.path is None:
annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}"
outputs.append(annotation)
return outputs

@classmethod
Expand All @@ -230,27 +206,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]:
This lets dags and steps hoist task/step outputs into its own outputs.
"""
outputs: List[Union[Artifact, Parameter]] = []
annotations = get_field_annotations(type(self))

if isinstance(self, V1BaseModel):
self_dict = self.dict()
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
self_dict = self.model_dump()

for field in get_fields(type(self)):
for field, _, annotation in _construct_io_from_fields(type(self)):
if field in {"exit_code", "result"}:
continue

templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"`

if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name:
if isinstance(param_or_artifact, Parameter):
outputs.append(
Parameter(name=param_or_artifact.name, value_from=ValueFrom(parameter=templated_value))
)
else:
outputs.append(Artifact(name=param_or_artifact.name, from_=templated_value))
elif not is_annotated(annotations[field]):
outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value)))
if isinstance(annotation, Parameter):
outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value)))
else:
outputs.append(Artifact(name=annotation.name, from_=templated_value))

return outputs
Loading

0 comments on commit 3e58f22

Please sign in to comment.