Skip to content

Commit

Permalink
Add query, langType, status, error in AsyncQueryJobMetadata (#2958)
Browse files Browse the repository at this point in the history
* Add query, langType, status, error in AsyncQueryJobMetadata

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Fix test

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

---------

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
Signed-off-by: Tomoyuki MORITA <moritato@amazon.com>
  • Loading branch information
ykmr1224 authored Sep 9, 2024
1 parent b76aa65 commit 1b1a1b5
Show file tree
Hide file tree
Showing 12 changed files with 119 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ public CreateAsyncQueryResponse createAsyncQuery(
.datasourceName(dispatchQueryResponse.getDatasourceName())
.jobType(dispatchQueryResponse.getJobType())
.indexName(dispatchQueryResponse.getIndexName())
.query(createAsyncQueryRequest.getQuery())
.langType(createAsyncQueryRequest.getLang())
.state(dispatchQueryResponse.getStatus())
.error(dispatchQueryResponse.getError())
.build(),
asyncQueryRequestContext);
return new CreateAsyncQueryResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import lombok.experimental.SuperBuilder;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.execution.statestore.StateModel;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.utils.SerializeUtils;

/** This class models all the metadata required for a job. */
Expand All @@ -35,6 +36,10 @@ public class AsyncQueryJobMetadata extends StateModel {
private final String datasourceName;
// null if JobType is INTERACTIVE or null
private final String indexName;
private final String query;
private final LangType langType;
private final QueryState state;
private final String error;

@Override
public String toString() {
Expand All @@ -54,6 +59,10 @@ public static AsyncQueryJobMetadata copy(
.datasourceName(copy.datasourceName)
.jobType(copy.jobType)
.indexName(copy.indexName)
.query(copy.query)
.langType(copy.langType)
.state(copy.state)
.error(copy.error)
.metadata(metadata)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.asyncquery.model;

import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Getter;

@Getter
public enum QueryState {
WAITING("waiting"),
RUNNING("running"),
SUCCESS("success"),
FAILED("failed"),
TIMEOUT("timeout"),
CANCELLED("cancelled");

private final String state;

QueryState(String state) {
this.state = state;
}

private static final Map<String, QueryState> STATES =
Arrays.stream(QueryState.values())
.collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));

public static QueryState fromString(String key) {
for (QueryState ss : QueryState.values()) {
if (ss.getState().toLowerCase(Locale.ROOT).equals(key)) {
return ss;
}
}
throw new IllegalArgumentException("Invalid query state: " + key);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
Expand Down Expand Up @@ -111,6 +112,7 @@ public DispatchQueryResponse submit(
.resultIndex(dataSourceMetadata.getResultIndex())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.status(QueryState.WAITING)
.indexName(getIndexName(context))
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -83,6 +84,7 @@ public DispatchQueryResponse submit(
.resultIndex(dataSourceMetadata.getResultIndex())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.status(QueryState.SUCCESS)
.build();
} catch (Exception e) {
LOG.error(e.getMessage());
Expand All @@ -101,6 +103,8 @@ public DispatchQueryResponse submit(
.resultIndex(dataSourceMetadata.getResultIndex())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.status(QueryState.FAILED)
.error(e.getMessage())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand Down Expand Up @@ -151,6 +152,7 @@ public DispatchQueryResponse submit(
.sessionId(session.getSessionId())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.INTERACTIVE)
.status(QueryState.WAITING)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
Expand Down Expand Up @@ -85,6 +86,7 @@ public DispatchQueryResponse submit(
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.REFRESH)
.indexName(context.getIndexQueryDetails().openSearchIndexName())
.status(QueryState.WAITING)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
Expand Down Expand Up @@ -102,6 +103,7 @@ public DispatchQueryResponse submit(
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.STREAMING)
.indexName(indexQueryDetails.openSearchIndexName())
.status(QueryState.WAITING)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import lombok.Builder;
import lombok.Getter;
import org.opensearch.sql.spark.asyncquery.model.QueryState;

@Getter
@Builder
Expand All @@ -13,4 +14,6 @@ public class DispatchQueryResponse {
private final String datasourceName;
private final JobType jobType;
private final String indexName;
private final QueryState status;
private final String error;
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata.AsyncQueryJobMetadataBuilder;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
Expand Down Expand Up @@ -205,7 +206,7 @@ public void createDropIndexQuery() {
verifyGetQueryIdCalled();
verifyCancelJobRunCalled();
verifyCreateIndexDMLResultCalled();
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH);
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, QueryState.SUCCESS, JobType.BATCH);
}

@Test
Expand All @@ -227,7 +228,7 @@ public void createDropIndexQueryWithScheduler() {
assertNull(response.getSessionId());
verifyGetQueryIdCalled();
verifyCreateIndexDMLResultCalled();
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH);
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, QueryState.SUCCESS, JobType.BATCH);

verify(asyncQueryScheduler).unscheduleJob(indexName);
}
Expand Down Expand Up @@ -255,7 +256,7 @@ public void createVacuumIndexQuery() {
verifyGetSessionIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, JobType.INTERACTIVE);
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.INTERACTIVE);
}

@Test
Expand Down Expand Up @@ -286,7 +287,7 @@ public void createAlterIndexQuery() {
assertFalse(flintIndexOptions.autoRefresh());
verifyCancelJobRunCalled();
verifyCreateIndexDMLResultCalled();
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH);
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, QueryState.SUCCESS, JobType.BATCH);
}

