Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor "one or many" typed Hera values #905

Merged
merged 6 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ help: ## Showcase the help instructions for all the available `make` commands
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'

.PHONY: install
install: ## Run poetry install with default behaviour
install: ## Run poetry install with all extras for development
@poetry env use system
@poetry install
@poetry install --all-extras

.PHONY: install-3.8
install-3.8: ## Install python3.8 for generating test data
Expand Down
107 changes: 65 additions & 42 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@
from hera.workflows._inspect import get_annotations # type: ignore
from collections import ChainMap
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)

try:
from typing import Annotated, get_args, get_origin # type: ignore
Expand Down Expand Up @@ -85,11 +97,36 @@
except ImportError:
_yaml = None


T = TypeVar("T")
OneOrMany = Union[T, List[T]]
"""OneOrMany is provided as a convenience to allow Hera models to accept single values or lists of
elliotgunton marked this conversation as resolved.
Show resolved Hide resolved
values, and so that our code is more readable. It is used by the 'normalize' validators below."""


def normalize_to_list(v: Optional[OneOrMany]):
elliotgunton marked this conversation as resolved.
Show resolved Hide resolved
"""Normalize given value to a list if not None."""
if v is None or isinstance(v, list):
return v
return [v]


def normalize_to_list_or(*valid_types: Type):
elliotgunton marked this conversation as resolved.
Show resolved Hide resolved
"""Normalize given value to a list if not None."""

def normalize_to_list_if_not_valid_type(v: Optional[OneOrMany]):
"""Normalize given value to a list if not None or already a valid type."""
if v is None or isinstance(v, (list, *valid_types)):
return v
return [v]

return normalize_to_list_if_not_valid_type


InputsT = Optional[
Union[
ModelInputs,
Union[Parameter, ModelParameter, Artifact, ModelArtifact, Dict[str, Any]],
List[Union[Parameter, ModelParameter, Artifact, ModelArtifact, Dict[str, Any]]],
OneOrMany[Union[Parameter, ModelParameter, Artifact, ModelArtifact, Dict[str, Any]]],
]
]
"""`InputsT` is the main type associated with inputs that can be specified in Hera workflows, dags, steps, etc.
Expand All @@ -103,8 +140,7 @@
OutputsT = Optional[
Union[
ModelOutputs,
Union[Parameter, ModelParameter, Artifact, ModelArtifact],
List[Union[Parameter, ModelParameter, Artifact, ModelArtifact]],
OneOrMany[Union[Parameter, ModelParameter, Artifact, ModelArtifact]],
]
]
"""`OutputsT` is the main type associated with outputs the can be specified in Hera workflows, dags, steps, etc.
Expand All @@ -117,8 +153,7 @@
ArgumentsT = Optional[
Union[
ModelArguments,
Union[Parameter, ModelParameter, Artifact, ModelArtifact, Dict[str, Any]],
List[Union[Parameter, ModelParameter, Artifact, ModelArtifact, Dict[str, Any]]],
OneOrMany[Union[Parameter, ModelParameter, Artifact, ModelArtifact, Dict[str, Any]]],
]
]
"""`ArgumentsT` is the main type associated with arguments that can be used on DAG tasks, steps, etc.
Expand All @@ -130,12 +165,10 @@

MetricsT = Optional[
Union[
_BaseMetric,
List[_BaseMetric],
Metrics,
ModelPrometheus,
List[ModelPrometheus],
ModelMetrics,
OneOrMany[_BaseMetric],
OneOrMany[ModelPrometheus],
]
]
"""`MetricsT` is the core Hera type for Prometheus metrics.
Expand All @@ -144,28 +177,21 @@
the variations of metrics provided by `hera.workflows.metrics.*`
"""

EnvT = Optional[
Union[
_BaseEnv,
EnvVar,
List[Union[_BaseEnv, EnvVar, Dict[str, Any]]],
Dict[str, Any],
]
]
EnvT = Optional[OneOrMany[Union[_BaseEnv, EnvVar, Dict[str, Any]]]]
"""`EnvT` is the core Hera type for environment variables.

The env type enables setting single valued environment variables, lists of environment variables, or dictionary
mappings of env variables names to values, which are automatically parsed by Hera.
"""

EnvFromT = Optional[Union[_BaseEnvFrom, EnvFromSource, List[Union[_BaseEnvFrom, EnvFromSource]]]]
EnvFromT = Optional[OneOrMany[Union[_BaseEnvFrom, EnvFromSource]]]
"""`EnvFromT` is the core Hera type for environment variables derived from Argo/Kubernetes sources.

This env type enables specifying environment variables in base form, as `hera.workflows.env` form, or lists of the
aforementioned objects.
"""

VolumesT = Optional[Union[Union[ModelVolume, _BaseVolume], List[Union[ModelVolume, _BaseVolume]]]]
VolumesT = Optional[OneOrMany[Union[ModelVolume, _BaseVolume]]]
"""`VolumesT` is the core Hera type for volumes.

