From 8069ea8a53bf03e5d73d2fe8cc1b3bf09ce2a35d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 14 Dec 2023 17:21:03 -0800 Subject: [PATCH] [Feature/agent_framework] Deprovision API (#271) * Deprovision REST and Transport Actions Signed-off-by: Daniel Widdis * Fix errors you find actually running the code Signed-off-by: Daniel Widdis * Add test for Rest deprovision action Signed-off-by: Daniel Widdis * Initial copypaste of Deprovision Transport Action Test Signed-off-by: Daniel Widdis * Add some delays to let deletions propagate, reset workflow state Signed-off-by: Daniel Widdis * Improved deprovisioning results and status updates Signed-off-by: Daniel Widdis * Fix bug in resource created parsing Signed-off-by: Daniel Widdis * Completed test implementations Signed-off-by: Daniel Widdis * Fixes after rebase Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .../flowframework/FlowFrameworkPlugin.java | 5 + .../common/WorkflowResources.java | 50 ++- .../flowframework/model/ResourceCreated.java | 9 +- .../rest/RestDeprovisionWorkflowAction.java | 108 ++++++ .../transport/DeprovisionWorkflowAction.java | 27 ++ .../DeprovisionWorkflowTransportAction.java | 337 ++++++++++++++++++ .../transport/WorkflowRequest.java | 2 +- .../FlowFrameworkPluginTests.java | 4 +- .../RestDeprovisionWorkflowActionTests.java | 100 ++++++ ...provisionWorkflowTransportActionTests.java | 234 ++++++++++++ 10 files changed, 857 insertions(+), 19 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 40ddee2fa..544f6f3e1 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -28,6 +28,7 @@ import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; +import org.opensearch.flowframework.rest.RestDeprovisionWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; @@ -35,6 +36,8 @@ import org.opensearch.flowframework.rest.RestSearchWorkflowStateAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; +import org.opensearch.flowframework.transport.DeprovisionWorkflowAction; +import org.opensearch.flowframework.transport.DeprovisionWorkflowTransportAction; import org.opensearch.flowframework.transport.GetWorkflowAction; import org.opensearch.flowframework.transport.GetWorkflowStateAction; import org.opensearch.flowframework.transport.GetWorkflowStateTransportAction; @@ -131,6 +134,7 @@ public List getRestHandlers( return ImmutableList.of( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestDeprovisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting), @@ -143,6 +147,7 @@ public List getRestHandlers( return ImmutableList.of( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), + new ActionHandler<>(DeprovisionWorkflowAction.INSTANCE, DeprovisionWorkflowTransportAction.class), new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class), diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index d43a9e0f9..1246574d7 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -23,32 +23,34 @@ public enum WorkflowResources { /** official workflow step name for creating a connector and associated created resource */ - CREATE_CONNECTOR("create_connector", "connector_id"), + CREATE_CONNECTOR("create_connector", "connector_id", "delete_connector"), /** official workflow step name for registering a remote model and associated created resource */ - REGISTER_REMOTE_MODEL("register_remote_model", "model_id"), + REGISTER_REMOTE_MODEL("register_remote_model", "model_id", "delete_model"), /** official workflow step name for registering a local model and associated created resource */ - REGISTER_LOCAL_MODEL("register_local_model", "model_id"), + REGISTER_LOCAL_MODEL("register_local_model", "model_id", "delete_model"), /** official workflow step name for registering a model group and associated created resource */ - REGISTER_MODEL_GROUP("register_model_group", "model_group_id"), + REGISTER_MODEL_GROUP("register_model_group", "model_group_id", null), // TODO /** official workflow step name for deploying a model and associated created resource */ - DEPLOY_MODEL("deploy_model", "model_id"), + DEPLOY_MODEL("deploy_model", "model_id", "undeploy_model"), /** official workflow step name for creating an ingest-pipeline and associated created resource */ - CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id"), + CREATE_INGEST_PIPELINE("create_ingest_pipeline", "pipeline_id", null), // TODO /** official workflow step name for creating an index and associated created resource */ - CREATE_INDEX("create_index", "index_name"), + CREATE_INDEX("create_index", "index_name", null), // TODO /** official workflow step name for register an agent and the associated created resource */ - REGISTER_AGENT("register_agent", "agent_id"); + REGISTER_AGENT("register_agent", "agent_id", "delete_agent"); private final String workflowStep; private final String resourceCreated; + private final String deprovisionStep; private static final Logger logger = LogManager.getLogger(WorkflowResources.class); private static final Set allResources = Stream.of(values()) .map(WorkflowResources::getResourceCreated) .collect(Collectors.toSet()); - WorkflowResources(String workflowStep, String resourceCreated) { + WorkflowResources(String workflowStep, String resourceCreated, String deprovisionStep) { this.workflowStep = workflowStep; this.resourceCreated = resourceCreated; + this.deprovisionStep = deprovisionStep; } /** @@ -68,7 +70,15 @@ public String getResourceCreated() { } /** - * gets the resources created type based on the workflowStep + * Returns the deprovisionStep for the given enum Constant + * @return the deprovisionStep of this data. + */ + public String getDeprovisionStep() { + return deprovisionStep; + } + + /** + * Gets the resources created type based on the workflowStep. * @param workflowStep workflow step name * @return the resource that will be created * @throws FlowFrameworkException if workflow step doesn't exist in enum @@ -76,7 +86,7 @@ public String getResourceCreated() { public static String getResourceByWorkflowStep(String workflowStep) throws FlowFrameworkException { if (workflowStep != null && !workflowStep.isEmpty()) { for (WorkflowResources mapping : values()) { - if (mapping.getWorkflowStep().equals(workflowStep)) { + if (workflowStep.equals(mapping.getWorkflowStep()) || workflowStep.equals(mapping.getDeprovisionStep())) { return mapping.getResourceCreated(); } } @@ -85,6 +95,24 @@ public static String getResourceByWorkflowStep(String workflowStep) throws FlowF throw new FlowFrameworkException("Unable to find resource type for step: " + workflowStep, RestStatus.BAD_REQUEST); } + /** + * Gets the deprovision step type based on the workflowStep. + * @param workflowStep workflow step name + * @return the corresponding step to deprovision + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static String getDeprovisionStepByWorkflowStep(String workflowStep) throws FlowFrameworkException { + if (workflowStep != null && !workflowStep.isEmpty()) { + for (WorkflowResources mapping : values()) { + if (mapping.getWorkflowStep().equals(workflowStep)) { + return mapping.getDeprovisionStep(); + } + } + } + logger.error("Unable to find deprovision step for step: " + workflowStep); + throw new FlowFrameworkException("Unable to find deprovision step for step: " + workflowStep, RestStatus.BAD_REQUEST); + } + /** * Returns all the possible resource created types in enum * @return a set of all the resource created types diff --git a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java index b12f4d044..9cc096ef6 100644 --- a/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java +++ b/src/main/java/org/opensearch/flowframework/model/ResourceCreated.java @@ -176,15 +176,14 @@ public static ResourceCreated parse(XContentParser parser) throws IOException { @Override public String toString() { - return "resources_Created [workflow_step_name= " + return "resources_Created [workflow_step_name=" + workflowStepName - + ", workflow_step_id= " + + ", workflow_step_id=" + workflowStepId - + ", resource_type= " + + ", resource_type=" + resourceType - + ", resource_id= " + + ", resource_id=" + resourceId + "]"; } - } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java new file mode 100644 index 000000000..467a683ce --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowAction.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.DeprovisionWorkflowAction; +import org.opensearch.flowframework.transport.WorkflowRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + +/** + * Rest Action to facilitate requests to de-provision a workflow + */ +public class RestDeprovisionWorkflowAction extends BaseRestHandler { + + private static final String DEPROVISION_WORKFLOW_ACTION = "deprovision_workflow"; + private static final Logger logger = LogManager.getLogger(RestDeprovisionWorkflowAction.class); + private final FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + /** + * Instantiates a new RestDeprovisionWorkflowAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestDeprovisionWorkflowAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } + + @Override + public String getName() { + return DEPROVISION_WORKFLOW_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("No request body is required", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + return channel -> client.execute(DeprovisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = exception instanceof FlowFrameworkException + ? (FlowFrameworkException) exception + : new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + } + + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_deprovision")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java new file mode 100644 index 000000000..8efcfbbc3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestDeprovisionWorkflowAction + */ +public class DeprovisionWorkflowAction extends ActionType { + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow/deprovision"; + /** An instance of this action */ + public static final DeprovisionWorkflowAction INSTANCE = new DeprovisionWorkflowAction(); + + private DeprovisionWorkflowAction() { + super(NAME, WorkflowResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java new file mode 100644 index 000000000..784b67374 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -0,0 +1,337 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.time.Instant; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; +import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep; +import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; + +/** + * Transport Action to deprovision a workflow from a stored use case template + */ +public class DeprovisionWorkflowTransportAction extends HandledTransportAction { + + private static final String DEPROVISION_SUFFIX = "_deprovision"; + + private final Logger logger = LogManager.getLogger(DeprovisionWorkflowTransportAction.class); + + private final ThreadPool threadPool; + private final Client client; + private final WorkflowProcessSorter workflowProcessSorter; + private final WorkflowStepFactory workflowStepFactory; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private final EncryptorUtils encryptorUtils; + + /** + * Instantiates a new ProvisionWorkflowTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param threadPool The OpenSearch thread pool + * @param client The node client to retrieve a stored use case template + * @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes + * @param workflowStepFactory The factory instantiating workflow steps + * @param flowFrameworkIndicesHandler Class to handle all internal system indices actions + * @param encryptorUtils Utility class to handle encryption/decryption + */ + @Inject + public DeprovisionWorkflowTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ThreadPool threadPool, + Client client, + WorkflowProcessSorter workflowProcessSorter, + WorkflowStepFactory workflowStepFactory, + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, + EncryptorUtils encryptorUtils + ) { + super(DeprovisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.threadPool = threadPool; + this.client = client; + this.workflowProcessSorter = workflowProcessSorter; + this.workflowStepFactory = workflowStepFactory; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + this.encryptorUtils = encryptorUtils; + } + + @Override + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + // Retrieve use case template from global context + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + // Stash thread context to interact with system index + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + listener.onFailure( + new FlowFrameworkException( + "Failed to retrieve template (" + workflowId + ") from global context.", + RestStatus.NOT_FOUND + ) + ); + return; + } + + // Parse template from document source + Template template = Template.parse(response.getSourceAsString()); + + // Decrypt template + template = encryptorUtils.decryptTemplateCredentials(template); + + // Sort and validate graph + Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); + workflowProcessSorter.validateGraph(provisionProcessSequence); + + // We have a valid template and sorted nodes, get the created resources + getResourcesAndExecute(request.getWorkflowId(), provisionProcessSequence, listener); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + logger.error("Workflow validation failed for workflow : " + workflowId); + listener.onFailure(exception); + } else { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } + })); + } catch (Exception e) { + String message = "Failed to retrieve template from global context."; + logger.error(message, e); + listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(e))); + } + } + + private void getResourcesAndExecute( + String workflowId, + List provisionProcessSequence, + ActionListener listener + ) { + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { + // Get a map of step id to created resources + final Map resourceMap = response.getWorkflowState() + .resourcesCreated() + .stream() + .collect(Collectors.toMap(ResourceCreated::workflowStepId, Function.identity())); + + // Now finally do the deprovision + executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener); + }, exception -> { + String message = "Failed to get workflow state for workflow " + workflowId; + logger.error(message, exception); + listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception))); + })); + } + + private void executeDeprovisionSequence( + String workflowId, + Map resourceMap, + List provisionProcessSequence, + ActionListener listener + ) { + // Create a list of ProcessNodes with the corresponding deprovision workflow steps + List deprovisionProcessSequence = provisionProcessSequence.stream() + // Only include nodes that created a resource + .filter(pn -> resourceMap.containsKey(pn.id())) + // Create a new ProcessNode with a deprovision step + .map(pn -> { + String stepName = pn.workflowStep().getName(); + String deprovisionStep = getDeprovisionStepByWorkflowStep(stepName); + // Unimplemented steps presently return null, so skip + if (deprovisionStep == null) { + return null; + } + // New ID is old ID with deprovision added + String deprovisionStepId = pn.id() + DEPROVISION_SUFFIX; + return new ProcessNode( + deprovisionStepId, + workflowStepFactory.createStep(deprovisionStep), + Collections.emptyMap(), + new WorkflowData( + Map.of(getResourceByWorkflowStep(stepName), resourceMap.get(pn.id()).resourceId()), + workflowId, + deprovisionStepId + ), + Collections.emptyList(), + this.threadPool, + pn.nodeTimeout() + ); + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + // Deprovision in reverse order of provisioning to minimize risk of dependencies + Collections.reverse(deprovisionProcessSequence); + logger.info("Deprovisioning steps: {}", deprovisionProcessSequence.stream().map(ProcessNode::id).collect(Collectors.joining(", "))); + + // Repeat attempting to delete resources as long as at least one is successful + int resourceCount = deprovisionProcessSequence.size(); + while (resourceCount > 0) { + Iterator iter = deprovisionProcessSequence.iterator(); + while (iter.hasNext()) { + ProcessNode deprovisionNode = iter.next(); + ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourceMap); + String resourceNameAndId = getResourceNameAndId(resource); + CompletableFuture deprovisionFuture = deprovisionNode.execute(); + try { + deprovisionFuture.join(); + logger.info("Successful {} for {}", deprovisionNode.id(), resourceNameAndId); + // Remove from list so we don't try again + iter.remove(); + // Pause briefly before next step + Thread.sleep(100); + } catch (Throwable t) { + logger.info( + "Failed {} for {}: {}", + deprovisionNode.id(), + resourceNameAndId, + t.getCause() == null ? t.getMessage() : t.getCause().getMessage() + ); + } + } + if (deprovisionProcessSequence.size() < resourceCount) { + // If we've deleted something, decrement and try again if not zero + resourceCount = deprovisionProcessSequence.size(); + deprovisionProcessSequence = deprovisionProcessSequence.stream().map(pn -> { + return new ProcessNode( + pn.id(), + workflowStepFactory.createStep(pn.workflowStep().getName()), + pn.previousNodeInputs(), + pn.input(), + pn.predecessors(), + this.threadPool, + pn.nodeTimeout() + ); + }).collect(Collectors.toList()); + // Pause briefly before next loop + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + break; + } + } else { + // If nothing was deleted, exit loop + break; + } + } + // Get corresponding resources + List remainingResources = deprovisionProcessSequence.stream() + .map(pn -> getResourceFromDeprovisionNode(pn, resourceMap)) + .collect(Collectors.toList()); + logger.info("Resources remaining: {}", remainingResources); + updateWorkflowState(workflowId, remainingResources, listener); + } + + private void updateWorkflowState( + String workflowId, + List remainingResources, + ActionListener listener + ) { + if (remainingResources.isEmpty()) { + // Successful deprovision + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.NOT_STARTED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, Collections.emptyList()) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to NOT_STARTED", workflowId); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); + // return workflow ID + listener.onResponse(new WorkflowResponse(workflowId)); + } else { + // Failed deprovision + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.COMPLETED), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.DONE), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, remainingResources) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to COMPLETED", workflowId); + }, exception -> { logger.error("Failed to update workflow state : {}", exception.getMessage()); }) + ); + // give user list of remaining resources + listener.onFailure( + new FlowFrameworkException( + "Failed to deprovision some resources: [" + + remainingResources.stream() + .map(DeprovisionWorkflowTransportAction::getResourceNameAndId) + .filter(Objects::nonNull) + .distinct() + .collect(Collectors.joining(", ")) + + "].", + RestStatus.ACCEPTED + ) + ); + } + } + + private static ResourceCreated getResourceFromDeprovisionNode(ProcessNode deprovisionNode, Map resourceMap) { + String deprovisionId = deprovisionNode.id(); + int pos = deprovisionId.indexOf(DEPROVISION_SUFFIX); + return pos > 0 ? resourceMap.get(deprovisionId.substring(0, pos)) : null; + } + + private static String getResourceNameAndId(ResourceCreated resource) { + if (resource == null) { + return null; + } + return getResourceByWorkflowStep(resource.workflowStepName()) + " " + resource.resourceId(); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 057f13d01..a030dccfa 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -19,7 +19,7 @@ import java.io.IOException; /** - * Transport Request to create and provision a workflow + * Transport Request to create, provision, and deprovision a workflow */ public class WorkflowRequest extends ActionRequest { diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index cbf988eee..9f9529ca5 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(6, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(6, ffp.getActions().size()); + assertEquals(7, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(7, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java new file mode 100644 index 000000000..a9170e35d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestDeprovisionWorkflowActionTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestDeprovisionWorkflowActionTests extends OpenSearchTestCase { + + private RestDeprovisionWorkflowAction deprovisionWorkflowRestAction; + private String deprovisionWorkflowPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + + this.deprovisionWorkflowRestAction = new RestDeprovisionWorkflowAction(flowFrameworkFeatureEnabledSetting); + this.deprovisionWorkflowPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_deprovision"); + this.nodeClient = mock(NodeClient.class); + } + + public void testRestDeprovisionWorkflowActionName() { + String name = deprovisionWorkflowRestAction.getName(); + assertEquals("deprovision_workflow", name); + } + + public void testRestDeprovisiionWorkflowActionRoutes() { + List routes = deprovisionWorkflowRestAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertEquals(this.deprovisionWorkflowPath, routes.get(0).getPath()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.deprovisionWorkflowPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + deprovisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.deprovisionWorkflowPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + deprovisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + }); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_deprovision] does not support having a body", + ex.getMessage() + ); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.deprovisionWorkflowPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + deprovisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java new file mode 100644 index 000000000..5d21c63d8 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -0,0 +1,234 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.CreateConnectorStep; +import org.opensearch.flowframework.workflow.DeleteConnectorStep; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.junit.AfterClass; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.mockito.ArgumentCaptor; + +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase { + + private static ThreadPool threadPool = new TestThreadPool(DeprovisionWorkflowTransportActionTests.class.getName()); + private Client client; + private WorkflowProcessSorter workflowProcessSorter; + private WorkflowStepFactory workflowStepFactory; + private DeleteConnectorStep deleteConnectorStep; + private DeprovisionWorkflowTransportAction deprovisionWorkflowTransportAction; + private Template template; + private GetResult getResult; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private EncryptorUtils encryptorUtils; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + this.workflowStepFactory = mock(WorkflowStepFactory.class); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.encryptorUtils = mock(EncryptorUtils.class); + + this.deprovisionWorkflowTransportAction = new DeprovisionWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + threadPool, + client, + workflowProcessSorter, + workflowStepFactory, + flowFrameworkIndicesHandler, + encryptorUtils + ); + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode node = new WorkflowNode("step_1", "create_connector", Collections.emptyMap(), Collections.emptyMap()); + List nodes = List.of(node); + List edges = Collections.emptyList(); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + this.template = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of(PROVISION_WORKFLOW, workflow), + Map.of(), + TestHelpers.randomUser() + ); + this.getResult = mock(GetResult.class); + + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + ProcessNode processNode = mock(ProcessNode.class); + when(processNode.id()).thenReturn("step_1"); + when(processNode.workflowStep()).thenReturn(new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap()); + when(processNode.input()).thenReturn(WorkflowData.EMPTY); + when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5)); + when(this.workflowProcessSorter.sortProcessNodes(any(Workflow.class), any(String.class))).thenReturn(List.of(processNode)); + this.deleteConnectorStep = mock(DeleteConnectorStep.class); + when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep); + + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); + } + + public void testDeprovisionWorkflow() throws IOException { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(true); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + + WorkflowState state = WorkflowState.builder() + .resourcesCreated(List.of(new ResourceCreated("create_connector", "step_1", "connector_id", "connectorId"))) + .build(); + responseListener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); + + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn( + CompletableFuture.completedFuture(WorkflowData.EMPTY) + ); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + } + + public void testFailedToRetrieveTemplateFromGlobalContext() { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(false); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to retrieve template (1) from global context.", exceptionCaptor.getValue().getMessage()); + } + + public void testFailToDeprovision() throws IOException { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(true); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + + WorkflowState state = WorkflowState.builder() + .resourcesCreated(List.of(new ResourceCreated("deploy_model", "step_1", "model_id", "modelId"))) + .build(); + responseListener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException("rte")); + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage()); + } +}