@Test
Expand Down Expand Up @@ -320,7 +321,7 @@ public void createAlterIndexQueryWithScheduler() {
verify(asyncQueryScheduler).unscheduleJob(indexName);

verifyCreateIndexDMLResultCalled();
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH);
verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, QueryState.SUCCESS, JobType.BATCH);
}

@Test
Expand All @@ -345,7 +346,7 @@ public void createStreamingQuery() {
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, JobType.STREAMING);
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.STREAMING);
}

private void verifyStartJobRunCalled() {
Expand Down Expand Up @@ -380,7 +381,7 @@ public void createCreateIndexQuery() {
assertNull(response.getSessionId());
verifyGetQueryIdCalled();
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH);
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.BATCH);
}

@Test
Expand All @@ -402,7 +403,7 @@ public void createRefreshQuery() {
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, JobType.REFRESH);
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.REFRESH);
}

@Test
Expand All @@ -428,7 +429,7 @@ public void createInteractiveQuery() {
verifyGetSessionIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, JobType.INTERACTIVE);
verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.INTERACTIVE);
}

@Test
Expand Down Expand Up @@ -644,14 +645,17 @@ private void verifyGetSessionIdCalled() {
assertEquals(APPLICATION_ID, createSessionRequest.getApplicationId());
}

private void verifyStoreJobMetadataCalled(String jobId, JobType jobType) {
private void verifyStoreJobMetadataCalled(String jobId, QueryState state, JobType jobType) {
verify(asyncQueryJobMetadataStorageService)
.storeJobMetadata(
asyncQueryJobMetadataArgumentCaptor.capture(), eq(asyncQueryRequestContext));
AsyncQueryJobMetadata asyncQueryJobMetadata = asyncQueryJobMetadataArgumentCaptor.getValue();
assertEquals(QUERY_ID, asyncQueryJobMetadata.getQueryId());
assertEquals(jobId, asyncQueryJobMetadata.getJobId());
assertEquals(DATASOURCE_NAME, asyncQueryJobMetadata.getDatasourceName());
assertNull(asyncQueryJobMetadata.getError());
assertEquals(LangType.SQL, asyncQueryJobMetadata.getLangType());
assertEquals(state, asyncQueryJobMetadata.getState());
assertEquals(jobType, asyncQueryJobMetadata.getJobType());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@
@ExtendWith(MockitoExtension.class)
public class AsyncQueryExecutorServiceImplTest {

private static final String QUERY = "select * from my_glue.default.http_logs";
private static final String QUERY_ID = "QUERY_ID";

@Mock private SparkQueryDispatcher sparkQueryDispatcher;
@Mock private AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService;
private AsyncQueryExecutorService jobExecutorService;

@Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
@Mock private SparkSubmitParameterModifier sparkSubmitParameterModifier;
@Mock private AsyncQueryRequestContext asyncQueryRequestContext;
private final String QUERY_ID = "QUERY_ID";

@BeforeEach
void setUp() {
Expand All @@ -68,8 +70,7 @@ void setUp() {
@Test
void testCreateAsyncQuery() {
CreateAsyncQueryRequest createAsyncQueryRequest =
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL);
new CreateAsyncQueryRequest(QUERY, "my_glue", LangType.SQL);
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any()))
.thenReturn(
SparkExecutionEngineConfig.builder()
Expand All @@ -82,7 +83,7 @@ void testCreateAsyncQuery() {
DispatchQueryRequest expectedDispatchQueryRequest =
DispatchQueryRequest.builder()
.applicationId(EMRS_APPLICATION_ID)
.query("select * from my_glue.default.http_logs")
.query(QUERY)
.datasource("my_glue")
.langType(LangType.SQL)
.executionRoleARN(EMRS_EXECUTION_ROLE)
Expand Down Expand Up @@ -134,9 +135,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() {
.build());

jobExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select * from my_glue.default.http_logs", "my_glue", LangType.SQL),
asyncQueryRequestContext);
new CreateAsyncQueryRequest(QUERY, "my_glue", LangType.SQL), asyncQueryRequestContext);

verify(sparkQueryDispatcher, times(1))
.dispatch(
Expand Down Expand Up @@ -237,6 +236,8 @@ private AsyncQueryJobMetadata getAsyncQueryJobMetadata() {
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.query(QUERY)
.langType(LangType.SQL)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.asyncquery.model;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;

class QueryStateTest {
@Test
public void testFromString() {
assertEquals(QueryState.WAITING, QueryState.fromString("waiting"));
assertEquals(QueryState.RUNNING, QueryState.fromString("running"));
assertEquals(QueryState.SUCCESS, QueryState.fromString("success"));
assertEquals(QueryState.FAILED, QueryState.fromString("failed"));
assertEquals(QueryState.CANCELLED, QueryState.fromString("cancelled"));
assertEquals(QueryState.TIMEOUT, QueryState.fromString("timeout"));
}

@Test
public void testFromStringWithUnknownState() {
assertThrows(IllegalArgumentException.class, () -> QueryState.fromString("UNKNOWN_STATE"));
}
}

0 comments on commit 1b1a1b5

Please sign in to comment.