diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index 04a8650b2..5258fb4d7 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -30,6 +30,8 @@ public enum WorkflowResources { REGISTER_LOCAL_MODEL("register_local_model", "model_id"), /** official workflow step name for registering a model group and associated created resource */ REGISTER_MODEL_GROUP("register_model_group", "model_group_id"), + /** official workflow step name for deploying a model and associated created resource */ + DEPLOY_MODEL("deploy_model", "model_id"), /** official workflow step name for creating an ingest-pipeline and associated created resource */ CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), /** official workflow step name for creating an index and associated created resource */ diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index b942ccb16..42d59e07f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -47,7 +47,7 @@ public class WorkflowNode implements ToXContentObject { /** The field defining the timeout value for this node */ public static final String NODE_TIMEOUT_FIELD = "node_timeout"; /** The default timeout value if the template doesn't override it */ - public static final String NODE_TIMEOUT_DEFAULT_VALUE = "10s"; + public static final String NODE_TIMEOUT_DEFAULT_VALUE = "15s"; private final String id; // unique id private final String type; // maps to a WorkflowStep diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index 799edabb9..48b7a0042 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -8,27 +8,133 @@ */ package org.opensearch.flowframework.workflow; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.FutureUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.WorkflowResources; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; /** * Abstract retryable workflow step */ public abstract class AbstractRetryableWorkflowStep implements WorkflowStep { - + private static final Logger logger = LogManager.getLogger(AbstractRetryableWorkflowStep.class); /** The maximum number of transport request retries */ protected volatile Integer maxRetry; + private final MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiates a new Retryable workflow step * @param settings Environment settings * @param clusterService the cluster service + * @param mlClient machine learning client + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public AbstractRetryableWorkflowStep(Settings settings, ClusterService clusterService) { + public AbstractRetryableWorkflowStep( + Settings settings, + ClusterService clusterService, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { this.maxRetry = MAX_GET_TASK_REQUEST_RETRY.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_GET_TASK_REQUEST_RETRY, it -> maxRetry = it); + this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + } + + /** + * Retryable get ml task + * @param workflowId the workflow id + * @param nodeId the workflow node id + * @param future the workflow step future + * @param taskId the ml task id + * @param retries the current number of request retries + * @param workflowStep the workflow step which requires a retry get ml task functionality + */ + void retryableGetMlTask( + String workflowId, + String nodeId, + CompletableFuture future, + String taskId, + int retries, + String workflowStep + ) { + mlClient.getTask(taskId, ActionListener.wrap(response -> { + MLTaskState currentState = response.getState(); + if (currentState != MLTaskState.COMPLETED) { + if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) { + // Model registration failed or completed with errors + String errorMessage = workflowStep + " failed with error : " + response.getError(); + logger.error(errorMessage); + future.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + } else { + // Task still in progress, attempt retry + throw new IllegalStateException(workflowStep + " is not yet completed"); + } + } else { + try { + logger.info(workflowStep + " successful for {} and modelId {}", workflowId, response.getModelId()); + String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); + flowFrameworkIndicesHandler.updateResourceInStateIndex( + workflowId, + nodeId, + getName(), + response.getTaskId(), + ActionListener.wrap(updateResponse -> { + logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); + future.complete( + new WorkflowData( + Map.ofEntries( + Map.entry(resourceName, response.getModelId()), + Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) + ), + workflowId, + nodeId + ) + ); + }, exception -> { + logger.error("Failed to update new created resource", exception); + future.completeExceptionally( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + }) + ); + + } catch (Exception e) { + logger.error("Failed to parse and update new created resource", e); + future.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + } + }, exception -> { + if (retries < maxRetry) { + // Sleep thread prior to retrying request + try { + Thread.sleep(5000); + } catch (Exception e) { + FutureUtils.cancel(future); + } + retryableGetMlTask(workflowId, nodeId, future, taskId, retries + 1, workflowStep); + } else { + logger.error("Failed to retrieve" + workflowStep + ",maximum retries exceeded"); + future.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } + })); } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index f878fbdc2..b9307b046 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -11,8 +11,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -27,18 +30,29 @@ /** * Step to deploy a model */ -public class DeployModelStep implements WorkflowStep { +public class DeployModelStep extends AbstractRetryableWorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); private final MachineLearningNodeClient mlClient; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; static final String NAME = "deploy_model"; /** * Instantiate this class + * @param settings The OpenSearch settings + * @param clusterService The cluster service * @param mlClient client to instantiate MLClient + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices */ - public DeployModelStep(MachineLearningNodeClient mlClient) { + public DeployModelStep( + Settings settings, + ClusterService clusterService, + MachineLearningNodeClient mlClient, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler + ) { + super(settings, clusterService, mlClient, flowFrameworkIndicesHandler); this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @Override @@ -55,13 +69,10 @@ public CompletableFuture execute( @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); - deployModelFuture.complete( - new WorkflowData( - Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), - currentNodeInputs.getWorkflowId(), - currentNodeInputs.getNodeId() - ) - ); + String taskId = mlDeployModelResponse.getTaskId(); + + // Attempt to retrieve the model ID + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, deployModelFuture, taskId, 0, "Deploy model"); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index e2aea19df..fbf907776 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -72,7 +73,7 @@ public CompletableFuture execute( @Override public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse) { try { - logger.info("Remote Model registration successful"); + logger.info("Model group registration successful"); String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); flowFrameworkIndicesHandler.updateResourceInStateIndex( currentNodeInputs.getWorkflowId(), @@ -134,7 +135,7 @@ public void onFailure(Exception e) { if (description != null) { builder.description(description); } - if (!backendRoles.isEmpty()) { + if (!CollectionUtils.isEmpty(backendRoles)) { builder.backendRoles(backendRoles); } if (modelAccessMode != null) { @@ -160,6 +161,9 @@ public String getName() { @SuppressWarnings("unchecked") private List getBackendRoles(Map content) { - return (List) content.get(BACKEND_ROLES_FIELD); + if (content.containsKey(BACKEND_ROLES_FIELD)) { + return (List) content.get(BACKEND_ROLES_FIELD); + } + return null; } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index fb1d383b5..4c01e8fb8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -13,15 +13,12 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -34,7 +31,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -45,7 +41,6 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; -import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; @@ -75,7 +70,7 @@ public RegisterLocalModelStep( MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { - super(settings, clusterService); + super(settings, clusterService, mlClient, flowFrameworkIndicesHandler); this.mlClient = mlClient; this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; } @@ -98,7 +93,14 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { String taskId = mlRegisterModelResponse.getTaskId(); // Attempt to retrieve the model ID - retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeId, registerLocalModelFuture, taskId, 0); + retryableGetMlTask( + currentNodeInputs.getWorkflowId(), + currentNodeId, + registerLocalModelFuture, + taskId, + 0, + "Local model registration" + ); } @Override @@ -178,84 +180,4 @@ public void onFailure(Exception e) { public String getName() { return NAME; } - - /** - * Retryable get ml task - * @param workflowId the workflow id - * @param nodeId the workflow node id - * @param registerLocalModelFuture the workflow step future - * @param taskId the ml task id - * @param retries the current number of request retries - */ - void retryableGetMlTask( - String workflowId, - String nodeId, - CompletableFuture registerLocalModelFuture, - String taskId, - int retries - ) { - mlClient.getTask(taskId, ActionListener.wrap(response -> { - MLTaskState currentState = response.getState(); - if (currentState != MLTaskState.COMPLETED) { - if (Stream.of(MLTaskState.FAILED, MLTaskState.COMPLETED_WITH_ERROR).anyMatch(x -> x == currentState)) { - // Model registration failed or completed with errors - String errorMessage = "Local model registration failed with error : " + response.getError(); - logger.error(errorMessage); - registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); - } else { - // Task still in progress, attempt retry - throw new IllegalStateException("Local model registration is not yet completed"); - } - } else { - try { - logger.info("Local Model registration successful"); - String resourceName = WorkflowResources.getResourceByWorkflowStep(getName()); - flowFrameworkIndicesHandler.updateResourceInStateIndex( - workflowId, - nodeId, - getName(), - response.getTaskId(), - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - registerLocalModelFuture.complete( - new WorkflowData( - Map.ofEntries( - Map.entry(resourceName, response.getModelId()), - Map.entry(REGISTER_MODEL_STATUS, response.getState().name()) - ), - workflowId, - nodeId - ) - ); - }, exception -> { - logger.error("Failed to update new created resource", exception); - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - }) - ); - - } catch (Exception e) { - logger.error("Failed to parse and update new created resource", e); - registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - } - }, exception -> { - if (retries < maxRetry) { - // Sleep thread prior to retrying request - try { - Thread.sleep(5000); - } catch (Exception e) { - FutureUtils.cancel(registerLocalModelFuture); - } - final int retryAdd = retries + 1; - retryableGetMlTask(workflowId, nodeId, registerLocalModelFuture, taskId, retryAdd); - } else { - logger.error("Failed to retrieve local model registration task, maximum retries exceeded"); - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); - } - })); - } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 3e8ef2981..742fefc36 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -52,7 +52,7 @@ public WorkflowStepFactory( ); stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); - stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient)); + stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler)); stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 670933373..fa27142f1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -10,27 +10,49 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.mockito.ArgumentMatchers.eq; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_GET_TASK_REQUEST_RETRY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class DeployModelStepTests extends OpenSearchTestCase { @@ -40,22 +62,37 @@ public class DeployModelStepTests extends OpenSearchTestCase { @Mock MachineLearningNodeClient machineLearningNodeClient; + private DeployModelStep deployModel; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + @Override public void setUp() throws Exception { super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); + ClusterService clusterService = mock(ClusterService.class); + final Set> settingsSet = Stream.concat( + ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), + Stream.of(MAX_GET_TASK_REQUEST_RETRY) + ).collect(Collectors.toSet()); - MockitoAnnotations.openMocks(this); + // Set max request retry setting to 0 to avoid sleeping the thread during unit test failure cases + Settings testMaxRetrySetting = Settings.builder().put(MAX_GET_TASK_REQUEST_RETRY.getKey(), 0).build(); + ClusterSettings clusterSettings = new ClusterSettings(testMaxRetrySetting, settingsSet); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + this.deployModel = new DeployModelStep(testMaxRetrySetting, clusterService, machineLearningNodeClient, flowFrameworkIndicesHandler); + this.inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); } - public void testDeployModel() throws ExecutionException, InterruptedException { + public void testDeployModel() throws ExecutionException, InterruptedException, IOException { + String modelId = "modelId"; String taskId = "taskId"; - String status = MLTaskState.CREATED.name(); - MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; - DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + String status = MLTaskState.COMPLETED.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -65,7 +102,36 @@ public void testDeployModel() throws ExecutionException, InterruptedException { MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); CompletableFuture future = deployModel.execute( inputData.getNodeId(), @@ -74,15 +140,19 @@ public void testDeployModel() throws ExecutionException, InterruptedException { Collections.emptyMap() ); - verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + verify(machineLearningNodeClient, times(1)).deploy(any(String.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); assertTrue(future.isDone()); - assertEquals(status, future.get().getContent().get("deploy_model_status")); + assertFalse(future.isCompletedExceptionally()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testDeployModelFailure() { - DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + String modelId = "modelId"; + String taskId = "taskId"; @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); @@ -106,4 +176,60 @@ public void testDeployModelFailure() { assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to deploy model", ex.getCause().getMessage()); } + + public void testDeployModelTaskFailure() throws IOException { + String modelId = "modelId"; + String taskId = "taskId"; + + String status = MLTaskState.RUNNING.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; + String testErrorMessage = "error"; + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deploy(eq(modelId), actionListenerCaptor.capture()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.FAILED, + null, + null, + null, + null, + null, + null, + testErrorMessage, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + CompletableFuture future = this.deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Deploy model failed with error : " + testErrorMessage, ex.getCause().getMessage()); + + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index ffa6d82d1..afd90786f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -156,7 +156,7 @@ public void testRegisterLocalModelSuccess() throws Exception { verify(machineLearningNodeClient, times(1)).getTask(any(), any()); assertTrue(future.isDone()); - assertTrue(!future.isCompletedExceptionally()); + assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e9e792add..8103f4fbf 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -118,7 +118,7 @@ public void testNodeDetails() throws IOException { ProcessNode node = workflow.get(0); assertEquals("default_timeout", node.id()); assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass()); - assertEquals(10, node.nodeTimeout().seconds()); + assertEquals(15, node.nodeTimeout().seconds()); node = workflow.get(1); assertEquals("custom_timeout", node.id()); assertEquals(CreateIndexStep.class, node.workflowStep().getClass());