Skip to content

Commit

Permalink
Fix handler for existing query (opensearch-project#2968)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
  • Loading branch information
ykmr1224 committed Sep 4, 2024
1 parent 6c5c685 commit b14a8cb
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public String cancelJob(
@Override
public DispatchQueryResponse submit(
DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) {
leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource()));
leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource()));

DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context);
DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata();
Expand All @@ -83,7 +83,7 @@ public DispatchQueryResponse submit(
.resultIndex(resp.getResultIndex())
.sessionId(resp.getSessionId())
.datasourceName(dataSourceMetadata.getName())
.jobType(JobType.BATCH)
.jobType(JobType.REFRESH)
.indexName(context.getIndexQueryDetails().openSearchIndexName())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
return queryHandlerFactory.getInteractiveQueryHandler();
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
} else if (asyncQueryJobMetadata.getJobType() == JobType.REFRESH) {
return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId());
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
public enum JobType {
INTERACTIVE("interactive"),
STREAMING("streaming"),
REFRESH("refresh"),
BATCH("batch");

private String text;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ public void createRefreshQuery() {
verifyGetQueryIdCalled();
verify(leaseManager).borrow(any());
verifyStartJobRunCalled();
verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH);
verifyStoreJobMetadataCalled(JOB_ID, JobType.REFRESH);
}

@Test
Expand Down Expand Up @@ -541,7 +541,7 @@ public void cancelIndexDMLQuery() {
@Test
public void cancelRefreshQuery() {
givenJobMetadataExists(
getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME));
getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.REFRESH).indexName(INDEX_NAME));
when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext))
.thenReturn(
ImmutableMap.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ void testDispatchCreateManualRefreshIndexQuery() {
verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -685,6 +686,7 @@ void testRefreshIndexQuery() {
verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType());
verifyNoInteractions(flintIndexMetadataService);
}

Expand Down Expand Up @@ -859,12 +861,7 @@ void testDispatchWithUnSupportedDataSourceType() {

@Test
void testCancelJob() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));
givenCancelJobRunSucceed();

String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);
Expand Down Expand Up @@ -925,17 +922,32 @@ void testCancelQueryWithInvalidStatementId() {

@Test
void testCancelQueryWithNoSessionId() {
givenCancelJobRunSucceed();

String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
}

@Test
void testCancelBatchJob() {
givenCancelJobRunSucceed();

String queryId =
sparkQueryDispatcher.cancelJob(
asyncQueryJobMetadata(JobType.BATCH), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
}

private void givenCancelJobRunSucceed() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
.withJobRunId(EMR_JOB_ID)
.withApplicationId(EMRS_APPLICATION_ID));

String queryId =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext);

Assertions.assertEquals(QUERY_ID, queryId);
}

@Test
Expand Down Expand Up @@ -1184,11 +1196,16 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str
}

private AsyncQueryJobMetadata asyncQueryJobMetadata() {
return asyncQueryJobMetadata(JobType.INTERACTIVE);
}

private AsyncQueryJobMetadata asyncQueryJobMetadata(JobType jobType) {
return AsyncQueryJobMetadata.builder()
.queryId(QUERY_ID)
.applicationId(EMRS_APPLICATION_ID)
.jobId(EMR_JOB_ID)
.datasourceName(MY_GLUE)
.jobType(jobType)
.build();
}

Expand Down

0 comments on commit b14a8cb

Please sign in to comment.