Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/agent_framework] Adding installed plugins validation #290

Merged
merged 5 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,13 @@ public Collection<Object> createComponents(
mlClient,
flowFrameworkIndicesHandler
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(
workflowStepFactory,
threadPool,
clusterService,
client,
settings
);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ public class WorkflowStepValidator {
private static final String INPUTS_FIELD = "inputs";
/** Outputs field name */
private static final String OUTPUTS_FIELD = "outputs";
/** Required Plugins field name */
private static final String REQUIRED_PLUGINS = "required_plugins";

private List<String> inputs;
private List<String> outputs;
private List<String> requiredPlugins;

/**
* Intantiate the object representing a Workflow Step validator
* @param inputs the workflow step inputs
* @param outputs the workflow step outputs
* @param requiredPlugins the required plugins for this workflow step
*/
public WorkflowStepValidator(List<String> inputs, List<String> outputs) {
public WorkflowStepValidator(List<String> inputs, List<String> outputs, List<String> requiredPlugins) {
this.inputs = inputs;
this.outputs = outputs;
this.requiredPlugins = requiredPlugins;
}

/**
Expand All @@ -48,6 +53,7 @@ public WorkflowStepValidator(List<String> inputs, List<String> outputs) {
public static WorkflowStepValidator parse(XContentParser parser) throws IOException {
List<String> parsedInputs = new ArrayList<>();
List<String> parsedOutputs = new ArrayList<>();
List<String> requiredPlugins = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -66,11 +72,17 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept
parsedOutputs.add(parser.text());
}
break;
case REQUIRED_PLUGINS:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
requiredPlugins.add(parser.text());
}
break;
default:
throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object.");
}
}
return new WorkflowStepValidator(parsedInputs, parsedOutputs);
return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins);
}

/**
Expand All @@ -88,4 +100,12 @@ public List<String> getInputs() {
public List<String> getOutputs() {
return List.copyOf(outputs);
}

/**
* Get the required plugins
* @return the outputs
*/
public List<String> getRequiredPlugins() {
return List.copyOf(requiredPlugins);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow,
private void validateWorkflows(Template template) throws Exception {
for (Workflow workflow : template.workflows().values()) {
List<ProcessNode> sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null);
workflowProcessSorter.validateGraph(sortedNodes);
workflowProcessSorter.validate(sortedNodes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
workflowProcessSorter.validateGraph(provisionProcessSequence);
workflowProcessSorter.validate(provisionProcessSequence);

flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(
workflowId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.admin.cluster.node.info.NodeInfo;
import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest;
import org.opensearch.action.admin.cluster.node.info.PluginsAndModules;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.model.WorkflowValidator;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.threadpool.ThreadPool;

import java.util.ArrayDeque;
Expand All @@ -31,6 +37,7 @@
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -49,24 +56,30 @@ public class WorkflowProcessSorter {
private WorkflowStepFactory workflowStepFactory;
private ThreadPool threadPool;
private Integer maxWorkflowSteps;
private ClusterService clusterService;
private Client client;

/**
* Instantiate this class.
*
* @param workflowStepFactory The factory which matches template step types to instances.
* @param threadPool The OpenSearch Thread pool to pass to process nodes.
* @param clusterService The OpenSearch cluster service.
* @param client The OpenSearch Client
* @param settings OpenSerch settings
*/
public WorkflowProcessSorter(
WorkflowStepFactory workflowStepFactory,
ThreadPool threadPool,
ClusterService clusterService,
Client client,
Settings settings
) {
this.workflowStepFactory = workflowStepFactory;
this.threadPool = threadPool;
this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings);
this.clusterService = clusterService;
this.client = client;
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it);
}

Expand Down Expand Up @@ -123,13 +136,75 @@ public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId)
}

/**
* Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs
* Validates inputs and ensures the required plugins are installed for each step in a topologically sorted graph
* @param processNodes the topologically sorted list of process nodes
* @throws Exception if validation fails
*/
public void validate(List<ProcessNode> processNodes) throws Exception {
WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json");
validatePluginsInstalled(processNodes, validator);
validateGraph(processNodes, validator);
}

/**
* Validates a sorted workflow, determines if each process node's required plugins are currently installed
* @param processNodes A list of process nodes
* @param validator The validation definitions for the workflow steps
* @throws Exception on validation failure
*/
public void validateGraph(List<ProcessNode> processNodes) throws Exception {
public void validatePluginsInstalled(List<ProcessNode> processNodes, WorkflowValidator validator) throws Exception {

WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json");
// Retrieve node information to ascertain installed plugins
NodesInfoRequest nodesInfoRequest = new NodesInfoRequest();
nodesInfoRequest.clear().addMetric(NodesInfoRequest.Metric.PLUGINS.metricName());
CompletableFuture<List<String>> installedPluginsFuture = new CompletableFuture<>();
client.admin().cluster().nodesInfo(nodesInfoRequest, ActionListener.wrap(response -> {
List<String> installedPlugins = new ArrayList<>();

// Retrieve installed plugin names from the local node
String localNodeId = clusterService.state().getNodes().getLocalNodeId();
NodeInfo info = response.getNodesMap().get(localNodeId);
PluginsAndModules plugins = info.getInfo(PluginsAndModules.class);
for (PluginInfo pluginInfo : plugins.getPluginInfos()) {
installedPlugins.add(pluginInfo.getName());
}

installedPluginsFuture.complete(installedPlugins);

}, exception -> {
logger.error("Failed to retrieve installed plugins");
installedPluginsFuture.completeExceptionally(exception);
}));

// Block execution until installed plugin list is returned
List<String> installedPlugins = installedPluginsFuture.get();
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved

// Iterate through process nodes in graph
for (ProcessNode processNode : processNodes) {

// Retrieve required plugins of this node based on type
String nodeType = processNode.workflowStep().getName();
List<String> requiredPlugins = new ArrayList<>(validator.getWorkflowStepValidators().get(nodeType).getRequiredPlugins());
if (!installedPlugins.containsAll(requiredPlugins)) {
requiredPlugins.removeAll(installedPlugins);
throw new FlowFrameworkException(
"The workflowStep "
+ processNode.workflowStep().getName()
+ " requires the following plugins to be installed : "
+ requiredPlugins.toString(),
RestStatus.BAD_REQUEST
);
}
}
}

/**
* Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs
* @param processNodes A list of process nodes
* @param validator The validation definitions for the workflow steps
* @throws Exception on validation failure
*/
public void validateGraph(List<ProcessNode> processNodes, WorkflowValidator validator) throws Exception {

// Iterate through process nodes in graph
for (ProcessNode processNode : processNodes) {
Expand Down
42 changes: 39 additions & 3 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"noop": {
"inputs":[],
"outputs":[]
"outputs":[],
"required_plugins":[]
},
"create_index": {
"inputs":[
Expand All @@ -10,7 +11,8 @@
],
"outputs":[
"index_name"
]
],
"required_plugins":[]
},
"create_ingest_pipeline": {
"inputs":[
Expand All @@ -23,7 +25,8 @@
],
"outputs":[
"pipeline_id"
]
],
"required_plugins":[]
},
"create_connector": {
"inputs":[
Expand All @@ -37,6 +40,9 @@
],
"outputs":[
"connector_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"delete_connector": {
Expand All @@ -45,6 +51,9 @@
],
"outputs":[
"connector_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_local_model": {
Expand All @@ -62,6 +71,9 @@
"outputs":[
"model_id",
"register_model_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_remote_model": {
Expand All @@ -73,6 +85,9 @@
"outputs": [
"model_id",
"register_model_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"delete_model": {
Expand All @@ -81,6 +96,9 @@
],
"outputs":[
"model_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"deploy_model": {
Expand All @@ -89,6 +107,9 @@
],
"outputs":[
"deploy_model_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"undeploy_model": {
Expand All @@ -97,6 +118,9 @@
],
"outputs":[
"success"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_model_group": {
Expand All @@ -106,6 +130,9 @@
"outputs":[
"model_group_id",
"model_group_status"
],
"required_plugins":[
"opensearch-ml"
]
},
"register_agent": {
Expand All @@ -121,6 +148,9 @@
],
"outputs":[
"agent_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"delete_agent": {
Expand All @@ -129,6 +159,9 @@
],
"outputs":[
"agent_id"
],
"required_plugins":[
"opensearch-ml"
]
},
"create_tool": {
Expand All @@ -137,6 +170,9 @@
],
"outputs": [
"tools"
],
"required_plugins":[
"opensearch-ml"
]
}
}
Loading
Loading