Skip to content

Commit

Permalink
Add AsyncQueryRequestContext to QueryIdProvider parameter (#2870)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
  • Loading branch information
ykmr1224 committed Jul 31, 2024
1 parent aa7a690 commit 53bfeba
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@

package org.opensearch.sql.spark.dispatcher;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.utils.IDUtils;

/** Generates QueryId by embedding Datasource name and random UUID */
public class DatasourceEmbeddedQueryIdProvider implements QueryIdProvider {

@Override
public String getQueryId(DispatchQueryRequest dispatchQueryRequest) {
public String getQueryId(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext) {
return IDUtils.encode(dispatchQueryRequest.getDatasource());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package org.opensearch.sql.spark.dispatcher;

import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;

/** Interface for extension point to specify queryId. Called when new query is executed. */
public interface QueryIdProvider {
String getQueryId(DispatchQueryRequest dispatchQueryRequest);
String getQueryId(
DispatchQueryRequest dispatchQueryRequest, AsyncQueryRequestContext asyncQueryRequestContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ private DispatchQueryResponse handleFlintExtensionQuery(
DataSourceMetadata dataSourceMetadata) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
getDefaultDispatchContextBuilder(
dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
Expand All @@ -84,7 +85,8 @@ private DispatchQueryResponse handleDefaultQuery(
DataSourceMetadata dataSourceMetadata) {

DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
getDefaultDispatchContextBuilder(
dispatchQueryRequest, dataSourceMetadata, asyncQueryRequestContext)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

Expand All @@ -93,11 +95,13 @@ private DispatchQueryResponse handleDefaultQuery(
}

private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
DispatchQueryRequest dispatchQueryRequest, DataSourceMetadata dataSourceMetadata) {
DispatchQueryRequest dispatchQueryRequest,
DataSourceMetadata dataSourceMetadata,
AsyncQueryRequestContext asyncQueryRequestContext) {
return DispatchQueryContext.builder()
.dataSourceMetadata(dataSourceMetadata)
.tags(getDefaultTagsForJobSubmission(dispatchQueryRequest))
.queryId(queryIdProvider.getQueryId(dispatchQueryRequest));
.queryId(queryIdProvider.getQueryId(dispatchQueryRequest, asyncQueryRequestContext));
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public void setUp() {
public void createDropIndexQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
String indexName = "flint_datasource_name_table_name_index_name_index";
givenFlintIndexMetadataExists(indexName);
givenCancelJobRunSucceed();
Expand All @@ -209,7 +209,7 @@ public void createDropIndexQuery() {
public void createVacuumIndexQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
String indexName = "flint_datasource_name_table_name_index_name_index";
givenFlintIndexMetadataExists(indexName);

Expand All @@ -231,7 +231,7 @@ public void createVacuumIndexQuery() {
public void createAlterIndexQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
String indexName = "flint_datasource_name_table_name_index_name_index";
givenFlintIndexMetadataExists(indexName);
givenCancelJobRunSucceed();
Expand Down Expand Up @@ -261,7 +261,7 @@ public void createAlterIndexQuery() {
public void createStreamingQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
when(awsemrServerless.startJobRun(any()))
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));

Expand Down Expand Up @@ -297,7 +297,7 @@ private void verifyStartJobRunCalled() {
public void createCreateIndexQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
when(awsemrServerless.startJobRun(any()))
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));

Expand All @@ -321,7 +321,7 @@ public void createCreateIndexQuery() {
public void createRefreshQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
when(awsemrServerless.startJobRun(any()))
.thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID));

Expand All @@ -344,7 +344,7 @@ public void createInteractiveQuery() {
givenSparkExecutionEngineConfigIsSupplied();
givenValidDataSourceMetadataExist();
givenSessionExists();
when(queryIdProvider.getQueryId(any())).thenReturn(QUERY_ID);
when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID);
when(sessionIdProvider.getSessionId(any())).thenReturn(SESSION_ID);
givenSessionExists(); // called twice
when(awsemrServerless.startJobRun(any()))
Expand Down Expand Up @@ -538,7 +538,8 @@ private void givenGetJobRunReturnJobRunWithState(String state) {
}

private void verifyGetQueryIdCalled() {
verify(queryIdProvider).getQueryId(dispatchQueryRequestArgumentCaptor.capture());
verify(queryIdProvider)
.getQueryId(dispatchQueryRequestArgumentCaptor.capture(), eq(asyncQueryRequestContext));
DispatchQueryRequest dispatchQueryRequest = dispatchQueryRequestArgumentCaptor.getValue();
assertEquals(ACCOUNT_ID, dispatchQueryRequest.getAccountId());
assertEquals(APPLICATION_ID, dispatchQueryRequest.getApplicationId());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.dispatcher;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.verifyNoInteractions;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;

@ExtendWith(MockitoExtension.class)
class DatasourceEmbeddedQueryIdProviderTest {
@Mock AsyncQueryRequestContext asyncQueryRequestContext;

DatasourceEmbeddedQueryIdProvider datasourceEmbeddedQueryIdProvider =
new DatasourceEmbeddedQueryIdProvider();

@Test
public void test() {
String queryId =
datasourceEmbeddedQueryIdProvider.getQueryId(
DispatchQueryRequest.builder().datasource("DATASOURCE").build(),
asyncQueryRequestContext);

assertNotNull(queryId);
verifyNoInteractions(asyncQueryRequestContext);
}
}

0 comments on commit 53bfeba

Please sign in to comment.