From 1b728a88fc3ab70ebc5a5d1eee44d48dc185f2d0 Mon Sep 17 00:00:00 2001 From: Elliot Gunton Date: Wed, 12 Jun 2024 17:49:08 +0800 Subject: [PATCH] Add automatic template ref steps/tasks * Within a DAG/Steps function, when calling functions that are decorated as templates belonging to other [Cluster]WorkflowTemplates, we generate a step/task that sets the TemplateRef based on the WorkflowTemplate name and the template name, as well as setting cluster_scope. Arguments are set as normal Signed-off-by: Elliot Gunton --- .../new-decorators-auto-template-refs.yaml | 68 +++++++++++ .../new_decorators_auto_template_refs.py | 82 +++++++++++++ src/hera/workflows/_meta_mixins.py | 114 ++++++++++-------- 3 files changed, 216 insertions(+), 48 deletions(-) create mode 100644 examples/workflows/experimental/new-decorators-auto-template-refs.yaml create mode 100644 examples/workflows/experimental/new_decorators_auto_template_refs.py diff --git a/examples/workflows/experimental/new-decorators-auto-template-refs.yaml b/examples/workflows/experimental/new-decorators-auto-template-refs.yaml new file mode 100644 index 000000000..69b858dd6 --- /dev/null +++ b/examples/workflows/experimental/new-decorators-auto-template-refs.yaml @@ -0,0 +1,68 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: my-workflow- +spec: + entrypoint: worker + templates: + - dag: + tasks: + - name: setup_task + templateRef: + clusterScope: true + name: my-cluster-workflow-template + template: setup + - arguments: + parameters: + - name: word_a + value: '{{inputs.parameters.value_a}}' + - name: word_b + value: '{{tasks.setup_task.outputs.parameters.environment_parameter}}{{tasks.setup_task.outputs.parameters.dummy-param}}' + - name: concat_config + value: '{"reverse": false}' + depends: setup_task + name: task_a + templateRef: + name: my-workflow-template + template: concat + - arguments: + parameters: + - name: word_a + value: '{{inputs.parameters.value_b}}' + - name: word_b + value: '{{tasks.setup_task.outputs.result}}' + - name: concat_config + value: '{"reverse": false}' + depends: setup_task + name: task_b + templateRef: + name: my-workflow-template + template: concat + - arguments: + parameters: + - name: word_a + value: '{{tasks.task_a.outputs.result}}' + - name: word_b + value: '{{tasks.task_b.outputs.result}}' + - name: concat_config + value: '{"reverse": false}' + depends: task_a && task_b + name: final_task + templateRef: + name: my-workflow-template + template: concat + inputs: + parameters: + - default: my default + name: value_a + - name: value_b + - default: '42' + name: an_int_value + - default: '{"param_1": "Hello", "param_2": "world"}' + name: a_basemodel + name: worker + outputs: + parameters: + - name: value + valueFrom: + parameter: '{{tasks.final_task.outputs.result}}' diff --git a/examples/workflows/experimental/new_decorators_auto_template_refs.py b/examples/workflows/experimental/new_decorators_auto_template_refs.py new file mode 100644 index 000000000..98be7492d --- /dev/null +++ b/examples/workflows/experimental/new_decorators_auto_template_refs.py @@ -0,0 +1,82 @@ +from pydantic import BaseModel +from typing_extensions import Annotated + +from hera.shared import global_config +from hera.workflows import ClusterWorkflowTemplate, Input, Output, Parameter, Workflow, WorkflowTemplate + +global_config.experimental_features["decorator_syntax"] = True + +wt = WorkflowTemplate(name="my-workflow-template") +cwt = ClusterWorkflowTemplate(name="my-cluster-workflow-template") + +w = Workflow(generate_name="my-workflow-") + + +class SetupConfig(BaseModel): + a_param: str + + +class SetupOutput(Output): + environment_parameter: str + an_annotated_parameter: Annotated[int, Parameter(name="dummy-param")] # use an annotated non-str + setup_config: Annotated[SetupConfig, Parameter(name="setup-config")] # use a pydantic BaseModel + + +@cwt.script() +def setup() -> SetupOutput: + return SetupOutput( + environment_parameter="linux", + an_annotated_parameter=42, + setup_config=SetupConfig(a_param="test"), + result="Setting things up", + ) + + +class ConcatConfig(BaseModel): + reverse: bool + + +class ConcatInput(Input): + word_a: Annotated[str, Parameter(name="word_a", default="")] + word_b: str + concat_config: ConcatConfig = ConcatConfig(reverse=False) + + +@wt.script() +def concat(concat_input: ConcatInput) -> Output: + res = f"{concat_input.word_a} {concat_input.word_b}" + if concat_input.reverse: + res = res[::-1] + return Output(result=res) + + +class WorkerConfig(BaseModel): + param_1: str + param_2: str + + +class WorkerInput(Input): + value_a: str = "my default" + value_b: str + an_int_value: int = 42 + a_basemodel: WorkerConfig = WorkerConfig(param_1="Hello", param_2="world") + + +class WorkerOutput(Output): + value: str + + +@w.set_entrypoint +@w.dag() +def worker(worker_input: WorkerInput) -> WorkerOutput: + setup_task = setup() + task_a = concat( + ConcatInput( + word_a=worker_input.value_a, + word_b=setup_task.environment_parameter + str(setup_task.an_annotated_parameter), + ) + ) + task_b = concat(ConcatInput(word_a=worker_input.value_b, word_b=setup_task.result)) + final_task = concat(ConcatInput(word_a=task_a.result, word_b=task_b.result)) + + return WorkerOutput(value=final_task.result) diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index b661ac7bd..cd3f4309b 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -23,6 +23,7 @@ from hera.workflows.models import ( Artifact as ModelArtifact, Parameter as ModelParameter, + TemplateRef, ) from hera.workflows.parameter import Parameter from hera.workflows.protocol import TWorkflow @@ -495,50 +496,6 @@ def __init__(self, subnode_type: str, output_class: Type[Union[OutputV1, OutputV self.output_class = output_class -def create_subnode( - subnode_name: str, - func: Callable, - template: Union[str, Template, TemplateMixin, CallableTemplateMixin], - *args, - **kwargs, -) -> Union[Step, Task]: - from hera.workflows.dag import DAG - from hera.workflows.steps import Parallel, Step, Steps - from hera.workflows.task import Task - - subnode_args = None - if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)): - subnode_args = args[0]._get_as_arguments() - - signature = inspect.signature(func) - output_class = signature.return_annotation - - subnode: Union[Step, Task] - - assert _context.pieces - _context.declaring = False - if isinstance(_context.pieces[-1], (Steps, Parallel)): - subnode = Step( - name=subnode_name, - template=template, - arguments=subnode_args, - **kwargs, - ) - elif isinstance(_context.pieces[-1], DAG): - subnode = Task( - name=subnode_name, - template=template, - arguments=subnode_args, - depends=" && ".join(sorted(_context.pieces[-1]._current_task_depends)) or None, - **kwargs, - ) - _context.pieces[-1]._current_task_depends.clear() - - subnode._build_obj = HeraBuildObj(subnode._subtype, output_class) - _context.declaring = True - return subnode - - def _get_underlying_type(annotation: Type): real_type = annotation if get_origin(annotation) is Annotated: @@ -575,6 +532,67 @@ def _check_if_enabled(decorator_name: str): "`varname` is not installed. Install `hera[experimental]` to bring in the extra dependency" ) + def _create_subnode( + self, + subnode_name: str, + func: Callable, + template: Union[str, Template, TemplateMixin, CallableTemplateMixin], + *args, + **kwargs, + ) -> Union[Step, Task]: + from hera.workflows.cluster_workflow_template import ClusterWorkflowTemplate + from hera.workflows.dag import DAG + from hera.workflows.steps import Parallel, Step, Steps + from hera.workflows.task import Task + from hera.workflows.workflow_template import WorkflowTemplate + + subnode_args = None + if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)): + subnode_args = args[0]._get_as_arguments() + + signature = inspect.signature(func) + output_class = signature.return_annotation + + subnode: Union[Step, Task] + + assert _context.pieces + + template_ref = None + if _context.pieces[0] != self and isinstance(self, WorkflowTemplate): + # Using None for cluster_scope means it won't appear in the YAML spec (saving some bytes), + # as cluster_scope=False is the default value + template_ref = TemplateRef( + name=self.name, + template=template.name, + cluster_scope=True if isinstance(self, ClusterWorkflowTemplate) else None, + ) + # Set template to None as it cannot be set alongside template_ref + template = None + + _context.declaring = False + if isinstance(_context.pieces[-1], (Steps, Parallel)): + subnode = Step( + name=subnode_name, + template=template, + template_ref=template_ref, + arguments=subnode_args, + **kwargs, + ) + elif isinstance(_context.pieces[-1], DAG): + subnode = Task( + name=subnode_name, + template=template, + template_ref=template_ref, + arguments=subnode_args, + depends=" && ".join(sorted(_context.pieces[-1]._current_task_depends)) or None, + **kwargs, + ) + _context.pieces[-1]._current_task_depends.clear() + + subnode._build_obj = HeraBuildObj(subnode._subtype, output_class) + _context.declaring = True + return subnode + @_add_type_hints(Script) # type: ignore def script(self, **script_kwargs) -> Callable: """A decorator that wraps a function into a Script object. @@ -629,7 +647,7 @@ def script_decorator(func: Callable[FuncIns, FuncR]) -> Callable: if "constructor" not in script_kwargs and "constructor" not in global_config._get_class_defaults(Script): script_kwargs["constructor"] = RunnerScriptConstructor() - # Open context to add `Script` object automatically + # Open (Workflow) context to add `Script` object automatically with self: script_template = Script(name=name, source=source, **script_kwargs) @@ -650,7 +668,7 @@ def script_call_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]: subnode_name = kwargs.pop("name", subnode_name) assert isinstance(subnode_name, str) - return create_subnode(subnode_name, func, script_template, *args, **kwargs) + return self._create_subnode(subnode_name, func, script_template, *args, **kwargs) if _context.pieces: return script_template.__call__(*args, **kwargs) @@ -709,7 +727,7 @@ def container_call_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]: subnode_name = kwargs.pop("name", subnode_name) assert isinstance(subnode_name, str) - return create_subnode(subnode_name, func, container_template, *args, **kwargs) + return self._create_subnode(subnode_name, func, container_template, *args, **kwargs) if _context.pieces: return container_template.__call__(*args, **kwargs) @@ -792,7 +810,7 @@ def call_wrapper(*args, **kwargs): subnode_name = kwargs.pop("name", subnode_name) assert isinstance(subnode_name, str) - return create_subnode(subnode_name, func, template, *args, **kwargs) + return self._create_subnode(subnode_name, func, template, *args, **kwargs) return func(*args, **kwargs)