This volume type is used to specify the configuration of volumes to be automatically created by Argo/K8s and mounted
Expand Down Expand Up @@ -271,6 +297,9 @@ class IOMixin(BaseMixin):

inputs: InputsT = None
outputs: OutputsT = None
_normalize_fields = validator("inputs", "outputs", allow_reuse=True)(
normalize_to_list_or(ModelInputs, ModelOutputs)
)

def get_parameter(self, name: str) -> Parameter:
"""Finds and returns the parameter with the supplied name.
Expand Down Expand Up @@ -341,8 +370,7 @@ def _build_outputs(self) -> Optional[ModelOutputs]:
return self.outputs

result = ModelOutputs()
outputs = self.outputs if isinstance(self.outputs, list) else [self.outputs]
for value in outputs:
for value in self.outputs:
if isinstance(value, Parameter):
result.parameters = (
[value.as_output()] if result.parameters is None else result.parameters + [value.as_output()]
Expand All @@ -355,7 +383,7 @@ def _build_outputs(self) -> Optional[ModelOutputs]:
if result.artifacts is None
else result.artifacts + [value._build_artifact()]
)
else:
elif isinstance(value, ModelArtifact):
result.artifacts = [value] if result.artifacts is None else result.artifacts + [value]

# returning `None` for `ModelInputs` means the submission to the server will not even have the `outputs` field
Expand All @@ -370,15 +398,15 @@ class EnvMixin(BaseMixin):

env: EnvT = None
env_from: EnvFromT = None
_normalize_fields = validator("env", "env_from", allow_reuse=True)(normalize_to_list)

def _build_env(self) -> Optional[List[EnvVar]]:
"""Processes the `env` field and returns a list of generated `EnvVar` or `None`."""
if self.env is None:
return None

result: List[EnvVar] = []
env = self.env if isinstance(self.env, list) else [self.env]
for e in env:
for e in self.env:
if isinstance(e, EnvVar):
result.append(e)
elif issubclass(e.__class__, _BaseEnv):
Expand All @@ -397,8 +425,7 @@ def _build_env_from(self) -> Optional[List[EnvFromSource]]:
return None

result: List[EnvFromSource] = []
env_from = self.env_from if isinstance(self.env_from, list) else [self.env_from]
for e in env_from:
for e in self.env_from:
if isinstance(e, EnvFromSource):
result.append(e)
elif issubclass(e.__class__, _BaseEnvFrom):
Expand All @@ -413,23 +440,20 @@ class MetricsMixin(BaseMixin):
"""`MetricsMixin` provides the ability to set metrics on a n object."""

metrics: MetricsT = None
_normalize_metrics = validator("metrics", allow_reuse=True)(normalize_to_list_or(Metrics, ModelMetrics))

def _build_metrics(self) -> Optional[ModelMetrics]:
"""Processes the `metrics` field and returns the generated `ModelMetrics` or `None`."""
if self.metrics is None or isinstance(self.metrics, ModelMetrics):
return self.metrics
elif isinstance(self.metrics, ModelPrometheus):
return ModelMetrics(prometheus=[self.metrics])
elif isinstance(self.metrics, Metrics):
return ModelMetrics(prometheus=self.metrics._build_metrics())
elif isinstance(self.metrics, _BaseMetric):
return ModelMetrics(prometheus=[self.metrics._build_metric()])

metrics = []
for m in self.metrics:
if isinstance(m, _BaseMetric):
metrics.append(m._build_metric())
else:
elif isinstance(m, ModelPrometheus):
metrics.append(m)
return ModelMetrics(prometheus=metrics) if metrics else None

Expand Down Expand Up @@ -520,15 +544,15 @@ class VolumeMixin(BaseMixin):
"""

volumes: VolumesT = None
_normalize_fields = validator("volumes", allow_reuse=True)(normalize_to_list)

def _build_volumes(self) -> Optional[List[ModelVolume]]:
"""Processes the `volumes` and creates an optional list of generates `Volume`s."""
if self.volumes is None:
return None

