diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 544f6f3e1..d69c0b588 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -116,7 +116,13 @@ public Collection createComponents( mlClient, flowFrameworkIndicesHandler ); - WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, clusterService, settings); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter( + workflowStepFactory, + threadPool, + clusterService, + client, + settings + ); return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java index eb1779e93..c9689b975 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -25,18 +25,23 @@ public class WorkflowStepValidator { private static final String INPUTS_FIELD = "inputs"; /** Outputs field name */ private static final String OUTPUTS_FIELD = "outputs"; + /** Required Plugins field name */ + private static final String REQUIRED_PLUGINS = "required_plugins"; private List inputs; private List outputs; + private List requiredPlugins; /** * Intantiate the object representing a Workflow Step validator * @param inputs the workflow step inputs * @param outputs the workflow step outputs + * @param requiredPlugins the required plugins for this workflow step */ - public WorkflowStepValidator(List inputs, List outputs) { + public WorkflowStepValidator(List inputs, List outputs, List requiredPlugins) { this.inputs = inputs; this.outputs = outputs; + this.requiredPlugins = requiredPlugins; } /** @@ -48,6 +53,7 @@ public WorkflowStepValidator(List inputs, List outputs) { public static WorkflowStepValidator parse(XContentParser parser) throws IOException { List parsedInputs = new ArrayList<>(); List parsedOutputs = new ArrayList<>(); + List requiredPlugins = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -66,11 +72,17 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept parsedOutputs.add(parser.text()); } break; + case REQUIRED_PLUGINS: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + requiredPlugins.add(parser.text()); + } + break; default: throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object."); } } - return new WorkflowStepValidator(parsedInputs, parsedOutputs); + return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins); } /** @@ -88,4 +100,12 @@ public List getInputs() { public List getOutputs() { return List.copyOf(outputs); } + + /** + * Get the required plugins + * @return the outputs + */ + public List getRequiredPlugins() { + return List.copyOf(requiredPlugins); + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index 765c9cae5..92f89c082 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -264,7 +264,7 @@ protected void checkMaxWorkflows(TimeValue requestTimeOut, Integer maxWorkflow, private void validateWorkflows(Template template) throws Exception { for (Workflow workflow : template.workflows().values()) { List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow, null); - workflowProcessSorter.validateGraph(sortedNodes); + workflowProcessSorter.validate(sortedNodes); } } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index cd4a54a57..ff36cfd1f 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -122,7 +122,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); - workflowProcessSorter.validateGraph(provisionProcessSequence); + workflowProcessSorter.validate(provisionProcessSequence); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( workflowId, diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index da362383b..e564ad456 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,15 +10,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.model.WorkflowValidator; +import org.opensearch.plugins.PluginInfo; import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; @@ -31,6 +37,7 @@ import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -49,6 +56,8 @@ public class WorkflowProcessSorter { private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; private Integer maxWorkflowSteps; + private ClusterService clusterService; + private Client client; /** * Instantiate this class. @@ -56,17 +65,21 @@ public class WorkflowProcessSorter { * @param workflowStepFactory The factory which matches template step types to instances. * @param threadPool The OpenSearch Thread pool to pass to process nodes. * @param clusterService The OpenSearch cluster service. + * @param client The OpenSearch Client * @param settings OpenSerch settings */ public WorkflowProcessSorter( WorkflowStepFactory workflowStepFactory, ThreadPool threadPool, ClusterService clusterService, + Client client, Settings settings ) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings); + this.clusterService = clusterService; + this.client = client; clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it); } @@ -123,13 +136,75 @@ public List sortProcessNodes(Workflow workflow, String workflowId) } /** - * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs + * Validates inputs and ensures the required plugins are installed for each step in a topologically sorted graph + * @param processNodes the topologically sorted list of process nodes + * @throws Exception if validation fails + */ + public void validate(List processNodes) throws Exception { + WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json"); + validatePluginsInstalled(processNodes, validator); + validateGraph(processNodes, validator); + } + + /** + * Validates a sorted workflow, determines if each process node's required plugins are currently installed * @param processNodes A list of process nodes + * @param validator The validation definitions for the workflow steps * @throws Exception on validation failure */ - public void validateGraph(List processNodes) throws Exception { + public void validatePluginsInstalled(List processNodes, WorkflowValidator validator) throws Exception { - WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json"); + // Retrieve node information to ascertain installed plugins + NodesInfoRequest nodesInfoRequest = new NodesInfoRequest(); + nodesInfoRequest.clear().addMetric(NodesInfoRequest.Metric.PLUGINS.metricName()); + CompletableFuture> installedPluginsFuture = new CompletableFuture<>(); + client.admin().cluster().nodesInfo(nodesInfoRequest, ActionListener.wrap(response -> { + List installedPlugins = new ArrayList<>(); + + // Retrieve installed plugin names from the local node + String localNodeId = clusterService.state().getNodes().getLocalNodeId(); + NodeInfo info = response.getNodesMap().get(localNodeId); + PluginsAndModules plugins = info.getInfo(PluginsAndModules.class); + for (PluginInfo pluginInfo : plugins.getPluginInfos()) { + installedPlugins.add(pluginInfo.getName()); + } + + installedPluginsFuture.complete(installedPlugins); + + }, exception -> { + logger.error("Failed to retrieve installed plugins"); + installedPluginsFuture.completeExceptionally(exception); + })); + + // Block execution until installed plugin list is returned + List installedPlugins = installedPluginsFuture.get(); + + // Iterate through process nodes in graph + for (ProcessNode processNode : processNodes) { + + // Retrieve required plugins of this node based on type + String nodeType = processNode.workflowStep().getName(); + List requiredPlugins = new ArrayList<>(validator.getWorkflowStepValidators().get(nodeType).getRequiredPlugins()); + if (!installedPlugins.containsAll(requiredPlugins)) { + requiredPlugins.removeAll(installedPlugins); + throw new FlowFrameworkException( + "The workflowStep " + + processNode.workflowStep().getName() + + " requires the following plugins to be installed : " + + requiredPlugins.toString(), + RestStatus.BAD_REQUEST + ); + } + } + } + + /** + * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs + * @param processNodes A list of process nodes + * @param validator The validation definitions for the workflow steps + * @throws Exception on validation failure + */ + public void validateGraph(List processNodes, WorkflowValidator validator) throws Exception { // Iterate through process nodes in graph for (ProcessNode processNode : processNodes) { diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 149b1cfce..1c6e73a4c 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -1,7 +1,8 @@ { "noop": { "inputs":[], - "outputs":[] + "outputs":[], + "required_plugins":[] }, "create_index": { "inputs":[ @@ -10,7 +11,8 @@ ], "outputs":[ "index_name" - ] + ], + "required_plugins":[] }, "create_ingest_pipeline": { "inputs":[ @@ -23,7 +25,8 @@ ], "outputs":[ "pipeline_id" - ] + ], + "required_plugins":[] }, "create_connector": { "inputs":[ @@ -37,6 +40,9 @@ ], "outputs":[ "connector_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "delete_connector": { @@ -45,6 +51,9 @@ ], "outputs":[ "connector_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_local_model": { @@ -62,6 +71,9 @@ "outputs":[ "model_id", "register_model_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_remote_model": { @@ -73,6 +85,9 @@ "outputs": [ "model_id", "register_model_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "delete_model": { @@ -81,6 +96,9 @@ ], "outputs":[ "model_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "deploy_model": { @@ -89,6 +107,9 @@ ], "outputs":[ "deploy_model_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "undeploy_model": { @@ -97,6 +118,9 @@ ], "outputs":[ "success" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_model_group": { @@ -106,6 +130,9 @@ "outputs":[ "model_group_id", "model_group_status" + ], + "required_plugins":[ + "opensearch-ml" ] }, "register_agent": { @@ -121,6 +148,9 @@ ], "outputs":[ "agent_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "delete_agent": { @@ -129,6 +159,9 @@ ], "outputs":[ "agent_id" + ], + "required_plugins":[ + "opensearch-ml" ] }, "create_tool": { @@ -137,6 +170,9 @@ ], "outputs": [ "tools" + ], + "required_plugins":[ + "opensearch-ml" ] } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 2e67b59d8..0addb7f1f 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -29,8 +29,6 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; -import org.opensearch.flowframework.workflow.WorkflowStepFactory; -import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -58,6 +56,8 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -70,7 +70,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private WorkflowProcessSorter workflowProcessSorter; private Template template; - private Client client = mock(Client.class); + private Client client; private ThreadPool threadPool; private ClusterSettings clusterSettings; private ClusterService clusterService; @@ -79,6 +79,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); + client = mock(Client.class); + threadPool = mock(ThreadPool.class); settings = Settings.builder() .put("plugins.flow_framework.max_workflows", 2) @@ -93,15 +95,10 @@ public void setUp() throws Exception { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); - MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); - WorkflowStepFactory factory = new WorkflowStepFactory( - Settings.EMPTY, - clusterService, - client, - mlClient, - flowFrameworkIndicesHandler - ); - this.workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool, clusterService, settings); + // Validation functionality should not be invoked in these unit tests, mocking instead + this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + + // Spy this action to stub check max workflows this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( mock(TransportService.class), @@ -150,7 +147,7 @@ public void testDryRunValidation_withoutProvision_Success() { createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); } - public void testDryRunValidation_Failed() { + public void testDryRunValidation_Failed() throws Exception { WorkflowNode createConnector = new WorkflowNode( "workflow_step_1", @@ -204,12 +201,12 @@ public void testDryRunValidation_Failed() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); + // Stub validation failure + doThrow(Exception.class).when(workflowProcessSorter).validate(any()); WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true, false, null, null); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage()); + verify(listener, times(1)).onFailure(any()); } public void testMaxWorkflow() { @@ -377,12 +374,14 @@ public void testUpdateWorkflow() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } - public void testCreateWorkflow_withDryRun_withProvision_Success() { + public void testCreateWorkflow_withDryRun_withProvision_Success() throws Exception { Template validTemplate = generateValidTemplate(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); + + doNothing().when(workflowProcessSorter).validate(any()); WorkflowRequest workflowRequest = new WorkflowRequest( null, validTemplate, @@ -436,11 +435,13 @@ public void testCreateWorkflow_withDryRun_withProvision_Success() { assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); } - public void testCreateWorkflow_withDryRun_withProvision_FailedProvisioning() { + public void testCreateWorkflow_withDryRun_withProvision_FailedProvisioning() throws Exception { + Template validTemplate = generateValidTemplate(); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); + doNothing().when(workflowProcessSorter).validate(any()); WorkflowRequest workflowRequest = new WorkflowRequest( null, validTemplate, diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index d1590acd8..2974470aa 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -8,12 +8,20 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkSettings; @@ -23,7 +31,9 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.plugins.PluginInfo; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -50,6 +60,8 @@ import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -73,12 +85,13 @@ private static List parse(String json) throws IOException { private static TestThreadPool testThreadPool; private static WorkflowProcessSorter workflowProcessSorter; + private static Client client = mock(Client.class); + private static ClusterService clusterService = mock(ClusterService.class); + private static WorkflowValidator validator; @BeforeClass - public static void setup() { + public static void setup() throws IOException { AdminClient adminClient = mock(AdminClient.class); - ClusterService clusterService = mock(ClusterService.class); - Client client = mock(Client.class); MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); @@ -100,7 +113,8 @@ public static void setup() { mlClient, flowFrameworkIndicesHandler ); - workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, settings); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, settings); + validator = WorkflowValidator.parse("mappings/workflow-steps.json"); } @AfterClass @@ -300,7 +314,7 @@ public void testSuccessfulGraphValidation() throws Exception { Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); - workflowProcessSorter.validateGraph(sortedProcessNodes); + workflowProcessSorter.validateGraph(sortedProcessNodes, validator); } public void testFailedGraphValidation() { @@ -324,9 +338,175 @@ public void testFailedGraphValidation() { List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); FlowFrameworkException ex = expectThrows( FlowFrameworkException.class, - () -> workflowProcessSorter.validateGraph(sortedProcessNodes) + () -> workflowProcessSorter.validateGraph(sortedProcessNodes, validator) ); assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } + + public void testSuccessfulInstalledPluginValidation() throws Exception { + + // Mock and stub the cluster admin client to invoke the NodesInfoRequest + AdminClient adminClient = mock(AdminClient.class); + ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + // Mock and stub the clusterservice to get the local node + ClusterState clusterState = mock(ClusterState.class); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getLocalNodeId()).thenReturn("123"); + + // Stub cluster admin client's node info request + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + + // Mock and stub Plugin info + PluginInfo mockedFlowPluginInfo = mock(PluginInfo.class); + PluginInfo mockedMlPluginInfo = mock(PluginInfo.class); + when(mockedFlowPluginInfo.getName()).thenReturn("opensearch-flow-framework"); + when(mockedMlPluginInfo.getName()).thenReturn("opensearch-ml"); + + // Mock and stub PluginsAndModules + PluginsAndModules mockedPluginsAndModules = mock(PluginsAndModules.class); + when(mockedPluginsAndModules.getPluginInfos()).thenReturn(List.of(mockedFlowPluginInfo, mockedMlPluginInfo)); + + // Mock and stub NodesInfoResponse to NodeInfo + NodeInfo nodeInfo = mock(NodeInfo.class); + @SuppressWarnings("unchecked") + Map mockedMap = mock(Map.class); + NodesInfoResponse response = mock(NodesInfoResponse.class); + when(response.getNodesMap()).thenReturn(mockedMap); + when(mockedMap.get(any())).thenReturn(nodeInfo); + when(nodeInfo.getInfo(any())).thenReturn(mockedPluginsAndModules); + + // stub on response to pass the mocked NodesInfoRepsonse + listener.onResponse(response); + return null; + + }).when(clusterAdminClient).nodesInfo(any(NodesInfoRequest.class), any()); + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterRemoteModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + + workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, validator); + } + + public void testFailedInstalledPluginValidation() throws Exception { + + // Mock and stub the cluster admin client to invoke the NodesInfoRequest + AdminClient adminClient = mock(AdminClient.class); + ClusterAdminClient clusterAdminClient = mock(ClusterAdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + // Mock and stub the clusterservice to get the local node + ClusterState clusterState = mock(ClusterState.class); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getLocalNodeId()).thenReturn("123"); + + // Stub cluster admin client's node info request + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + + // Mock and stub Plugin info, We ommit the opensearch-ml info here to trigger validation failure + PluginInfo mockedFlowPluginInfo = mock(PluginInfo.class); + when(mockedFlowPluginInfo.getName()).thenReturn("opensearch-flow-framework"); + + // Mock and stub PluginsAndModules + PluginsAndModules mockedPluginsAndModules = mock(PluginsAndModules.class); + when(mockedPluginsAndModules.getPluginInfos()).thenReturn(List.of(mockedFlowPluginInfo)); + + // Mock and stub NodesInfoResponse to NodeInfo + NodeInfo nodeInfo = mock(NodeInfo.class); + @SuppressWarnings("unchecked") + Map mockedMap = mock(Map.class); + NodesInfoResponse response = mock(NodesInfoResponse.class); + when(response.getNodesMap()).thenReturn(mockedMap); + when(mockedMap.get(any())).thenReturn(nodeInfo); + when(nodeInfo.getInfo(any())).thenReturn(mockedPluginsAndModules); + + // stub on response to pass the mocked NodesInfoRepsonse + listener.onResponse(response); + return null; + + }).when(clusterAdminClient).nodesInfo(any(NodesInfoRequest.class), any()); + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterRemoteModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + DeployModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + + Workflow workflow = new Workflow(Map.of(), List.of(createConnector, registerModel, deployModel), List.of(edge1, edge2)); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); + + FlowFrameworkException exception = expectThrows( + FlowFrameworkException.class, + () -> workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, validator) + ); + + assertEquals( + "The workflowStep create_connector requires the following plugins to be installed : [opensearch-ml]", + exception.getMessage() + ); + } }