Skip to content

Commit

Permalink
Infer edges from previous node inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
dbwiddis committed Dec 28, 2023
1 parent 89c03f9 commit 6d4e94c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
16 changes: 9 additions & 7 deletions src/main/java/org/opensearch/flowframework/model/Workflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand Down Expand Up @@ -119,13 +120,14 @@ public static Workflow parse(XContentParser parser) throws IOException {
if (nodes.isEmpty()) {
throw new IOException("A workflow must have at least one node.");
}
if (edges.isEmpty()) {
// infer edges from sequence of nodes
// Start iteration at 1, will skip for a one-node array
for (int i = 1; i < nodes.size(); i++) {
edges.add(new WorkflowEdge(nodes.get(i - 1).id(), nodes.get(i).id()));
}
}
// Iterate the nodes and infer edges from previous node inputs
List<WorkflowEdge> inferredEdges = nodes.stream()
.flatMap(node -> node.previousNodeInputs().keySet().stream().map(previousNode -> new WorkflowEdge(previousNode, node.id())))
.collect(Collectors.toList());
// Remove any that are already in edges list
inferredEdges.removeAll(edges);
// Then add them to the edges
edges.addAll(inferredEdges);
return new Workflow(userParams, nodes, edges);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.flowframework.workflow.NoOpStep;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -53,6 +54,22 @@ public static String nodeWithTypeAndTimeout(String id, String type, String timeo
+ "\"}}";
}

public static String nodeWithTypeAndPreviousNodes(String id, String type, String... previousNodes) {
return "{\""
+ WorkflowNode.ID_FIELD
+ "\": \""
+ id
+ "\", \""
+ WorkflowNode.TYPE_FIELD
+ "\": \""
+ type
+ "\", \""
+ WorkflowNode.PREVIOUS_NODE_INPUTS_FIELD
+ "\": {"
+ Arrays.stream(previousNodes).map(n -> "\"" + n + "\": \"output_value\"").collect(Collectors.joining(","))
+ "}}";
}

public static String edge(String sourceId, String destId) {
return "{\"" + WorkflowEdge.SOURCE_FIELD + "\": \"" + sourceId + "\", \"" + WorkflowEdge.DEST_FIELD + "\": \"" + destId + "\"}";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand Down Expand Up @@ -59,6 +58,7 @@
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndPreviousNodes;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow;
import static org.mockito.ArgumentMatchers.any;
Expand All @@ -72,11 +72,14 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase {
private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor.";
private static final String CYCLE_DETECTED = "Cycle detected:";

// Wrap parser into workflow
private static Workflow parseToWorkflow(String json) throws IOException {
return Workflow.parse(TemplateTestJsonUtil.jsonToParser(json));
}

// Wrap parser into node list
private static List<ProcessNode> parseToNodes(String json) throws IOException {
XContentParser parser = TemplateTestJsonUtil.jsonToParser(json);
Workflow w = Workflow.parse(parser);
return workflowProcessSorter.sortProcessNodes(w, "123");
return workflowProcessSorter.sortProcessNodes(parseToWorkflow(json), "123");
}

// Wrap parser into string list
Expand Down Expand Up @@ -242,6 +245,56 @@ public void testNoEdges() throws IOException {
assertTrue(workflow.contains("B"));
}

public void testInferredEdges() throws IOException {
Workflow w = parseToWorkflow(
workflow(List.of(nodeWithTypeAndPreviousNodes("A", "noop"), nodeWithTypeAndPreviousNodes("B", "noop")), Collections.emptyList())
);
assertTrue(w.edges().isEmpty());

w = parseToWorkflow(
workflow(List.of(nodeWithTypeAndPreviousNodes("A", "noop"), nodeWithTypeAndPreviousNodes("B", "noop")), List.of(edge("B", "A")))
);
// edge from previous inputs only
assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges());

w = parseToWorkflow(
workflow(
List.of(nodeWithTypeAndPreviousNodes("A", "noop", "B"), nodeWithTypeAndPreviousNodes("B", "noop")),
Collections.emptyList()
)
);
// edge from edges only
assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges());

w = parseToWorkflow(
workflow(
List.of(
nodeWithTypeAndPreviousNodes("A", "noop", "B"),
nodeWithTypeAndPreviousNodes("B", "noop"),
nodeWithTypeAndPreviousNodes("C", "noop")
),
List.of(edge("C", "A"))
)
);
// combine sources, order not guaranteed
assertEquals(2, w.edges().size());
assertTrue(w.edges().contains(new WorkflowEdge("B", "A")));
assertTrue(w.edges().contains(new WorkflowEdge("C", "A")));

w = parseToWorkflow(
workflow(
List.of(
nodeWithTypeAndPreviousNodes("A", "noop", "B"),
nodeWithTypeAndPreviousNodes("B", "noop"),
nodeWithTypeAndPreviousNodes("C", "noop")
),
List.of(edge("B", "A"))
)
);
// duplicates, only 1
assertEquals(List.of(new WorkflowEdge("B", "A")), w.edges());
}

public void testExceptions() throws IOException {
Exception ex = assertThrows(
FlowFrameworkException.class,
Expand Down

0 comments on commit 6d4e94c

Please sign in to comment.