diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 2343cd305..0863565c0 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -73,6 +73,8 @@ private CommonValue() {} /** The provision workflow thread pool name */ public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision"; + /** Success name field */ + public static final String SUCCESS = "success"; /** Index name field */ public static final String INDEX_NAME = "index_name"; /** Type field */ diff --git a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java new file mode 100644 index 000000000..cfb683648 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchException; +import org.opensearch.action.FailedNodeException; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.SUCCESS; + +/** + * Step to undeploy model + */ +public class UndeployModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(UndeployModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "undeploy_model"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the undeploy + */ + public UndeployModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture undeployModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { + List failures = mlUndeployModelsResponse.getResponse().failures(); + if (failures.isEmpty()) { + undeployModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(SUCCESS, !mlUndeployModelsResponse.getResponse().hasFailures())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } else { + List failedNodes = failures.stream().map(FailedNodeException::nodeId).collect(Collectors.toList()); + String message = "Failed to undeploy model on nodes " + failedNodes; + logger.error(message); + undeployModelFuture.completeExceptionally(new OpenSearchException(message)); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to unldeploy model"); + undeployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(MODEL_ID); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String modelId = inputs.get(MODEL_ID).toString(); + + mlClient.undeploy(new String[] { modelId }, null, actionListener); + } catch (FlowFrameworkException e) { + undeployModelFuture.completeExceptionally(e); + } + return undeployModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index ce0b24d24..4b197d99b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -51,6 +51,7 @@ public WorkflowStepFactory( ); stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); + stepMap.put(UndeployModelStep.NAME, new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index b5d09e8cb..d0b05e9fa 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -83,6 +83,14 @@ "deploy_model_status" ] }, + "undeploy_model": { + "inputs":[ + "model_id" + ], + "outputs":[ + "success" + ] + }, "register_model_group": { "inputs":[ "name" diff --git a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java new file mode 100644 index 000000000..1a5fef445 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.SUCCESS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class UndeployModelStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + } + + public void testUndeployModel() throws IOException, ExecutionException, InterruptedException { + + String modelId = randomAlphaOfLength(5); + UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ClusterName clusterName = new ClusterName("clusterName"); + ActionListener actionListener = invocation.getArgument(2); + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( + clusterName, + Collections.emptyList(), + Collections.emptyList() + ); + MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + CompletableFuture future = UndeployModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), + Map.of("step_1", MODEL_ID) + ); + verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + assertTrue(future.isDone()); + assertTrue((boolean) future.get().getContent().get(SUCCESS)); + } + + public void testNoModelIdInOutput() throws IOException { + UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); + + CompletableFuture future = UndeployModelStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); + } + + public void testUndeployModelFailure() throws IOException { + UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ClusterName clusterName = new ClusterName("clusterName"); + ActionListener actionListener = invocation.getArgument(2); + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( + clusterName, + Collections.emptyList(), + List.of(new FailedNodeException("failed-node", "Test message", null)) + ); + MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); + actionListener.onResponse(output); + + actionListener.onFailure(new FlowFrameworkException("Failed to undeploy model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + CompletableFuture future = UndeployModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), + Map.of("step_1", MODEL_ID) + ); + + verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof OpenSearchException); + assertEquals("Failed to undeploy model on nodes [failed-node]", ex.getCause().getMessage()); + } +}