Skip to content

Commit

Permalink
Pass accountId to EMRServerlessClientFactory.getClient (#2783)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
  • Loading branch information
ykmr1224 committed Jun 28, 2024
1 parent 00f82f5 commit e24b51f
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public interface EMRServerlessClientFactory {
/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @param accountId Account ID of the requester. It will be used to decide the cluster.
* @return An {@link EMRServerlessClient} instance.
*/
EMRServerlessClient getClient();
EMRServerlessClient getClient(String accountId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.metrics.MetricsService;

/** Implementation of {@link EMRServerlessClientFactory}. */
@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

Expand All @@ -27,13 +26,8 @@ public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactor
private EMRServerlessClient emrServerlessClient;
private String region;

/**
* Gets an instance of {@link EMRServerlessClient}.
*
* @return An {@link EMRServerlessClient} instance.
*/
@Override
public EMRServerlessClient getClient() {
public EMRServerlessClient getClient(String accountId) {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(
new NullAsyncQueryRequestContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ public class QueryHandlerFactory {
private final EMRServerlessClientFactory emrServerlessClientFactory;
private final MetricsService metricsService;

public RefreshQueryHandler getRefreshQueryHandler() {
public RefreshQueryHandler getRefreshQueryHandler(String accountId) {
return new RefreshQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
flintIndexMetadataService,
leaseManager,
flintIndexOpFactory,
metricsService);
}

public StreamingQueryHandler getStreamingQueryHandler() {
public StreamingQueryHandler getStreamingQueryHandler(String accountId) {
return new StreamingQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService);
}

public BatchQueryHandler getBatchQueryHandler() {
public BatchQueryHandler getBatchQueryHandler(String accountId) {
return new BatchQueryHandler(
emrServerlessClientFactory.getClient(),
emrServerlessClientFactory.getClient(accountId),
jobExecutionResponseReader,
leaseManager,
metricsService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ public DispatchQueryResponse dispatch(
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(indexQueryDetails)
return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler().submit(dispatchQueryRequest, context);
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}
}

Expand All @@ -74,28 +75,28 @@ private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchConte
}

private AsyncQueryHandler getQueryHandlerForFlintExtensionQuery(
IndexQueryDetails indexQueryDetails) {
DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) {
if (isEligibleForIndexDMLHandling(indexQueryDetails)) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (isEligibleForStreamingQuery(indexQueryDetails)) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) {
// Create should be handled by batch handler. This is to avoid DROP index incorrectly cancel
// an interactive job.
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(dispatchQueryRequest.getAccountId());
} else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) {
// Manual refresh should be handled by batch handler
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(dispatchQueryRequest.getAccountId());
} else {
return getDefaultAsyncQueryHandler();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId());
}
}

@NotNull
private AsyncQueryHandler getDefaultAsyncQueryHandler() {
private AsyncQueryHandler getDefaultAsyncQueryHandler(String accountId) {
return sessionManager.isEnabled()
? queryHandlerFactory.getInteractiveQueryHandler()
: queryHandlerFactory.getBatchQueryHandler();
: queryHandlerFactory.getBatchQueryHandler(accountId);
}

@NotNull
Expand Down Expand Up @@ -143,11 +144,11 @@ private AsyncQueryHandler getAsyncQueryHandlerForExistingQuery(
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return queryHandlerFactory.getIndexDMLHandler();
} else if (asyncQueryJobMetadata.getJobType() == JobType.BATCH) {
return queryHandlerFactory.getRefreshQueryHandler();
return queryHandlerFactory.getRefreshQueryHandler(asyncQueryJobMetadata.getAccountId());
} else if (asyncQueryJobMetadata.getJobType() == JobType.STREAMING) {
return queryHandlerFactory.getStreamingQueryHandler();
return queryHandlerFactory.getStreamingQueryHandler(asyncQueryJobMetadata.getAccountId());
} else {
return queryHandlerFactory.getBatchQueryHandler();
return queryHandlerFactory.getBatchQueryHandler(asyncQueryJobMetadata.getAccountId());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public Session createSession(
.sessionId(sessionIdProvider.getSessionId(request))
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(request.getAccountId()))
.build();
session.open(request, asyncQueryRequestContext);
return session;
Expand Down Expand Up @@ -65,7 +65,7 @@ public Optional<Session> getSession(String sessionId, String dataSourceName) {
.sessionId(sessionId)
.sessionStorageService(sessionStorageService)
.statementStorageService(statementStorageService)
.serverlessClient(emrServerlessClientFactory.getClient())
.serverlessClient(emrServerlessClientFactory.getClient(model.get().getAccountId()))
.sessionModel(model.get())
.sessionInactivityTimeoutMilli(
sessionConfigSupplier.getSessionInactivityTimeoutMillis())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public void cancelStreamingJob(FlintIndexStateModel flintIndexStateModel)
throws InterruptedException, TimeoutException {
String applicationId = flintIndexStateModel.getApplicationId();
String jobId = flintIndexStateModel.getJobId();
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient =
emrServerlessClientFactory.getClient(flintIndexStateModel.getAccountId());
try {
emrServerlessClient.cancelJobRun(
flintIndexStateModel.getApplicationId(), flintIndexStateModel.getJobId(), true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
@ExtendWith(MockitoExtension.class)
public class EMRServerlessClientFactoryImplTest {

public static final String ACCOUNT_ID = "accountId";
@Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
@Mock private MetricsService metricsService;

Expand All @@ -30,7 +31,9 @@ public void testGetClient() {
.thenReturn(createSparkExecutionEngineConfig());
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient();

EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID);

Assertions.assertNotNull(emrserverlessClient);
}

Expand All @@ -41,16 +44,16 @@ public void testGetClientWithChangeInSetting() {
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient();
EMRServerlessClient emrserverlessClient = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertNotNull(emrserverlessClient);

EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient1 = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertEquals(emrServerlessClient1, emrserverlessClient);

sparkExecutionEngineConfig.setRegion(TestConstants.US_WEST_REGION);
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any()))
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient();
EMRServerlessClient emrServerlessClient2 = emrServerlessClientFactory.getClient(ACCOUNT_ID);
Assertions.assertNotEquals(emrServerlessClient2, emrserverlessClient);
Assertions.assertNotEquals(emrServerlessClient2, emrServerlessClient1);
}
Expand All @@ -60,9 +63,11 @@ public void testGetClientWithException() {
when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(any())).thenReturn(null);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class, emrServerlessClientFactory::getClient);
IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID));

Assertions.assertEquals(
"Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config"
+ " in cluster settings to enable them.",
Expand All @@ -77,9 +82,11 @@ public void testGetClientWithExceptionWithNullRegion() {
.thenReturn(sparkExecutionEngineConfig);
EMRServerlessClientFactory emrServerlessClientFactory =
new EMRServerlessClientFactoryImpl(sparkExecutionEngineConfigSupplier, metricsService);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class, emrServerlessClientFactory::getClient);
IllegalArgumentException.class, () -> emrServerlessClientFactory.getClient(ACCOUNT_ID));

Assertions.assertEquals(
"Async Query APIs are disabled. Please configure plugins.query.executionengine.spark.config"
+ " in cluster settings to enable them.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void setUp() {

@Test
void testDispatchSelectQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -179,7 +179,7 @@ void testDispatchSelectQuery() {

@Test
void testDispatchSelectQueryWithLakeFormation() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -220,7 +220,7 @@ void testDispatchSelectQueryWithLakeFormation() {

@Test
void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -262,7 +262,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {

@Test
void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -368,7 +368,7 @@ void testDispatchSelectQueryFailedCreateSession() {

@Test
void testDispatchCreateAutoRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -413,7 +413,7 @@ void testDispatchCreateAutoRefreshIndexQuery() {

@Test
void testDispatchCreateManualRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -456,7 +456,7 @@ void testDispatchCreateManualRefreshIndexQuery() {

@Test
void testDispatchWithPPLQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -499,7 +499,7 @@ void testDispatchWithPPLQuery() {

@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -540,7 +540,7 @@ void testDispatchQueryWithoutATableAndDataSourceName() {

@Test
void testDispatchIndexQueryWithoutADatasourceName() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -585,7 +585,7 @@ void testDispatchIndexQueryWithoutADatasourceName() {

@Test
void testDispatchMaterializedViewQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(INDEX_TAG_KEY, "flint_mv_1");
Expand Down Expand Up @@ -630,7 +630,7 @@ void testDispatchMaterializedViewQuery() {

@Test
void testDispatchShowMVQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -671,7 +671,7 @@ void testDispatchShowMVQuery() {

@Test
void testRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -712,7 +712,7 @@ void testRefreshIndexQuery() {

@Test
void testDispatchDescribeIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
Expand Down Expand Up @@ -753,7 +753,7 @@ void testDispatchDescribeIndexQuery() {

@Test
void testDispatchAlterToAutoRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, "my_glue");
tags.put(INDEX_TAG_KEY, "flint_my_glue_default_http_logs_elb_and_requesturi_index");
Expand Down Expand Up @@ -906,7 +906,7 @@ void testDispatchWithUnSupportedDataSourceType() {

@Test
void testCancelJob() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
Expand Down Expand Up @@ -968,7 +968,7 @@ void testCancelQueryWithInvalidStatementId() {

@Test
void testCancelQueryWithNoSessionId() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.cancelJobRun(EMRS_APPLICATION_ID, EMR_JOB_ID, false))
.thenReturn(
new CancelJobRunResult()
Expand All @@ -982,7 +982,7 @@ void testCancelQueryWithNoSessionId() {

@Test
void testGetQueryResponse() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(emrServerlessClient.getJobRunResult(EMRS_APPLICATION_ID, EMR_JOB_ID))
.thenReturn(new GetJobRunResult().withJobRun(new JobRun().withState(JobRunState.PENDING)));
// simulate result index is not created yet
Expand Down Expand Up @@ -1079,7 +1079,7 @@ void testGetQueryResponseWithSuccess() {

@Test
void testDispatchQueryWithExtraSparkSubmitParameters() {
when(emrServerlessClientFactory.getClient()).thenReturn(emrServerlessClient);
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);
Expand Down
Loading

0 comments on commit e24b51f

Please sign in to comment.