diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 38145a143e..cf5a0c6c59 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -73,7 +73,7 @@ public String cancelJob( @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { - leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); + leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource())); DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); @@ -83,7 +83,7 @@ public DispatchQueryResponse submit( .resultIndex(resp.getResultIndex()) .sessionId(resp.getSessionId()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.BATCH) + .jobType(JobType.REFRESH) .indexName(context.getIndexQueryDetails().openSearchIndexName()) .build(); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 710f472acb..6c207df45e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -177,7 +177,7 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { return queryHandlerFactory.getIndexDMLHandler(); - } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { + } else if (asyncQueryJobMetadata.getJobType() == JobType.REFRESH) { return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId()); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java index 01f5f422e9..af1f69d74b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java @@ -8,6 +8,7 @@ public enum JobType { INTERACTIVE("interactive"), STREAMING("streaming"), + REFRESH("refresh"), BATCH("batch"); private String text; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index b6369292a6..ba32647d19 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -377,6 +377,7 @@ void testDispatchCreateManualRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -661,6 +662,7 @@ void testRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -831,12 +833,7 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) - .thenReturn( - new CancelJobRunResult() - .withJobRunId(EMR_JOB_ID) - .withApplicationId(EMRS_APPLICATION_ID)); + givenCancelJobRunSucceed(); String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); @@ -897,17 +894,32 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + @Test + void testCancelBatchJob() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadata(JobType.BATCH), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + private void givenCancelJobRunSucceed() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - - String queryId = - sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); - - Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -1154,11 +1166,16 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str } private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return asyncQueryJobMetadata(JobType.INTERACTIVE); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadata(JobType jobType) { return AsyncQueryJobMetadata.builder() .queryId(QUERY_ID) .applicationId(EMRS_APPLICATION_ID) .jobId(EMR_JOB_ID) .datasourceName(MY_GLUE) + .jobType(jobType) .build(); }