Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add automatic template ref steps/tasks #1097

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: my-workflow-
spec:
entrypoint: worker
templates:
- dag:
tasks:
- name: setup_task
templateRef:
clusterScope: true
name: my-cluster-workflow-template
template: setup
- arguments:
parameters:
- name: word_a
value: '{{inputs.parameters.value_a}}'
- name: word_b
value: '{{tasks.setup_task.outputs.parameters.environment_parameter}}{{tasks.setup_task.outputs.parameters.dummy-param}}'
- name: concat_config
value: '{"reverse": false}'
depends: setup_task
name: task_a
templateRef:
name: my-workflow-template
template: concat
- arguments:
parameters:
- name: word_a
value: '{{inputs.parameters.value_b}}'
- name: word_b
value: '{{tasks.setup_task.outputs.result}}'
- name: concat_config
value: '{"reverse": false}'
depends: setup_task
name: task_b
templateRef:
name: my-workflow-template
template: concat
- arguments:
parameters:
- name: word_a
value: '{{tasks.task_a.outputs.result}}'
- name: word_b
value: '{{tasks.task_b.outputs.result}}'
- name: concat_config
value: '{"reverse": false}'
depends: task_a && task_b
name: final_task
templateRef:
name: my-workflow-template
template: concat
inputs:
parameters:
- default: my default
name: value_a
- name: value_b
- default: '42'
name: an_int_value
- default: '{"param_1": "Hello", "param_2": "world"}'
name: a_basemodel
name: worker
outputs:
parameters:
- name: value
valueFrom:
parameter: '{{tasks.final_task.outputs.result}}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pydantic import BaseModel
from typing_extensions import Annotated

from hera.shared import global_config
from hera.workflows import ClusterWorkflowTemplate, Input, Output, Parameter, Workflow, WorkflowTemplate

global_config.experimental_features["decorator_syntax"] = True

wt = WorkflowTemplate(name="my-workflow-template")
cwt = ClusterWorkflowTemplate(name="my-cluster-workflow-template")

w = Workflow(generate_name="my-workflow-")


class SetupConfig(BaseModel):
a_param: str


class SetupOutput(Output):
environment_parameter: str
an_annotated_parameter: Annotated[int, Parameter(name="dummy-param")] # use an annotated non-str
setup_config: Annotated[SetupConfig, Parameter(name="setup-config")] # use a pydantic BaseModel


@cwt.script()
def setup() -> SetupOutput:
return SetupOutput(
environment_parameter="linux",
an_annotated_parameter=42,
setup_config=SetupConfig(a_param="test"),
result="Setting things up",
)


class ConcatConfig(BaseModel):
reverse: bool


class ConcatInput(Input):
word_a: Annotated[str, Parameter(name="word_a", default="")]
word_b: str
concat_config: ConcatConfig = ConcatConfig(reverse=False)


@wt.script()
def concat(concat_input: ConcatInput) -> Output:
res = f"{concat_input.word_a} {concat_input.word_b}"
if concat_input.reverse:
res = res[::-1]
return Output(result=res)


class WorkerConfig(BaseModel):
param_1: str
param_2: str


class WorkerInput(Input):
value_a: str = "my default"
value_b: str
an_int_value: int = 42
a_basemodel: WorkerConfig = WorkerConfig(param_1="Hello", param_2="world")


class WorkerOutput(Output):
value: str


@w.set_entrypoint
@w.dag()
def worker(worker_input: WorkerInput) -> WorkerOutput:
setup_task = setup()
task_a = concat(
ConcatInput(
word_a=worker_input.value_a,
word_b=setup_task.environment_parameter + str(setup_task.an_annotated_parameter),
)
)
task_b = concat(ConcatInput(word_a=worker_input.value_b, word_b=setup_task.result))
final_task = concat(ConcatInput(word_a=task_a.result, word_b=task_b.result))

