diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java new file mode 100644 index 000000000..44fc5c8d7 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; + +/** + * Step to delete a model for a remote model + */ +public class DeleteModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteModelStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_model"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture deleteModelFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteModelFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry(MODEL_ID, deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete model"); + deleteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(MODEL_ID); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String modelId = inputs.get(MODEL_ID).toString(); + + mlClient.deleteModel(modelId, actionListener); + } catch (FlowFrameworkException e) { + deleteModelFuture.completeExceptionally(e); + } + return deleteModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 1b1875177..3e8ef2981 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -51,6 +51,7 @@ public WorkflowStepFactory( () -> new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) ); stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient)); stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index d0b05e9fa..e3263d9a2 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -75,6 +75,14 @@ "register_model_status" ] }, + "delete_model": { + "inputs": [ + "model_id" + ], + "outputs":[ + "model_id" + ] + }, "deploy_model": { "inputs":[ "model_id" diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index a766d51c9..d94f3d793 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -15,7 +15,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -92,7 +91,7 @@ public void testDeleteConnectorFailure() throws IOException { DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); + ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR)); return null; }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java new file mode 100644 index 000000000..59d92e94b --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteModelStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + } + + public void testDeleteModel() throws IOException, ExecutionException, InterruptedException { + + String modelId = randomAlphaOfLength(5); + DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + String modelIdArg = invocation.getArgument(0); + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, modelIdArg, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteModel(any(String.class), any()); + + CompletableFuture future = deleteModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("model_id", modelId), "workflowId", "nodeId")), + Map.of("step_1", "model_id") + ); + verify(machineLearningNodeClient).deleteModel(any(String.class), any()); + + assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get("model_id")); + } + + public void testNoModelIdInOutput() throws IOException { + DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); + + CompletableFuture future = deleteModelStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); + } + + public void testDeleteModelFailure() throws IOException { + DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteModel(any(String.class), any()); + + CompletableFuture future = deleteModelStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("model_id", "test"), "workflowId", "nodeId")), + Map.of("step_1", "model_id") + ); + + verify(machineLearningNodeClient).deleteModel(any(String.class), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to delete model", ex.getCause().getMessage()); + } +}