From 6447ac64db669e6ef93e24edd4d5b131802c85b8 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Tue, 12 Dec 2023 17:17:12 -0800 Subject: [PATCH] [Feature/agent_framework] Add Get Workflow API to retrieve a stored template by workflow id (#273) * renaming status API implementation Signed-off-by: Joshua Palis * Adding GetWorkflow API Signed-off-by: Joshua Palis * addressing PR comments Signed-off-by: Joshua Palis * Adding todo reminder Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- .../flowframework/FlowFrameworkPlugin.java | 5 + .../rest/RestGetWorkflowAction.java | 26 ++- .../rest/RestGetWorkflowStateAction.java | 109 ++++++++++++ .../transport/GetWorkflowResponse.java | 45 +++-- .../transport/GetWorkflowStateAction.java | 29 +++ ...uest.java => GetWorkflowStateRequest.java} | 12 +- .../transport/GetWorkflowStateResponse.java | 67 +++++++ .../GetWorkflowStateTransportAction.java | 99 +++++++++++ .../transport/GetWorkflowTransportAction.java | 97 +++++----- .../FlowFrameworkPluginTests.java | 4 +- .../rest/RestGetWorkflowActionTests.java | 38 ++-- .../rest/RestGetWorkflowStateActionTests.java | 104 +++++++++++ .../GetWorkflowStateTransportActionTests.java | 127 ++++++++++++++ .../GetWorkflowTransportActionTests.java | 166 +++++++++++------- 14 files changed, 742 insertions(+), 186 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java rename src/main/java/org/opensearch/flowframework/transport/{GetWorkflowRequest.java => GetWorkflowStateRequest.java} (83%) create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java create mode 100644 src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java create mode 100644 src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java create mode 100644 src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 513984c68..ec9eb40da 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -29,11 +29,14 @@ import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; import org.opensearch.flowframework.rest.RestGetWorkflowAction; +import org.opensearch.flowframework.rest.RestGetWorkflowStateAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.rest.RestSearchWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; import org.opensearch.flowframework.transport.GetWorkflowAction; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateTransportAction; import org.opensearch.flowframework.transport.GetWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; @@ -126,6 +129,7 @@ public List getRestHandlers( new RestCreateWorkflowAction(flowFrameworkFeatureEnabledSetting, settings, clusterService), new RestProvisionWorkflowAction(flowFrameworkFeatureEnabledSetting), new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting), + new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting), new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting) ); } @@ -136,6 +140,7 @@ public List getRestHandlers( new ActionHandler<>(CreateWorkflowAction.INSTANCE, CreateWorkflowTransportAction.class), new ActionHandler<>(ProvisionWorkflowAction.INSTANCE, ProvisionWorkflowTransportAction.class), new ActionHandler<>(SearchWorkflowAction.INSTANCE, SearchWorkflowTransportAction.class), + new ActionHandler<>(GetWorkflowStateAction.INSTANCE, GetWorkflowStateTransportAction.class), new ActionHandler<>(GetWorkflowAction.INSTANCE, GetWorkflowTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java index 6d9d5e3b5..5a92e9c0e 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowAction.java @@ -20,7 +20,7 @@ import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.GetWorkflowAction; -import org.opensearch.flowframework.transport.GetWorkflowRequest; +import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; @@ -34,7 +34,7 @@ import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; /** - * Rest Action to facilitate requests to get a workflow status + * Rest Action to facilitate requests to get a stored template */ public class RestGetWorkflowAction extends BaseRestHandler { @@ -55,6 +55,11 @@ public String getName() { return GET_WORKFLOW_ACTION; } + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, WORKFLOW_ID))); + } + @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { @@ -68,7 +73,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request // Validate content if (request.hasContent()) { - throw new FlowFrameworkException("No request body present", RestStatus.BAD_REQUEST); + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); } // Validate params String workflowId = request.param(WORKFLOW_ID); @@ -76,9 +81,8 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } - boolean all = request.paramAsBoolean("all", false); - GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest(workflowId, all); - return channel -> client.execute(GetWorkflowAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(GetWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); }, exception -> { @@ -88,7 +92,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); } catch (IOException e) { - logger.error("Failed to send back provision workflow exception", e); + logger.error("Failed to send back get workflow exception", e); channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); } })); @@ -99,12 +103,4 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request ); } } - - @Override - public List routes() { - return ImmutableList.of( - // Provision workflow from indexed use case template - new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) - ); - } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java new file mode 100644 index 000000000..ab7335b2d --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/rest/RestGetWorkflowStateAction.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.GetWorkflowStateAction; +import org.opensearch.flowframework.transport.GetWorkflowStateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; + +/** + * Rest Action to facilitate requests to get a workflow status + */ +public class RestGetWorkflowStateAction extends BaseRestHandler { + + private static final String GET_WORKFLOW_STATE_ACTION = "get_workflow_state"; + private static final Logger logger = LogManager.getLogger(RestGetWorkflowStateAction.class); + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + /** + * Instantiates a new RestGetWorkflowStateAction + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled + */ + public RestGetWorkflowStateAction(FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + this.flowFrameworkFeatureEnabledSetting = flowFrameworkFeatureEnabledSetting; + } + + @Override + public String getName() { + return GET_WORKFLOW_STATE_ACTION; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + try { + if (!flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()) { + throw new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + } + + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("No request body present", RestStatus.BAD_REQUEST); + } + // Validate params + String workflowId = request.param(WORKFLOW_ID); + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + + boolean all = request.paramAsBoolean("all", false); + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest(workflowId, all); + return channel -> client.execute(GetWorkflowStateAction.INSTANCE, getWorkflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), e.getMessage())); + } + })); + + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + } + + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, WORKFLOW_ID, "_status")) + ); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java index 922a8a3f5..db70d2cb2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowResponse.java @@ -13,55 +13,52 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.model.Template; import java.io.IOException; /** - * Transport Response from getting a workflow status + * Transport Response from getting a template */ public class GetWorkflowResponse extends ActionResponse implements ToXContentObject { - /** The workflow state */ - public WorkflowState workflowState; - /** Flag to indicate if the entire state should be returned */ - public boolean allStatus; + /** The template */ + private Template template; /** * Instantiates a new GetWorkflowResponse from an input stream * @param in the input stream to read from - * @throws IOException if the workflowId cannot be read from the input stream + * @throws IOException if the template json cannot be read from the input stream */ public GetWorkflowResponse(StreamInput in) throws IOException { super(in); - workflowState = new WorkflowState(in); - allStatus = in.readBoolean(); + this.template = Template.parse(in.readString()); } /** - * Instatiates a new GetWorkflowResponse from an input stream - * @param workflowState the workflow state object - * @param allStatus whether to return all fields in state index + * Instantiates a new GetWorkflowResponse + * @param template the template */ - public GetWorkflowResponse(WorkflowState workflowState, boolean allStatus) { - if (allStatus) { - this.workflowState = workflowState; - } else { - this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId()) - .error(workflowState.getError()) - .state(workflowState.getState()) - .resourcesCreated(workflowState.resourcesCreated()) - .build(); - } + public GetWorkflowResponse(Template template) { + this.template = template; } @Override public void writeTo(StreamOutput out) throws IOException { - workflowState.writeTo(out); + out.writeString(template.toJson()); } @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - return workflowState.toXContent(xContentBuilder, params); + return this.template.toXContent(xContentBuilder, params); } + + /** + * Gets the template + * @return the template + */ + public Template getTemplate() { + return this.template; + } + } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java new file mode 100644 index 000000000..b8a713685 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.ActionType; + +import static org.opensearch.flowframework.common.CommonValue.TRANSPORT_ACTION_NAME_PREFIX; + +/** + * External Action for public facing RestGetWorkflowStateAction + */ +public class GetWorkflowStateAction extends ActionType { + // TODO : If the template body is returned as part of the GetWorkflowStateAction, + // it is necessary to ensure the user has permissions for workflow/get + /** The name of this action */ + public static final String NAME = TRANSPORT_ACTION_NAME_PREFIX + "workflow_state/get"; + /** An instance of this action */ + public static final GetWorkflowStateAction INSTANCE = new GetWorkflowStateAction(); + + private GetWorkflowStateAction() { + super(NAME, GetWorkflowStateResponse::new); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java similarity index 83% rename from src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java rename to src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java index c7594eb77..7fd546c25 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateRequest.java @@ -17,9 +17,9 @@ import java.io.IOException; /** - * Transport Request to get a workflow or workflow status + * Transport Request to get a workflow status */ -public class GetWorkflowRequest extends ActionRequest { +public class GetWorkflowStateRequest extends ActionRequest { /** * The documentId of the workflow entry within the Global Context index @@ -33,21 +33,21 @@ public class GetWorkflowRequest extends ActionRequest { private boolean all; /** - * Instantiates a new GetWorkflowRequest + * Instantiates a new GetWorkflowStateRequest * @param workflowId the documentId of the workflow * @param all whether the get request is looking for all fields in status */ - public GetWorkflowRequest(@Nullable String workflowId, boolean all) { + public GetWorkflowStateRequest(@Nullable String workflowId, boolean all) { this.workflowId = workflowId; this.all = all; } /** - * Instantiates a new GetWorkflowRequest request + * Instantiates a new GetWorkflowStateRequest request * @param in The input stream to read from * @throws IOException If the stream cannot be read properly */ - public GetWorkflowRequest(StreamInput in) throws IOException { + public GetWorkflowStateRequest(StreamInput in) throws IOException { super(in); this.workflowId = in.readString(); this.all = in.readBoolean(); diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java new file mode 100644 index 000000000..fe155237e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.model.WorkflowState; + +import java.io.IOException; + +/** + * Transport Response from getting a workflow status + */ +public class GetWorkflowStateResponse extends ActionResponse implements ToXContentObject { + + /** The workflow state */ + public WorkflowState workflowState; + /** Flag to indicate if the entire state should be returned */ + public boolean allStatus; + + /** + * Instantiates a new GetWorkflowStateResponse from an input stream + * @param in the input stream to read from + * @throws IOException if the workflowId cannot be read from the input stream + */ + public GetWorkflowStateResponse(StreamInput in) throws IOException { + super(in); + workflowState = new WorkflowState(in); + allStatus = in.readBoolean(); + } + + /** + * Instatiates a new GetWorkflowStateResponse from an input stream + * @param workflowState the workflow state object + * @param allStatus whether to return all fields in state index + */ + public GetWorkflowStateResponse(WorkflowState workflowState, boolean allStatus) { + if (allStatus) { + this.workflowState = workflowState; + } else { + this.workflowState = new WorkflowState.Builder().workflowId(workflowState.getWorkflowId()) + .error(workflowState.getError()) + .state(workflowState.getState()) + .resourcesCreated(workflowState.resourcesCreated()) + .build(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + workflowState.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return workflowState.toXContent(xContentBuilder, params); + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java new file mode 100644 index 000000000..57fcc2b89 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; + +//TODO: Currently we only get the workflow status but we should change to be able to get the +// full template as well +/** + * Transport Action to get a specific workflow. Currently, we only support the action with _status + * in the API path but will add the ability to get the workflow and not just the status in the future + */ +public class GetWorkflowStateTransportAction extends HandledTransportAction { + + private final Logger logger = LogManager.getLogger(GetWorkflowStateTransportAction.class); + + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + /** + * Intantiates a new GetWorkflowStateTransportAction + * @param transportService The TransportService + * @param actionFilters action filters + * @param client The client used to make the request to OS + * @param xContentRegistry contentRegister to parse get response + */ + @Inject + public GetWorkflowStateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry + ) { + super(GetWorkflowStateAction.NAME, transportService, actionFilters, GetWorkflowStateRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void doExecute(Task task, GetWorkflowStateRequest request, ActionListener listener) { + String workflowId = request.getWorkflowId(); + User user = ParseUtils.getUserContext(client); + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(new GetWorkflowStateResponse(workflowState, request.getAll())); + } catch (Exception e) { + logger.error("Failed to parse workflowState" + r.getId(), e); + listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST)); + } + } else { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); + } else { + logger.error("Failed to get workflow status of: " + workflowId, e); + listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND)); + } + }), () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to get workflow: " + workflowId, e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index f3bc1dd9e..e2a9b1931 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -17,83 +17,78 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.flowframework.model.WorkflowState; -import org.opensearch.flowframework.util.ParseUtils; -import org.opensearch.index.IndexNotFoundException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.Template; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -//TODO: Currently we only get the workflow status but we should change to be able to get the -// full template as well /** - * Transport Action to get a specific workflow. Currently, we only support the action with _status - * in the API path but will add the ability to get the workflow and not just the status in the future + * Transport action to retrieve a use case template within the Global Context */ -public class GetWorkflowTransportAction extends HandledTransportAction { +public class GetWorkflowTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(GetWorkflowTransportAction.class); - + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private final Client client; - private final NamedXContentRegistry xContentRegistry; /** - * Intantiates a new CreateWorkflowTransportAction - * @param transportService The TransportService + * Instantiates a new GetWorkflowTransportAction instance + * @param transportService the transport service * @param actionFilters action filters - * @param client The client used to make the request to OS - * @param xContentRegistry contentRegister to parse get response + * @param flowFrameworkIndicesHandler The Flow Framework indices handler + * @param client the Opensearch Client */ @Inject public GetWorkflowTransportAction( TransportService transportService, ActionFilters actionFilters, - Client client, - NamedXContentRegistry xContentRegistry + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, + Client client ) { - super(GetWorkflowAction.NAME, transportService, actionFilters, GetWorkflowRequest::new); + super(GetWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; this.client = client; - this.xContentRegistry = xContentRegistry; } @Override - protected void doExecute(Task task, GetWorkflowRequest request, ActionListener listener) { - String workflowId = request.getWorkflowId(); - User user = ParseUtils.getUserContext(client); - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - listener.onResponse(new GetWorkflowResponse(workflowState, request.getAll())); - } catch (Exception e) { - logger.error("Failed to parse workflowState" + r.getId(), e); - listener.onFailure(new FlowFrameworkException("Failed to parse workflowState" + r.getId(), RestStatus.BAD_REQUEST)); + protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { + if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + // Retrieve workflow by ID + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + listener.onFailure( + new FlowFrameworkException( + "Failed to retrieve template (" + workflowId + ") from global context.", + RestStatus.NOT_FOUND + ) + ); + } else { + listener.onResponse(new GetWorkflowResponse(Template.parse(response.getSourceAsString()))); } - } else { - listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onFailure(new FlowFrameworkException("Fail to find workflow", RestStatus.NOT_FOUND)); - } else { - logger.error("Failed to get workflow status of: " + workflowId, e); - listener.onFailure(new FlowFrameworkException("Failed to get workflow status of: " + workflowId, RestStatus.NOT_FOUND)); - } - }), () -> context.restore())); - } catch (Exception e) { - logger.error("Failed to get workflow: " + workflowId, e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + }, exception -> { + logger.error("Failed to retrieve template from global context.", exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + })); + } catch (Exception e) { + logger.error("Failed to retrieve template from global context.", e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + + } else { + listener.onFailure(new FlowFrameworkException("There are no templates in the global_context", RestStatus.NOT_FOUND)); } + } } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 2585ffb09..6370d2312 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -82,8 +82,8 @@ public void testPlugin() throws IOException { 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); - assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); - assertEquals(4, ffp.getActions().size()); + assertEquals(5, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); + assertEquals(5, ffp.getActions().size()); assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(5, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java index 0f6ddab59..3a51f1a9e 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowActionTests.java @@ -29,25 +29,20 @@ public class RestGetWorkflowActionTests extends OpenSearchTestCase { private RestGetWorkflowAction restGetWorkflowAction; private String getPath; - private NodeClient nodeClient; private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + private NodeClient nodeClient; @Override public void setUp() throws Exception { super.setUp(); - this.getPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_status"); + this.getPath = String.format(Locale.ROOT, "%s/{%s}", WORKFLOW_URI, "workflow_id"); flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); this.restGetWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); this.nodeClient = mock(NodeClient.class); } - public void testConstructor() { - RestGetWorkflowAction getWorkflowAction = new RestGetWorkflowAction(flowFrameworkFeatureEnabledSetting); - assertNotNull(getWorkflowAction); - } - public void testRestGetWorkflowActionName() { String name = restGetWorkflowAction.getName(); assertEquals("get_workflow", name); @@ -60,6 +55,19 @@ public void testRestGetWorkflowActionRoutes() { assertEquals(this.getPath, routes.get(0).getPath()); } + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restGetWorkflowAction.handleRequest(request, channel, nodeClient); + }); + assertEquals("request [POST /_plugins/_flow_framework/workflow/{workflow_id}] does not support having a body", ex.getMessage()); + } + public void testNullWorkflowId() throws Exception { // Request with no params @@ -75,22 +83,6 @@ public void testNullWorkflowId() throws Exception { assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); } - public void testInvalidRequestWithContent() { - RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) - .withPath(this.getPath) - .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) - .build(); - - FakeRestChannel channel = new FakeRestChannel(request, false, 1); - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { - restGetWorkflowAction.handleRequest(request, channel, nodeClient); - }); - assertEquals( - "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", - ex.getMessage() - ); - } - public void testFeatureFlagNotEnabled() throws Exception { when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) diff --git a/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java new file mode 100644 index 000000000..dc605a5cd --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/rest/RestGetWorkflowStateActionTests.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.rest; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RestGetWorkflowStateActionTests extends OpenSearchTestCase { + private RestGetWorkflowStateAction restGetWorkflowStateAction; + private String getPath; + private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; + + @Override + public void setUp() throws Exception { + super.setUp(); + + this.getPath = String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOW_URI, "workflow_id", "_status"); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restGetWorkflowStateAction = new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting); + this.nodeClient = mock(NodeClient.class); + } + + public void testConstructor() { + RestGetWorkflowStateAction getWorkflowAction = new RestGetWorkflowStateAction(flowFrameworkFeatureEnabledSetting); + assertNotNull(getWorkflowAction); + } + + public void testRestGetWorkflowStateActionName() { + String name = restGetWorkflowStateAction.getName(); + assertEquals("get_workflow_state", name); + } + + public void testRestGetWorkflowStateActionRoutes() { + List routes = restGetWorkflowStateAction.routes(); + assertEquals(1, routes.size()); + assertEquals(RestRequest.Method.GET, routes.get(0).getMethod()); + assertEquals(this.getPath, routes.get(0).getPath()); + } + + public void testNullWorkflowId() throws Exception { + + // Request with no params + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); + } + + public void testInvalidRequestWithContent() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) + .build(); + + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); + }); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_status] does not support having a body", + ex.getMessage() + ); + } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.getPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restGetWorkflowStateAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java new file mode 100644 index 000000000..7aa0323b4 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.junit.Assert; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +import org.mockito.Mockito; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GetWorkflowStateTransportActionTests extends OpenSearchTestCase { + + private GetWorkflowStateTransportAction getWorkflowStateTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private Client client; + private ThreadPool threadPool; + private ThreadContext threadContext; + private ActionListener response; + private Task task; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.client = mock(Client.class); + this.threadPool = mock(ThreadPool.class); + this.getWorkflowStateTransportAction = new GetWorkflowStateTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + xContentRegistry() + ); + task = Mockito.mock(Task.class); + ThreadPool clientThreadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + + when(client.threadPool()).thenReturn(clientThreadPool); + when(clientThreadPool.getThreadContext()).thenReturn(threadContext); + + response = new ActionListener() { + @Override + public void onResponse(GetWorkflowStateResponse getResponse) { + assertTrue(true); + } + + @Override + public void onFailure(Exception e) {} + }; + + } + + public void testGetTransportAction() throws IOException { + GetWorkflowStateRequest getWorkflowRequest = new GetWorkflowStateRequest("1234", false); + getWorkflowStateTransportAction.doExecute(task, getWorkflowRequest, response); + } + + public void testGetAction() { + Assert.assertNotNull(GetWorkflowStateAction.INSTANCE.name()); + Assert.assertEquals(GetWorkflowStateAction.INSTANCE.name(), GetWorkflowStateAction.NAME); + } + + public void testGetWorkflowStateRequest() throws IOException { + GetWorkflowStateRequest request = new GetWorkflowStateRequest("1234", false); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetWorkflowStateRequest newRequest = new GetWorkflowStateRequest(input); + Assert.assertEquals(request.getWorkflowId(), newRequest.getWorkflowId()); + Assert.assertEquals(request.getAll(), newRequest.getAll()); + Assert.assertNull(newRequest.validate()); + } + + public void testGetWorkflowStateResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + String workflowId = randomAlphaOfLength(5); + WorkflowState workFlowState = new WorkflowState( + workflowId, + "test", + "PROVISIONING", + "IN_PROGRESS", + Instant.now(), + Instant.now(), + TestHelpers.randomUser(), + Collections.emptyMap(), + Collections.emptyList() + ); + + GetWorkflowStateResponse response = new GetWorkflowStateResponse(workFlowState, false); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + GetWorkflowStateResponse newResponse = new GetWorkflowStateResponse(input); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + Map map = TestHelpers.XContentBuilderToMap(builder); + Assert.assertEquals(map.get("state"), workFlowState.getState()); + Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java index ab6d0a68f..d7db8a2c9 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java @@ -8,115 +8,151 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; -import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; -import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.index.get.GetResult; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import org.junit.Assert; -import java.io.IOException; -import java.time.Instant; -import java.util.Collections; +import java.util.List; import java.util.Map; -import org.mockito.Mockito; +import org.mockito.ArgumentCaptor; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class GetWorkflowTransportActionTests extends OpenSearchTestCase { - private GetWorkflowTransportAction getWorkflowTransportAction; + private ThreadPool threadPool; private Client client; - private ActionListener response; - private Task task; + private GetWorkflowTransportAction getTemplateTransportAction; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private Template template; @Override public void setUp() throws Exception { super.setUp(); + this.threadPool = mock(ThreadPool.class); this.client = mock(Client.class); - this.getWorkflowTransportAction = new GetWorkflowTransportAction( + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.getTemplateTransportAction = new GetWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), - client, - xContentRegistry() + flowFrameworkIndicesHandler, + client ); - task = Mockito.mock(Task.class); + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); + WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); + WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); + List nodes = List.of(nodeA, nodeB); + List edges = List.of(edgeAB); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); + + this.template = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + Map.of(), + TestHelpers.randomUser() + ); + ThreadPool clientThreadPool = mock(ThreadPool.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(client.threadPool()).thenReturn(clientThreadPool); when(clientThreadPool.getThreadContext()).thenReturn(threadContext); - response = new ActionListener() { - @Override - public void onResponse(GetWorkflowResponse getResponse) { - assertTrue(true); - } + } - @Override - public void onFailure(Exception e) {} - }; + public void testGetWorkflowNoGlobalContext() { - } + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest("1", null); + getTemplateTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - public void testGetTransportAction() throws IOException { - GetWorkflowRequest getWorkflowRequest = new GetWorkflowRequest("1234", false); - getWorkflowTransportAction.doExecute(task, getWorkflowRequest, response); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("There are no templates in the global_context")); } - public void testGetAction() { - Assert.assertNotNull(GetWorkflowAction.INSTANCE.name()); - Assert.assertEquals(GetWorkflowAction.INSTANCE.name(), GetWorkflowAction.NAME); - } + public void testGetWorkflowSuccess() { + String workflowId = "12345"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - public void testGetAnomalyDetectorRequest() throws IOException { - GetWorkflowRequest request = new GetWorkflowRequest("1234", false); - BytesStreamOutput out = new BytesStreamOutput(); - request.writeTo(out); - StreamInput input = out.bytes().streamInput(); - GetWorkflowRequest newRequest = new GetWorkflowRequest(input); - Assert.assertEquals(request.getWorkflowId(), newRequest.getWorkflowId()); - Assert.assertEquals(request.getAll(), newRequest.getAll()); - Assert.assertNull(newRequest.validate()); + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + + // Stub client.get to force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + this.template.toXContent(builder, null); + BytesReference templateBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + getTemplateTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + + ArgumentCaptor templateCaptor = ArgumentCaptor.forClass(GetWorkflowResponse.class); + verify(listener, times(1)).onResponse(templateCaptor.capture()); + assertEquals(this.template.name(), templateCaptor.getValue().getTemplate().name()); } - public void testGetAnomalyDetectorResponse() throws IOException { - BytesStreamOutput out = new BytesStreamOutput(); - String workflowId = randomAlphaOfLength(5); - WorkflowState workFlowState = new WorkflowState( - workflowId, - "test", - "PROVISIONING", - "IN_PROGRESS", - Instant.now(), - Instant.now(), - TestHelpers.randomUser(), - Collections.emptyMap(), - Collections.emptyList() - ); + public void testGetWorkflowFailure() { + String workflowId = "12345"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + + when(flowFrameworkIndicesHandler.doesIndexExist(anyString())).thenReturn(true); + + // Stub client.get to force on failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to retrieve template from global context.")); + return null; + }).when(client).get(any(GetRequest.class), any()); - GetWorkflowResponse response = new GetWorkflowResponse(workFlowState, false); - response.writeTo(out); - NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); - GetWorkflowResponse newResponse = new GetWorkflowResponse(input); - XContentBuilder builder = TestHelpers.builder(); - Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + getTemplateTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - Map map = TestHelpers.XContentBuilderToMap(builder); - Assert.assertEquals(map.get("state"), workFlowState.getState()); - Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to retrieve template from global context.", exceptionCaptor.getValue().getMessage()); } }