return WorkerOutput(value=final_task.result)
114 changes: 66 additions & 48 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from hera.workflows.models import (
Artifact as ModelArtifact,
Parameter as ModelParameter,
TemplateRef,
)
from hera.workflows.parameter import Parameter
from hera.workflows.protocol import TWorkflow
Expand Down Expand Up @@ -495,50 +496,6 @@ def __init__(self, subnode_type: str, output_class: Type[Union[OutputV1, OutputV
self.output_class = output_class


def create_subnode(
subnode_name: str,
func: Callable,
template: Union[str, Template, TemplateMixin, CallableTemplateMixin],
*args,
**kwargs,
) -> Union[Step, Task]:
from hera.workflows.dag import DAG
from hera.workflows.steps import Parallel, Step, Steps
from hera.workflows.task import Task

subnode_args = None
if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)):
subnode_args = args[0]._get_as_arguments()

signature = inspect.signature(func)
output_class = signature.return_annotation

subnode: Union[Step, Task]

assert _context.pieces
_context.declaring = False
if isinstance(_context.pieces[-1], (Steps, Parallel)):
subnode = Step(
name=subnode_name,
template=template,
arguments=subnode_args,
**kwargs,
)
elif isinstance(_context.pieces[-1], DAG):
subnode = Task(
name=subnode_name,
template=template,
arguments=subnode_args,
depends=" && ".join(sorted(_context.pieces[-1]._current_task_depends)) or None,
**kwargs,
)
_context.pieces[-1]._current_task_depends.clear()

subnode._build_obj = HeraBuildObj(subnode._subtype, output_class)
_context.declaring = True
return subnode


def _get_underlying_type(annotation: Type):
real_type = annotation
if get_origin(annotation) is Annotated:
Expand Down Expand Up @@ -575,6 +532,67 @@ def _check_if_enabled(decorator_name: str):
"`varname` is not installed. Install `hera[experimental]` to bring in the extra dependency"
)

def _create_subnode(
self,
subnode_name: str,
func: Callable,
template: Union[str, Template, TemplateMixin, CallableTemplateMixin],
*args,
**kwargs,
) -> Union[Step, Task]:
from hera.workflows.cluster_workflow_template import ClusterWorkflowTemplate
from hera.workflows.dag import DAG
from hera.workflows.steps import Parallel, Step, Steps
from hera.workflows.task import Task
from hera.workflows.workflow_template import WorkflowTemplate

subnode_args = None
if len(args) == 1 and isinstance(args[0], (InputV1, InputV2)):
subnode_args = args[0]._get_as_arguments()

signature = inspect.signature(func)
output_class = signature.return_annotation

subnode: Union[Step, Task]

assert _context.pieces

template_ref = None
if _context.pieces[0] != self and isinstance(self, WorkflowTemplate):
# Using None for cluster_scope means it won't appear in the YAML spec (saving some bytes),
# as cluster_scope=False is the default value
template_ref = TemplateRef(
name=self.name,
template=template.name,
cluster_scope=True if isinstance(self, ClusterWorkflowTemplate) else None,
)
# Set template to None as it cannot be set alongside template_ref
template = None

_context.declaring = False
if isinstance(_context.pieces[-1], (Steps, Parallel)):
subnode = Step(
name=subnode_name,
template=template,
template_ref=template_ref,
arguments=subnode_args,
**kwargs,
)
elif isinstance(_context.pieces[-1], DAG):
subnode = Task(
name=subnode_name,
template=template,
template_ref=template_ref,
arguments=subnode_args,
depends=" && ".join(sorted(_context.pieces[-1]._current_task_depends)) or None,
**kwargs,
)
_context.pieces[-1]._current_task_depends.clear()

subnode._build_obj = HeraBuildObj(subnode._subtype, output_class)
_context.declaring = True
return subnode

@_add_type_hints(Script) # type: ignore
def script(self, **script_kwargs) -> Callable:
"""A decorator that wraps a function into a Script object.
Expand Down Expand Up @@ -629,7 +647,7 @@ def script_decorator(func: Callable[FuncIns, FuncR]) -> Callable:
if "constructor" not in script_kwargs and "constructor" not in global_config._get_class_defaults(Script):
script_kwargs["constructor"] = RunnerScriptConstructor()

# Open context to add `Script` object automatically
# Open (Workflow) context to add `Script` object automatically
with self:
script_template = Script(name=name, source=source, **script_kwargs)

Expand All @@ -650,7 +668,7 @@ def script_call_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]:
subnode_name = kwargs.pop("name", subnode_name)
assert isinstance(subnode_name, str)

return create_subnode(subnode_name, func, script_template, *args, **kwargs)
return self._create_subnode(subnode_name, func, script_template, *args, **kwargs)

if _context.pieces:
return script_template.__call__(*args, **kwargs)
Expand Down Expand Up @@ -709,7 +727,7 @@ def container_call_wrapper(*args, **kwargs) -> Union[FuncR, Step, Task, None]:
subnode_name = kwargs.pop("name", subnode_name)
assert isinstance(subnode_name, str)

return create_subnode(subnode_name, func, container_template, *args, **kwargs)
return self._create_subnode(subnode_name, func, container_template, *args, **kwargs)

if _context.pieces:
return container_template.__call__(*args, **kwargs)
Expand Down Expand Up @@ -792,7 +810,7 @@ def call_wrapper(*args, **kwargs):
subnode_name = kwargs.pop("name", subnode_name)
assert isinstance(subnode_name, str)

return create_subnode(subnode_name, func, template, *args, **kwargs)
return self._create_subnode(subnode_name, func, template, *args, **kwargs)

return func(*args, **kwargs)

Expand Down