Skip to content

Commit

Permalink
Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetada…
Browse files Browse the repository at this point in the history
… method (#2866) (#2872)

* Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetadata method

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Add comments

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Fix style

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

---------

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

(cherry picked from commit ba82e12)
  • Loading branch information
ykmr1224 committed Jul 31, 2024
1 parent 0e07bec commit 3030ad5
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.spark.asyncquery.model;

import org.opensearch.sql.datasource.RequestContext;

/** Context interface to provide additional request related information */
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
public interface AsyncQueryRequestContext extends RequestContext {}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public DispatchQueryResponse dispatch(
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
dispatchQueryRequest.getDatasource(), asyncQueryRequestContext);

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ private void givenFlintIndexMetadataExists(String indexName) {
}

private void givenValidDataSourceMetadataExist() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(DATASOURCE_NAME))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
DATASOURCE_NAME, asyncQueryRequestContext))
.thenReturn(
new DataSourceMetadata.Builder()
.setName(DATASOURCE_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ void testDispatchSelectQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -221,7 +222,8 @@ void testDispatchSelectQueryWithLakeFormation() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithLakeFormation();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -253,7 +255,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -276,7 +279,8 @@ void testDispatchSelectQueryCreateNewSession() {
doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any());
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -302,7 +306,8 @@ void testDispatchSelectQueryReuseSession() {
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
when(session.isOperationalForDataSource(any())).thenReturn(true);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -322,7 +327,8 @@ void testDispatchSelectQueryFailedCreateSession() {
doReturn(true).when(sessionManager).isEnabled();
doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any());
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

Assertions.assertThrows(
Expand Down Expand Up @@ -356,7 +362,8 @@ void testDispatchCreateAutoRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -391,7 +398,8 @@ void testDispatchCreateManualRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -424,7 +432,8 @@ void testDispatchWithPPLQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -459,7 +468,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -495,7 +505,8 @@ void testDispatchIndexQueryWithoutADatasourceName() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -516,8 +527,7 @@ void testDispatchMaterializedViewQuery() {
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText());
String query =
"CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH"
+ " (auto_refresh = true)";
"CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
StartJobRequest expected =
new StartJobRequest(
Expand All @@ -531,7 +541,8 @@ void testDispatchMaterializedViewQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -564,7 +575,8 @@ void testDispatchShowMVQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -597,7 +609,8 @@ void testRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -630,7 +643,8 @@ void testDispatchDescribeIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -666,7 +680,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -689,7 +704,8 @@ void testDispatchAlterToManualRefreshIndexQuery() {
"ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH"
+ " (auto_refresh = false)";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -712,7 +728,8 @@ void testDispatchDropIndexQuery() {

String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -735,7 +752,8 @@ void testDispatchVacuumIndexQuery() {

String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -751,7 +769,8 @@ void testDispatchVacuumIndexQuery() {

@Test
void testDispatchWithUnSupportedDataSourceType() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_prometheus", asyncQueryRequestContext))
.thenReturn(constructPrometheusDataSourceType());
String query = "select * from my_prometheus.default.http_logs";

Expand Down Expand Up @@ -945,7 +964,8 @@ void testGetQueryResponseWithSuccess() {
void testDispatchQueryWithExtraSparkSubmitParameters() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

String extraParameters = "--conf spark.dynamicAllocation.enabled=false";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ public interface DataSourceService {
* Specifically for addressing use cases in SparkQueryDispatcher.
*
* @param dataSourceName of the {@link DataSource}
* @param context request context used by the implementation. It is passed by async-query-core.
* refer {@link RequestContext}
*/
DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName);
DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.datasource;

/**
* Context interface to provide additional request related information. It is introduced to allow
* async-query-core library user to pass request context information to implementations of data
* accessors.
*/
public interface RequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.sql.config.TestConfig;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceType;
Expand Down Expand Up @@ -236,7 +237,8 @@ public Boolean dataSourceExists(String dataSourceName) {
}

@Override
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) {
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext requestContext) {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.*;
import java.util.stream.Collectors;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceStatus;
Expand Down Expand Up @@ -122,7 +123,8 @@ public Boolean dataSourceExists(String dataSourceName) {
}

@Override
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) {
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext requestContext) {
DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName);
verifyDataSourceAccess(dataSourceMetadata);
return dataSourceMetadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceStatus;
Expand All @@ -52,6 +53,7 @@ class DataSourceServiceImplTest {
@Mock private DataSourceFactory dataSourceFactory;
@Mock private StorageEngine storageEngine;
@Mock private DataSourceMetadataStorage dataSourceMetadataStorage;
@Mock private RequestContext requestContext;

@Mock private DataSourceUserAuthorizationHelper dataSourceUserAuthorizationHelper;

Expand Down Expand Up @@ -461,7 +463,9 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadataWithDisabledData() {
DatasourceDisabledException datasourceDisabledException =
Assertions.assertThrows(
DatasourceDisabledException.class,
() -> dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS"));
() ->
dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"testDS", requestContext));
Assertions.assertEquals(
"Datasource testDS is disabled.", datasourceDisabledException.getMessage());
}
Expand All @@ -484,7 +488,7 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadata() {
when(dataSourceMetadataStorage.getDataSourceMetadata("testDS"))
.thenReturn(Optional.of(dataSourceMetadata));
DataSourceMetadata dataSourceMetadata1 =
dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS");
dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS", requestContext);
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.uri"));
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type"));
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username"));
Expand Down

0 comments on commit 3030ad5

Please sign in to comment.