From 5bcc23b6e1b6af1bdb8384d0e62d25984bc96d85 Mon Sep 17 00:00:00 2001 From: Tomoyuki Morita Date: Wed, 4 Sep 2024 16:09:28 -0700 Subject: [PATCH] Fix handler for existing query (#2968) Signed-off-by: Tomoyuki Morita (cherry picked from commit b14a8cb7eeb5bd74eb1aebfe7cd8d56bfe2c88d8) --- .../spark/dispatcher/RefreshQueryHandler.java | 4 +- .../dispatcher/SparkQueryDispatcher.java | 2 +- .../sql/spark/dispatcher/model/JobType.java | 1 + .../asyncquery/AsyncQueryCoreIntegTest.java | 36 +++++++++++++---- .../dispatcher/SparkQueryDispatcherTest.java | 39 +++++++++++++------ 5 files changed, 60 insertions(+), 22 deletions(-) diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 38145a143e..cf5a0c6c59 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -73,7 +73,7 @@ public String cancelJob( @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { - leaseManager.borrow(new LeaseRequest(JobType.BATCH, dispatchQueryRequest.getDatasource())); + leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource())); DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); @@ -83,7 +83,7 @@ public DispatchQueryResponse submit( .resultIndex(resp.getResultIndex()) .sessionId(resp.getSessionId()) .datasourceName(dataSourceMetadata.getName()) - .jobType(JobType.BATCH) + .jobType(JobType.REFRESH) .indexName(context.getIndexQueryDetails().openSearchIndexName()) .build(); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 710f472acb..6c207df45e 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -177,7 +177,7 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery( return queryHandlerFactory.getInteractiveQueryHandler(); } else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) { return queryHandlerFactory.getIndexDMLHandler(); - } else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) { + } else if (asyncQueryJobMetadata.getJobType() == JobType.REFRESH) { return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId()); } else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) { return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId()); diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java index 01f5f422e9..af1f69d74b 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/model/JobType.java @@ -8,6 +8,7 @@ public enum JobType { INTERACTIVE("interactive"), STREAMING("streaming"), + REFRESH("refresh"), BATCH("batch"); private String text; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index feb8c8c0ac..6f463e0f10 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -255,7 +255,7 @@ public void createAlterIndexQuery() { assertFalse(flintIndexOptions.autoRefresh()); verifyCancelJobRunCalled(); verifyCreateIndexDMLResultCalled(); - verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID); + verifyStoreJobMetadataCalled(DML_QUERY_JOB_ID, JobType.BATCH); } @Test @@ -280,7 +280,7 @@ public void createStreamingQuery() { verifyGetQueryIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.STREAMING); } private void verifyStartJobRunCalled() { @@ -315,7 +315,7 @@ public void createCreateIndexQuery() { assertNull(response.getSessionId()); verifyGetQueryIdCalled(); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.BATCH); } @Test @@ -337,7 +337,7 @@ public void createRefreshQuery() { verifyGetQueryIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.REFRESH); } @Test @@ -363,7 +363,7 @@ public void createInteractiveQuery() { verifyGetSessionIdCalled(); verify(leaseManager).borrow(any()); verifyStartJobRunCalled(); - verifyStoreJobMetadataCalled(JOB_ID); + verifyStoreJobMetadataCalled(JOB_ID, JobType.INTERACTIVE); } @Test @@ -454,7 +454,7 @@ public void cancelIndexDMLQuery() { @Test public void cancelRefreshQuery() { givenJobMetadataExists( - getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.BATCH).indexName(INDEX_NAME)); + getBaseAsyncQueryJobMetadataBuilder().jobType(JobType.REFRESH).indexName(INDEX_NAME)); when(flintIndexMetadataService.getFlintIndexMetadata(INDEX_NAME, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( @@ -507,7 +507,8 @@ private void givenSparkExecutionEngineConfigIsSupplied() { .build()); } - private void givenFlintIndexMetadataExists(String indexName) { + private void givenFlintIndexMetadataExists( + String indexName, FlintIndexOptions flintIndexOptions) { when(flintIndexMetadataService.getFlintIndexMetadata(indexName, asyncQueryRequestContext)) .thenReturn( ImmutableMap.of( @@ -516,9 +517,27 @@ private void givenFlintIndexMetadataExists(String indexName) { .appId(APPLICATION_ID) .jobId(JOB_ID) .opensearchIndexName(indexName) + .flintIndexOptions(flintIndexOptions) .build())); } + // Overload method for default FlintIndexOptions usage + private void givenFlintIndexMetadataExists(String indexName) { + givenFlintIndexMetadataExists(indexName, new FlintIndexOptions()); + } + + // Method to set up FlintIndexMetadata with external scheduler + private void givenFlintIndexMetadataExistsWithExternalScheduler(String indexName) { + givenFlintIndexMetadataExists(indexName, createExternalSchedulerFlintIndexOptions()); + } + + // Helper method for creating FlintIndexOptions with external scheduler + private FlintIndexOptions createExternalSchedulerFlintIndexOptions() { + FlintIndexOptions options = new FlintIndexOptions(); + options.setOption(FlintIndexOptions.SCHEDULER_MODE, "external"); + return options; + } + private void givenValidDataSourceMetadataExist() { when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( DATASOURCE_NAME, asyncQueryRequestContext)) @@ -560,7 +579,7 @@ private void verifyGetSessionIdCalled() { assertEquals(APPLICATION_ID, createSessionRequest.getApplicationId()); } - private void verifyStoreJobMetadataCalled(String jobId) { + private void verifyStoreJobMetadataCalled(String jobId, JobType jobType) { verify(asyncQueryJobMetadataStorageService) .storeJobMetadata( asyncQueryJobMetadataArgumentCaptor.capture(), eq(asyncQueryRequestContext)); @@ -568,6 +587,7 @@ private void verifyStoreJobMetadataCalled(String jobId) { assertEquals(QUERY_ID, asyncQueryJobMetadata.getQueryId()); assertEquals(jobId, asyncQueryJobMetadata.getJobId()); assertEquals(DATASOURCE_NAME, asyncQueryJobMetadata.getDatasourceName()); + assertEquals(jobType, asyncQueryJobMetadata.getJobType()); } private void verifyCreateIndexDMLResultCalled() { diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index b6369292a6..ba32647d19 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -377,6 +377,7 @@ void testDispatchCreateManualRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.BATCH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -661,6 +662,7 @@ void testRefreshIndexQuery() { verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(JobType.REFRESH, dispatchQueryResponse.getJobType()); verifyNoInteractions(flintIndexMetadataService); } @@ -831,12 +833,7 @@ void testDispatchWithUnSupportedDataSourceType() { @Test void testCancelJob() { - when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); - when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) - .thenReturn( - new CancelJobRunResult() - .withJobRunId(EMR_JOB_ID) - .withApplicationId(EMRS_APPLICATION_ID)); + givenCancelJobRunSucceed(); String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); @@ -897,17 +894,32 @@ void testCancelQueryWithInvalidStatementId() { @Test void testCancelQueryWithNoSessionId() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + @Test + void testCancelBatchJob() { + givenCancelJobRunSucceed(); + + String queryId = + sparkQueryDispatcher.cancelJob( + asyncQueryJobMetadata(JobType.BATCH), asyncQueryRequestContext); + + Assertions.assertEquals(QUERY_ID, queryId); + } + + private void givenCancelJobRunSucceed() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false)) .thenReturn( new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - - String queryId = - sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata(), asyncQueryRequestContext); - - Assertions.assertEquals(QUERY_ID, queryId); } @Test @@ -1154,11 +1166,16 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str } private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return asyncQueryJobMetadata(JobType.INTERACTIVE); + } + + private AsyncQueryJobMetadata asyncQueryJobMetadata(JobType jobType) { return AsyncQueryJobMetadata.builder() .queryId(QUERY_ID) .applicationId(EMRS_APPLICATION_ID) .jobId(EMR_JOB_ID) .datasourceName(MY_GLUE) + .jobType(jobType) .build(); }