Skip to content

Commit

Permalink
[Feature/agent_framework] Add Delete Model Step (#237)
Browse files Browse the repository at this point in the history
Add Delete Model Step

Signed-off-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
dbwiddis committed Dec 18, 2023
1 parent e99f83c commit aae49ab
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;

/**
* Step to delete a model for a remote model
*/
public class DeleteModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(DeleteModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "delete_model";

/**
* Instantiate this class
* @param mlClient Machine Learning client to perform the deletion
*/
public DeleteModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
CompletableFuture<WorkflowData> deleteModelFuture = new CompletableFuture<>();

ActionListener<DeleteResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, deleteResponse.getId())),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to delete model");
deleteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

Set<String> requiredKeys = Set.of(MODEL_ID);
Set<String> optionalKeys = Collections.emptySet();

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

String modelId = inputs.get(MODEL_ID).toString();

mlClient.deleteModel(modelId, actionListener);
} catch (FlowFrameworkException e) {
deleteModelFuture.completeExceptionally(e);
}
return deleteModelFuture;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public WorkflowStepFactory(
() -> new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler)
);
stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient));
stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
Expand Down
8 changes: 8 additions & 0 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@
"register_model_status"
]
},
"delete_model": {
"inputs": [
"model_id"
],
"outputs":[
"model_id"
]
},
"deploy_model": {
"inputs":[
"model_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -92,7 +91,7 @@ public void testDeleteConnectorFailure() throws IOException {
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteConnector(any(String.class), any());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class DeleteModelStepTests extends OpenSearchTestCase {
private WorkflowData inputData;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id");
}

public void testDeleteModel() throws IOException, ExecutionException, InterruptedException {

String modelId = randomAlphaOfLength(5);
DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient);

doAnswer(invocation -> {
String modelIdArg = invocation.getArgument(0);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, modelIdArg, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deleteModel(any(String.class), any());

CompletableFuture<WorkflowData> future = deleteModelStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of("model_id", modelId), "workflowId", "nodeId")),
Map.of("step_1", "model_id")
);
verify(machineLearningNodeClient).deleteModel(any(String.class), any());

assertTrue(future.isDone());
assertEquals(modelId, future.get().getContent().get("model_id"));
}

public void testNoModelIdInOutput() throws IOException {
DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient);

CompletableFuture<WorkflowData> future = deleteModelStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage());
}

public void testDeleteModelFailure() throws IOException {
DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteModel(any(String.class), any());

CompletableFuture<WorkflowData> future = deleteModelStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of("model_id", "test"), "workflowId", "nodeId")),
Map.of("step_1", "model_id")
);

verify(machineLearningNodeClient).deleteModel(any(String.class), any());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to delete model", ex.getCause().getMessage());
}
}

0 comments on commit aae49ab

Please sign in to comment.