Skip to content

Commit

Permalink
[Feature/agent_framework] Adding installed plugins validation (opense…
Browse files Browse the repository at this point in the history
…arch-project#290)

* Adding installed plugins validation

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Adding failure success unit tests

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Combining graph and installed plugin validation

Signed-off-by: Joshua Palis <jpalis@amazon.com>

* Removing stray comment

Signed-off-by: Joshua Palis <jpalis@amazon.com>

---------

Signed-off-by: Joshua Palis <jpalis@amazon.com>
  • Loading branch information
joshpalis authored and dbwiddis committed Dec 18, 2023
1 parent ddfec4b commit 4bb1c30
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ public Collection<Object> createComponents(
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(
workflowStepFactory,
threadPool,
clusterService,
client,
settings
);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ public class WorkflowStepValidator {
private static final String INPUTS_FIELD = "inputs";
/** Outputs field name */
private static final String OUTPUTS_FIELD = "outputs";
/** Required Plugins field name */
private static final String REQUIRED_PLUGINS = "required_plugins";

private List<String> inputs;
private List<String> outputs;
private List<String> requiredPlugins;

/**
* Intantiate the object representing a Workflow Step validator
* @param inputs the workflow step inputs
* @param outputs the workflow step outputs
* @param requiredPlugins the required plugins for this workflow step
*/
public WorkflowStepValidator(List<String> inputs, List<String> outputs) {
public WorkflowStepValidator(List<String> inputs, List<String> outputs, List<String> requiredPlugins) {
this.inputs = inputs;
this.outputs = outputs;
this.requiredPlugins = requiredPlugins;
}

/**
Expand All @@ -48,6 +53,7 @@ public WorkflowStepValidator(List<String> inputs, List<String> outputs) {
public static WorkflowStepValidator parse(XContentParser parser) throws IOException {
List<String> parsedInputs = new ArrayList<>();
List<String> parsedOutputs = new ArrayList<>();
List<String> requiredPlugins = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -66,11 +72,17 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept
parsedOutputs.add(parser.text());
}
break;
case REQUIRED_PLUGINS:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
requiredPlugins.add(parser.text());
}
break;
default:
throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object.");
}
}
return new WorkflowStepValidator(parsedInputs, parsedOutputs);
return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins);
}

/**
Expand All @@ -88,4 +100,12 @@ public List<String> getInputs() {
public List<String> getOutputs() {
return List.copyOf(outputs);
}

/**
* Get the required plugins
* @return the outputs
*/
public List<String> getRequiredPlugins() {
return List.copyOf(requiredPlugins);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow,
private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);
workflowProcessSorter.validateGraph(sortedNodes);
workflowProcessSorter.validate(sortedNodes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
workflowProcessSorter.validateGraph(provisionProcessSequence);
workflowProcessSorter.validate(provisionProcessSequence);

flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
workflowId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.admin.cluster.node.info.NodeInfo;
import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest;
import org.opensearch.action.admin.cluster.node.info.PluginsAndModules;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.model.WorkflowValidator;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayDeque;
Expand All @@ -31,6 +37,7 @@
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -49,24 +56,30 @@ public class WorkflowProcessSorter {
private WorkflowStepFactory workflowStepFactory;
private ThreadPool threadPool;
private Integer maxWorkflowSteps;
private ClusterService clusterService;
private Client client;

/**
* Instantiate this class.
*
* @param workflowStepFactory The factory which matches template step types to instances.
* @param threadPool The OpenSearch Thread pool to pass to process nodes.
* @param clusterService The OpenSearch cluster service.
* @param client The OpenSearch Client
* @param settings OpenSerch settings
*/
public WorkflowProcessSorter(
WorkflowStepFactory workflowStepFactory,
ThreadPool threadPool,
ClusterService clusterService,
Client client,
Settings settings
) {
this.workflowStepFactory = workflowStepFactory;
this.threadPool = threadPool;
this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings);
this.clusterService = clusterService;
this.client = client;
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it);
}

Expand Down Expand Up @@ -123,13 +136,75 @@ public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId)
}

/**
* Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs
* Validates inputs and ensures the required plugins are installed for each step in a topologically sorted graph
* @param processNodes the topologically sorted list of process nodes
* @throws Exception if validation fails
*/
public void validate(List<ProcessNode> processNodes) throws Exception {
WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json");
validatePluginsInstalled(processNodes, validator);
validateGraph(processNodes, validator);
}

/**
* Validates a sorted workflow, determines if each process node's required plugins are currently installed
* @param processNodes A list of process nodes
* @param validator The validation definitions for the workflow steps
* @throws Exception on validation failure
*/
public void validateGraph(List<ProcessNode> processNodes) throws Exception {
public void validatePluginsInstalled(List<ProcessNode> processNodes, WorkflowValidator validator) throws Exception {

WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json");
// Retrieve node information to ascertain installed plugins
NodesInfoRequest nodesInfoRequest = new NodesInfoRequest();
nodesInfoRequest.clear().addMetric(NodesInfoRequest.Metric.PLUGINS.metricName());
CompletableFuture<List<String>> installedPluginsFuture = new CompletableFuture<>();
client.admin().cluster().nodesInfo(nodesInfoRequest, ActionListener.wrap(response -> {
List<String> installedPlugins = new ArrayList<>();

// Retrieve installed plugin names from the local node
String localNodeId = clusterService.state().getNodes().getLocalNodeId();
NodeInfo info = response.getNodesMap().get(localNodeId);
PluginsAndModules plugins = info.getInfo(PluginsAndModules.class);
for (PluginInfo pluginInfo : plugins.getPluginInfos()) {
installedPlugins.add(pluginInfo.getName());
}

installedPluginsFuture.complete(installedPlugins);

}, exception -> {
logger.error("Failed to retrieve installed plugins");
installedPluginsFuture.completeExceptionally(exception);
}));

// Block execution until installed plugin list is returned
List<String> installedPlugins = installedPluginsFuture.get();

// Iterate through process nodes in graph
for (ProcessNode processNode : processNodes) {

// Retrieve required plugins of this node based on type
String nodeType = processNode.workflowStep().getName();
List<String> requiredPlugins = new ArrayList<>(validator.getWorkflowStepValidators().get(nodeType).getRequiredPlugins());
if (!installedPlugins.containsAll(requiredPlugins)) {
requiredPlugins.removeAll(installedPlugins);
throw new FlowFrameworkException(
"The workflowStep "
+ processNode.workflowStep().getName()
+ " requires the following plugins to be installed : "
+ requiredPlugins.toString(),
RestStatus.BAD_REQUEST
);
}
}
}

/**
* Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs
* @param processNodes A list of process nodes
* @param validator The validation definitions for the workflow steps
* @throws Exception on validation failure
*/
public void validateGraph(List<ProcessNode> processNodes, WorkflowValidator validator) throws Exception {

// Iterate through process nodes in graph
for (ProcessNode processNode : processNodes) {
Expand Down
42 changes: 39 additions & 3 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"noop": {
"inputs":[],
"outputs":[]
"outputs":[],
"required_plugins":[]
},
"create_index": {
"inputs":[
Expand All @@ -10,7 +11,8 @@
],
"outputs":[
"index_name"
]
],
"required_plugins":[]
},
"create_ingest_pipeline": {
"inputs":[
Expand All @@ -23,7 +25,8 @@
],
"outputs":[
"pipeline_id"
]
],
"required_plugins":[]
},
"create_connector": {
"inputs":[
Expand All @@ -37,6 +40,9 @@
],
"outputs":[
"connector_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"delete_connector": {
Expand All @@ -45,6 +51,9 @@
],
"outputs":[
"connector_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_local_model": {
Expand All @@ -62,6 +71,9 @@
"outputs":[
"model_id",
"register_model_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_remote_model": {
Expand All @@ -73,6 +85,9 @@
"outputs": [
"model_id",
"register_model_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"delete_model": {
Expand All @@ -81,6 +96,9 @@
],
"outputs":[
"model_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"deploy_model": {
Expand All @@ -89,6 +107,9 @@
],
"outputs":[
"deploy_model_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"undeploy_model": {
Expand All @@ -97,6 +118,9 @@
],
"outputs":[
"success"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_model_group": {
Expand All @@ -106,6 +130,9 @@
"outputs":[
"model_group_id",
"model_group_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_agent": {
Expand All @@ -121,6 +148,9 @@
],
"outputs":[
"agent_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"delete_agent": {
Expand All @@ -129,6 +159,9 @@
],
"outputs":[
"agent_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"create_tool": {
Expand All @@ -137,6 +170,9 @@
],
"outputs": [
"tools"
],
"required_plugins":[
"opensearch-ml"
]
}
}
Loading

0 comments on commit 4bb1c30

Please sign in to comment.