Skip to content

Commit

Permalink
feat(core): provide a ForEach task (#4343)
Browse files Browse the repository at this point in the history
Fixes #2137
  • Loading branch information
loicmathieu authored Jul 17, 2024
1 parent da10330 commit bdd916b
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 9 deletions.
90 changes: 81 additions & 9 deletions core/src/main/java/io/kestra/core/runners/FlowableUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ public static List<ResolvedTask> resolveTasks(List<Task> tasks, TaskRun parentTa
.toList();
}

/**
* resolveParallelNexts will resolve both concurrent values and subtasks
* For only concurrent values, see resolveConcurrentNexts()
*/
public static List<NextTaskRun> resolveParallelNexts(
Execution execution,
List<ResolvedTask> tasks,
Expand All @@ -204,6 +208,74 @@ public static List<NextTaskRun> resolveParallelNexts(
);
}

/**
* resolveConcurrentNexts will resolve concurrent values
* For both concurrent vales and subtasks, see resolveParallelNexts()
*/
public static List<NextTaskRun> resolveConcurrentNexts(
Execution execution,
List<ResolvedTask> tasks,
List<ResolvedTask> errors,
TaskRun parentTaskRun,
Integer concurrency
) {
if (execution.getState().getCurrent() == State.Type.KILLING) {
return Collections.emptyList();
}

List<ResolvedTask> allTasks = execution.findTaskDependingFlowState(
tasks,
errors,
parentTaskRun
);

// all tasks run
List<TaskRun> taskRuns = execution.findTaskRunByTasks(allTasks, parentTaskRun);

// find all non-terminated
long nonTerminatedCount = taskRuns
.stream()
.filter(taskRun -> !taskRun.getState().isTerminated())
.count();

if (concurrency > 0 && nonTerminatedCount >= concurrency) {
return Collections.emptyList();
}

long concurrencySlots = concurrency - nonTerminatedCount;

// first one
if (taskRuns.isEmpty()) {
Map<String, List<ResolvedTask>> collect = allTasks
.stream()
.collect(Collectors.groupingBy(resolvedTask -> resolvedTask.getValue(), () -> new LinkedHashMap<>(), Collectors.toList()));
return collect.values().stream()
.limit(concurrencySlots)
.map(resolvedTasks -> resolvedTasks.getFirst().toNextTaskRun(execution))
.toList()
.reversed();
}

// start as many tasks as we have concurrency slots
Map<String, List<ResolvedTask>> collect = allTasks
.stream()
.collect(Collectors.groupingBy(resolvedTask -> resolvedTask.getValue(), () -> new LinkedHashMap<>(), Collectors.toList()));
return collect.values().stream()
.map(resolvedTasks -> filterCreated(resolvedTasks, taskRuns, parentTaskRun))
.filter(resolvedTasks -> !resolvedTasks.isEmpty())
.limit(concurrencySlots)
.map(resolvedTasks -> resolvedTasks.getFirst().toNextTaskRun(execution))
.toList();
}

private static List<ResolvedTask> filterCreated(List<ResolvedTask> tasks, List<TaskRun> taskRuns, TaskRun parentTaskRun) {
return tasks.stream()
.filter(resolvedTask -> taskRuns.stream()
.noneMatch(taskRun -> FlowableUtils.isTaskRunFor(resolvedTask, taskRun, parentTaskRun))
)
.toList();
}

public static List<NextTaskRun> resolveDagNexts(
Execution execution,
List<ResolvedTask> tasks,
Expand Down Expand Up @@ -265,15 +337,6 @@ public static List<NextTaskRun> resolveParallelNexts(
// all tasks run
List<TaskRun> taskRuns = execution.findTaskRunByTasks(currentTasks, parentTaskRun);

// find all not created tasks
List<ResolvedTask> notFinds = currentTasks
.stream()
.filter(resolvedTask -> taskRuns
.stream()
.noneMatch(taskRun -> FlowableUtils.isTaskRunFor(resolvedTask, taskRun, parentTaskRun))
)
.toList();

// find all running and deal concurrency
long runningCount = taskRuns
.stream()
Expand All @@ -284,6 +347,15 @@ public static List<NextTaskRun> resolveParallelNexts(
return Collections.emptyList();
}

// find all not created tasks
List<ResolvedTask> notFinds = currentTasks
.stream()
.filter(resolvedTask -> taskRuns
.stream()
.noneMatch(taskRun -> FlowableUtils.isTaskRunFor(resolvedTask, taskRun, parentTaskRun))
)
.toList();

// first created, leave
Optional<TaskRun> lastCreated = execution.findLastCreated(taskRuns);

Expand Down
191 changes: 191 additions & 0 deletions core/src/main/java/io/kestra/plugin/core/flow/ForEach.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package io.kestra.plugin.core.flow;

import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.NextTaskRun;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.hierarchies.GraphCluster;
import io.kestra.core.models.hierarchies.RelationType;
import io.kestra.core.models.tasks.FlowableTask;
import io.kestra.core.models.tasks.ResolvedTask;
import io.kestra.core.models.tasks.VoidOutput;
import io.kestra.core.runners.FlowableUtils;
import io.kestra.core.runners.RunContext;
import io.kestra.core.utils.GraphUtils;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;

import java.util.List;
import java.util.Optional;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "For each value in the list, execute one or more tasks sequentially.",
description = "The list of `tasks` will be executed for each value sequentially, but values can be executed concurrently if concurrency is set to more than one." +
"The values must be a valid JSON string representing an array, e.g. a list of strings `[\"value1\", \"value2\"]` or a list of dictionaries `[{\"key\": \"value1\"}, {\"key\": \"value2\"}]`. \n\n" +
"You can access the current iteration value using the variable `{{ taskrun.value }}`. " +
"The task list will be executed sequentially for each value, if you want them to be executed in parallel you can use a Parallel task.\n\n" +
"We highly recommend triggering a subflow for each value. " +
"This allows much better scalability and modularity. Check the [flow best practices documentation](https://kestra.io/docs/developer-guide/best-practices) " +
"and the [following Blueprint](https://kestra.io/blueprints/128-run-a-subflow-for-each-value-in-parallel-and-wait-for-their-completion-recommended-pattern-to-iterate-over-hundreds-or-thousands-of-list-items) " +
"for more details."
)
@Plugin(
examples = {
@Example(
full = true,
title = "The taskrun.value from the `each_sequential` task is available only to immediate child tasks such as the `before_if` and the `if` tasks. To access the taskrun value in child tasks of the `if` task (such as in the `after_if` task), you need to use the syntax `{{ parent.taskrun.value }}` as this allows you to access the taskrun value of the parent task `each_sequential`.",
code = """
id: loop_example
namespace: company.team
tasks:
- id: for_each
type: io.kestra.plugin.core.flow.ForEach
values: ["value 1", "value 2", "value 3"]
tasks:
- id: before_if
type: io.kestra.plugin.core.debug.Return
format: 'Before if {{ taskrun.value }}'
- id: if
type: io.kestra.plugin.core.flow.If
condition: '{{ taskrun.value == "value 2" }}'
then:
- id: after_if
type: io.kestra.plugin.core.debug.Return
format: 'After if {{ parent.taskrun.value }}'"""
),
@Example(
full = true,
title = "This task shows that the value can be a bullet-style list. The task iterates over the list of values and executes the `each-value` child task for each value. It also process the values concurrently 2 by 2.",
code = {
"id: for_each",
"namespace: company.team",
"",
"tasks:",
" - id: for_each",
" type: io.kestra.plugin.core.flow.ForEach",
" values: ",
" - value 1",
" - value 2",
" - value 3",
" concurrency: 2",
" tasks:",
" - id: each-value",
" type: io.kestra.plugin.core.debug.Return",
" format: \"{{ task.id }} with value '{{ taskrun.value }}'\"",
}
),
@Example(
full = true,
title = "This task shows how to process the sub-tasks in parallel.",
code = """
id: loop_example
namespace: company.team
tasks:
- id: for_each
type: io.kestra.plugin.core.flow.ForEach
values: ["value 1", "value 2", "value 3"]
concurrency: 2
tasks:
- id: parallel
type: io.kestra.plugin.core.flow.Parallel
tasks:
- id: log
type: io.kestra.plugin.core.log.Log
message: Processing
- id: shell
type: io.kestra.plugin.scripts.shell.Commands
commands:
- sleep 5"""
),
}
)
public class ForEach extends Sequential implements FlowableTask<VoidOutput> {
@NotNull
@PluginProperty(dynamic = true)
@Schema(
title = "The list of values for this task.",
description = "The values car be passed as a string, a list of strings, or a list of objects.",
oneOf = {String.class, Object[].class}
)
private Object values;

@NotNull
@Builder.Default
@Schema(
title = "Number of concurrent values that can be running at any point in time.",
description = "If the concurrency is `0`, no limit exist and all the values will start at the same time."
)
@PluginProperty
private final Integer concurrency = 1;

@Override
public GraphCluster tasksTree(Execution execution, TaskRun taskRun, List<String> parentValues) throws IllegalVariableEvaluationException {
GraphCluster subGraph = new GraphCluster(this, taskRun, parentValues, RelationType.DYNAMIC);

GraphUtils.parallel(
subGraph,
this.getTasks(),
this.getErrors(),
taskRun,
execution
);

return subGraph;
}

@Override
public List<ResolvedTask> childTasks(RunContext runContext, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
return FlowableUtils.resolveEachTasks(runContext, parentTaskRun, this.getTasks(), this.values);
}

@Override
public Optional<State.Type> resolveState(RunContext runContext, Execution execution, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
List<ResolvedTask> childTasks = this.childTasks(runContext, parentTaskRun);

if (childTasks.isEmpty()) {
return Optional.of(State.Type.SUCCESS);
}

return FlowableUtils.resolveState(
execution,
childTasks,
FlowableUtils.resolveTasks(this.getErrors(), parentTaskRun),
parentTaskRun,
runContext,
this.isAllowFailure()
);
}

@Override
public List<NextTaskRun> resolveNexts(RunContext runContext, Execution execution, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
if (this.concurrency == 1) {
return FlowableUtils.resolveSequentialNexts(
execution,
this.childTasks(runContext, parentTaskRun),
FlowableUtils.resolveTasks(this.errors, parentTaskRun),
parentTaskRun
);
}

return FlowableUtils.resolveConcurrentNexts(
execution,
FlowableUtils.resolveEachTasks(runContext, parentTaskRun, this.getTasks(), this.values),
FlowableUtils.resolveTasks(this.errors, parentTaskRun),
parentTaskRun,
this.concurrency
);
}
}
38 changes: 38 additions & 0 deletions core/src/test/java/io/kestra/plugin/core/flow/ForEachTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.kestra.plugin.core.flow;

import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.flows.State;
import io.kestra.core.runners.AbstractMemoryRunnerTest;
import org.junit.jupiter.api.Test;

import java.util.concurrent.TimeoutException;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

class ForEachTest extends AbstractMemoryRunnerTest {
@Test
void nonConcurrent() throws TimeoutException {
Execution execution = runnerUtils.runOne(null, "io.kestra.tests", "foreach-non-concurrent");

assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(7));
}

@Test
void concurrent() throws TimeoutException {
Execution execution = runnerUtils.runOne(null, "io.kestra.tests", "foreach-concurrent");

assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(7));
}

@Test
void concurrentWithParallel() throws TimeoutException {
Execution execution = runnerUtils.runOne(null, "io.kestra.tests", "foreach-concurrent-parallel");

assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(10));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
id: foreach-concurrent-parallel
namespace: io.kestra.tests

tasks:
- id: for_each
type: io.kestra.plugin.core.flow.ForEach
values: ["value 1", "value 2", "value 3"]
concurrency: 2
tasks:
- id: parallel
type: io.kestra.plugin.core.flow.Parallel
tasks:
- id: log
type: io.kestra.plugin.core.log.Log
message: Processing
- id: shell
type: io.kestra.plugin.core.log.Log
message: 2nd task
15 changes: 15 additions & 0 deletions core/src/test/resources/flows/valids/foreach-concurrent.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
id: foreach-concurrent
namespace: io.kestra.tests

tasks:
- id: for_each
type: io.kestra.plugin.core.flow.ForEach
values: ["value 1", "value 2", "value 3"]
concurrency: 2
tasks:
- id: log
type: io.kestra.plugin.core.log.Log
message: Processing {{taskrun.value}}
- id: shell
type: io.kestra.plugin.core.log.Log
message: 2nd task
Loading

0 comments on commit bdd916b

Please sign in to comment.