volumes = self.volumes if isinstance(self.volumes, list) else [self.volumes]
# filter volumes for otherwise we're building extra Argo volumes
filtered_volumes = [v for v in volumes if not isinstance(v, Volume)]
filtered_volumes = [cast(_BaseVolume, v) for v in self.volumes if not isinstance(v, Volume)]
# only build volumes if there are any of type `_BaseVolume`, otherwise it must be an autogenerated model
# already, so kept it as it is
result = [v._build_volume() if issubclass(v.__class__, _BaseVolume) else v for v in filtered_volumes]
Expand All @@ -539,8 +563,7 @@ def _build_persistent_volume_claims(self) -> Optional[List[PersistentVolumeClaim
if self.volumes is None:
return None

volumes = self.volumes if isinstance(self.volumes, list) else [self.volumes]
volumes_with_pv_claims = [v for v in volumes if isinstance(v, Volume)]
volumes_with_pv_claims = [v for v in self.volumes if isinstance(v, Volume)]
if not volumes_with_pv_claims:
return None
return [v._build_persistent_volume_claim() for v in volumes_with_pv_claims] or None
Expand All @@ -566,7 +589,7 @@ def _build_volume_mounts(self) -> Optional[List[VolumeMount]]:
if self.volumes is None:
volumes: list = []
else:
volumes = self.volumes if isinstance(self.volumes, list) else [self.volumes]
volumes = cast(list, self.volumes)

result = (
None
Expand All @@ -588,6 +611,7 @@ class ArgumentsMixin(BaseMixin):
"""`ArgumentsMixin` provides the ability to set the `arguments` field on the inheriting object."""

arguments: ArgumentsT = None
_normalize_arguments = validator("arguments", allow_reuse=True)(normalize_to_list_or(ModelArguments))

def _build_arguments(self) -> Optional[ModelArguments]:
"""Processes the `arguments` field and builds the optional generated `Arguments` to set as arguments."""
Expand All @@ -597,8 +621,7 @@ def _build_arguments(self) -> Optional[ModelArguments]:
return self.arguments

result = ModelArguments()
arguments = self.arguments if isinstance(self.arguments, list) else [self.arguments]
for arg in arguments:
for arg in self.arguments:
if isinstance(arg, dict):
for k, v in arg.items():
value = Parameter(name=k, value=v)
Expand Down Expand Up @@ -703,7 +726,7 @@ def _get_arguments(self, **kwargs) -> List:
# uses the user-provided value rather than the inferred value
kwargs_arguments = kwargs.get("arguments", [])
kwargs_arguments = kwargs_arguments if isinstance(kwargs_arguments, List) else [kwargs_arguments] # type: ignore
arguments = self.arguments if isinstance(self.arguments, List) else [self.arguments] + kwargs_arguments # type: ignore
arguments = self.arguments if self.arguments else [] + kwargs_arguments
return list(filter(lambda x: x is not None, arguments))

def _get_parameter_names(self, arguments: List) -> Set[str]:
Expand Down Expand Up @@ -807,7 +830,7 @@ class ItemMixin(BaseMixin):
The items passed in `with_items` must be serializable objects
"""

with_items: Optional[List[Any]] = None
with_items: Optional[OneOrMany[Any]] = None

def _build_with_items(self) -> Optional[List[Item]]:
"""Process the `with_items` field and returns an optional list of corresponding `Item`s.
Expand Down
58 changes: 56 additions & 2 deletions tests/test_unit/test_mixins.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest

from hera.workflows import Parameter
from hera.workflows._mixins import ContainerMixin, IOMixin
from hera.workflows import Env, Parameter
from hera.workflows._mixins import ArgumentsMixin, ContainerMixin, EnvMixin, IOMixin
from hera.workflows.models import (
Arguments as ModelArguments,
ImagePullPolicy,
Inputs as ModelInputs,
)
Expand Down Expand Up @@ -53,3 +54,56 @@ def test_build_inputs_from_model_inputs(self):

def test_build_outputs_none(self):
assert self.io_mixin._build_outputs() is None


class TestArgumentsMixin:
def test_list_normalized_to_list(self):
args_mixin = ArgumentsMixin(
arguments=[
Parameter(name="my-param-1"),
Parameter(name="my-param-2"),
]
)

assert isinstance(args_mixin.arguments, list)
assert len(args_mixin.arguments) == 2

def test_single_value_normalized_to_list(self):
args_mixin = ArgumentsMixin(arguments=Parameter(name="my-param"))

assert isinstance(args_mixin.arguments, list)
assert len(args_mixin.arguments) == 1

def test_none_value_is_not_normalized_to_list(self):
args_mixin = ArgumentsMixin(arguments=None)

assert args_mixin.arguments is None

def test_model_arguments_value_is_not_normalized_to_list(self):
args_mixin = ArgumentsMixin(arguments=ModelArguments())

assert args_mixin.arguments == ModelArguments()


class TestEnvMixin:
def test_list_normalized_to_list(self):
env_mixin = EnvMixin(
env=[
Env(name="test-1", value="test"),
Env(name="test-2", value="test"),
]
)

assert isinstance(env_mixin.env, list)
assert len(env_mixin.env) == 2

def test_single_value_normalized_to_list(self):
env_mixin = EnvMixin(env=Env(name="test", value="test"))

assert isinstance(env_mixin.env, list)
assert len(env_mixin.env) == 1

def test_none_value_is_not_normalized_to_list(self):
env_mixin = EnvMixin(env=None)

assert env_mixin.env is None