Skip to content

Commit

Permalink
Adding another test for update API, input validation, local model reg…
Browse files Browse the repository at this point in the history
…istration. Persiting cluster settings between test runs to ensure plugin apis are enabled. Cleaning up resources after all test runs complete, rather than between test runs

Signed-off-by: Joshua Palis <jpalis@amazon.com>
  • Loading branch information
joshpalis committed Dec 6, 2023
1 parent 17e3d05 commit 9f1433b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;

import javax.net.ssl.SSLEngine;
Expand All @@ -62,6 +62,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE;
Expand Down Expand Up @@ -206,9 +207,10 @@ protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOE

}

// Cleans up resources after all test execution has been completed
@SuppressWarnings("unchecked")
@After
protected void wipeAllSystemIndices() throws IOException {
@AfterClass
protected static void wipeAllSystemIndices() throws IOException {
Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all"));
MediaType xContentType = MediaType.fromMediaType(response.getEntity().getContentType());
try (
Expand Down Expand Up @@ -299,6 +301,14 @@ protected boolean preserveIndicesUponCompletion() {
return true;
}

/**
* Required to persist cluster settings between test executions
*/
@Override
protected boolean preserveClusterSettings() {
return true;
}

/**
* Helper method to invoke the Create Workflow Rest Action
* @param template the template to create
Expand All @@ -319,6 +329,24 @@ protected Response createWorkflowDryRun(Template template) throws Exception {
return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI + "?dryrun=true", ImmutableMap.of(), template.toJson(), null);
}

/**
* Helper method to invoke the Update Workflow API
* @param workflowId the document id
* @param template the template used to update
* @throws Exception if the request fails
* @return a rest response
*/
protected Response updateWorkflow(String workflowId, Template template) throws Exception {
return TestHelpers.makeRequest(
client(),
"PUT",
String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, workflowId),
ImmutableMap.of(),
template.toJson(),
null
);
}

/**
* Helper method to invoke the Provision Workflow Rest Action
* @param workflowId the workflow ID to provision
Expand Down Expand Up @@ -376,13 +404,18 @@ protected void getAndAssertWorkflowStatus(String workflowId, State stateStatus,
/**
* Helper method to wait until a workflow provisioning has completed and retrieve any resources created
* @param workflowId the workflow id to retrieve resources from
* @param timeout the max wait time in seconds
* @return a list of created resources
* @throws Exception if the request fails
*/
protected List<ResourceCreated> getResourcesCreated(String workflowId) throws Exception {
protected List<ResourceCreated> getResourcesCreated(String workflowId, int timeout) throws Exception {

// wait and ensure state is completed/done
assertBusy(() -> { getAndAssertWorkflowStatus(workflowId, State.COMPLETED, ProvisioningProgress.DONE); });
assertBusy(
() -> { getAndAssertWorkflowStatus(workflowId, State.COMPLETED, ProvisioningProgress.DONE); },
timeout,
TimeUnit.SECONDS
);

Response response = getWorkflowStatus(workflowId, true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand All @@ -28,6 +30,71 @@

public class FlowFrameworkRestApiIT extends FlowFrameworkRestTestCase {

public void testCreateAndProvisionLocalModelWorkflow() throws Exception {

// Using a 3 step template to create a model group, register a remote model and deploy model
Template template = TestHelpers.createTemplateFromFile("registermodelgroup-registerlocalmodel-deploymodel.json");

// Remove register model input to test validation
Workflow originalWorkflow = template.workflows().get(PROVISION_WORKFLOW);

List<WorkflowNode> modifiednodes = new ArrayList<>();
modifiednodes.add(
new WorkflowNode(
"workflow_step_1",
"model_group",
Map.of(),
Map.of() // empty user inputs
)
);
for (WorkflowNode node : originalWorkflow.nodes()) {
if (!node.id().equals("workflow_step_1")) {
modifiednodes.add(node);
}
}

Workflow missingInputs = new Workflow(originalWorkflow.userParams(), modifiednodes, originalWorkflow.edges());

Template templateWithMissingInputs = new Template.Builder().name(template.name())
.description(template.description())
.useCase(template.useCase())
.templateVersion(template.templateVersion())
.compatibilityVersion(template.compatibilityVersion())
.workflows(Map.of(PROVISION_WORKFLOW, missingInputs))
.uiMetadata(template.getUiMetadata())
.user(template.getUser())
.build();

// Hit Create Workflow API with invalid template
Response response = createWorkflow(templateWithMissingInputs);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

// Retrieve workflow ID
Map<String, Object> responseMap = entityAsMap(response);
String workflowId = (String) responseMap.get(WORKFLOW_ID);
getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED);

// Attempt provision
ResponseException exception = expectThrows(ResponseException.class, () -> provisionWorkflow(workflowId));
assertTrue(exception.getMessage().contains("Invalid graph, missing the following required inputs : [name]"));

// update workflow with updated inputs
response = updateWorkflow(workflowId, template);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED);

// Reattempt Provision
response = provisionWorkflow(workflowId);
assertEquals(RestStatus.OK, TestHelpers.restStatus(response));
getAndAssertWorkflowStatus(workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS);

// Wait until provisioning has completed successfully before attempting to retrieve created resources
List<ResourceCreated> resourcesCreated = getResourcesCreated(workflowId, 100);

// TODO : This template should create 2 resources, model_group_id and model_id, need to fix after feature branch is merged
assertEquals(0, resourcesCreated.size());
}

public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {

// Using a 3 step template to create a connector, register remote model and deploy model
Expand Down Expand Up @@ -69,7 +136,7 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception {
getAndAssertWorkflowStatus(workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS);

// Wait until provisioning has completed successfully before attempting to retrieve created resources
List<ResourceCreated> resourcesCreated = getResourcesCreated(workflowId);
List<ResourceCreated> resourcesCreated = getResourcesCreated(workflowId, 10);

// TODO : This template should create 2 resources, connector_id and model_id, need to fix after feature branch is merged
assertEquals(1, resourcesCreated.size());
Expand Down

0 comments on commit 9f1433b

Please sign in to comment.