Skip to content

Commit

Permalink
Add accountId to data models (#2709)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
  • Loading branch information
ykmr1224 authored Jun 5, 2024
1 parent c90cf00 commit ffc48fa
Show file tree
Hide file tree
Showing 33 changed files with 392 additions and 323 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,22 @@ public CreateAsyncQueryResponse createAsyncQuery(
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig(requestContext);
DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
new DispatchQueryRequest(
sparkExecutionEngineConfig.getApplicationId(),
createAsyncQueryRequest.getQuery(),
createAsyncQueryRequest.getDatasource(),
createAsyncQueryRequest.getLang(),
sparkExecutionEngineConfig.getExecutionRoleARN(),
sparkExecutionEngineConfig.getClusterName(),
sparkExecutionEngineConfig.getSparkSubmitParameterModifier(),
createAsyncQueryRequest.getSessionId()));
DispatchQueryRequest.builder()
.accountId(sparkExecutionEngineConfig.getAccountId())
.applicationId(sparkExecutionEngineConfig.getApplicationId())
.query(createAsyncQueryRequest.getQuery())
.datasource(createAsyncQueryRequest.getDatasource())
.langType(createAsyncQueryRequest.getLang())
.executionRoleARN(sparkExecutionEngineConfig.getExecutionRoleARN())
.clusterName(sparkExecutionEngineConfig.getClusterName())
.sparkSubmitParameterModifier(
sparkExecutionEngineConfig.getSparkSubmitParameterModifier())
.sessionId(createAsyncQueryRequest.getSessionId())
.build());
asyncQueryJobMetadataStorageService.storeJobMetadata(
AsyncQueryJobMetadata.builder()
.queryId(dispatchQueryResponse.getQueryId())
.accountId(sparkExecutionEngineConfig.getAccountId())
.applicationId(sparkExecutionEngineConfig.getApplicationId())
.jobId(dispatchQueryResponse.getJobId())
.resultIndex(dispatchQueryResponse.getResultIndex())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
@EqualsAndHashCode(callSuper = false)
public class AsyncQueryJobMetadata extends StateModel {
private final String queryId;
// optional: accountId for EMRS cluster
private final String accountId;
private final String applicationId;
private final String jobId;
private final String resultIndex;
Expand All @@ -44,6 +46,7 @@ public static AsyncQueryJobMetadata copy(
AsyncQueryJobMetadata copy, ImmutableMap<String, Object> metadata) {
return builder()
.queryId(copy.queryId)
.accountId(copy.accountId)
.applicationId(copy.getApplicationId())
.jobId(copy.getJobId())
.resultIndex(copy.getResultIndex())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ private void validateSparkExecutionEngineConfig(
}

private EMRServerlessClient createEMRServerlessClient(String awsRegion) {
// TODO: It does not handle accountId for now. (it creates client for same account)
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class StartJobRequest {
public static final Long DEFAULT_JOB_TIMEOUT = 120L;

private final String jobName;
// optional
private final String accountId;
private final String applicationId;
private final String executionRoleArn;
private final String sparkSubmitParams;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.opensearch.sql.spark.config;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

Expand All @@ -11,8 +10,8 @@
*/
@Data
@Builder
@AllArgsConstructor
public class SparkExecutionEngineConfig {
private String accountId;
private String applicationId;
private String region;
private String executionRoleARN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class SparkExecutionEngineConfigClusterSetting {
// optional
private String accountId;
private String applicationId;
private String region;
private String executionRoleARN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public DispatchQueryResponse submit(
StartJobRequest startJobRequest =
new StartJobRequest(
clusterName + ":" + JobType.BATCH.getText(),
dispatchQueryRequest.getAccountId(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public DispatchQueryResponse submit(
sessionManager.createSession(
new CreateSessionRequest(
clusterName,
dispatchQueryRequest.getAccountId(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public DispatchQueryResponse submit(
StartJobRequest startJobRequest =
new StartJobRequest(
jobName,
dispatchQueryRequest.getAccountId(),
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
package org.opensearch.sql.spark.dispatcher.model;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.config.SparkSubmitParameterModifier;
import org.opensearch.sql.spark.rest.model.LangType;

@AllArgsConstructor
@Data
@RequiredArgsConstructor // required explicitly
@Builder
public class DispatchQueryRequest {
private final String accountId;
private final String applicationId;
private final String query;
private final String datasource;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@Data
public class CreateSessionRequest {
private final String clusterName;
private final String accountId;
private final String applicationId;
private final String executionRoleArn;
private final SparkSubmitParameters sparkSubmitParameters;
Expand All @@ -24,6 +25,7 @@ public class CreateSessionRequest {
public StartJobRequest getStartJobRequest(String sessionId) {
return new InteractiveSessionStartJobRequest(
clusterName + ":" + JobType.INTERACTIVE.getText() + ":" + sessionId,
accountId,
applicationId,
executionRoleArn,
sparkSubmitParameters.toString(),
Expand All @@ -34,12 +36,21 @@ public StartJobRequest getStartJobRequest(String sessionId) {
static class InteractiveSessionStartJobRequest extends StartJobRequest {
public InteractiveSessionStartJobRequest(
String jobName,
String accountId,
String applicationId,
String executionRoleArn,
String sparkSubmitParams,
Map<String, String> tags,
String resultIndex) {
super(jobName, applicationId, executionRoleArn, sparkSubmitParams, tags, false, resultIndex);
super(
jobName,
accountId,
applicationId,
executionRoleArn,
sparkSubmitParams,
tags,
false,
resultIndex);
}

/** Interactive query keep running. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ public void open(CreateSessionRequest createSessionRequest) {
createSessionRequest.getStartJobRequest(sessionId.getSessionId());
String jobID = serverlessClient.startJobRun(startJobRequest);
String applicationId = startJobRequest.getApplicationId();
String accountId = createSessionRequest.getAccountId();

sessionModel =
initInteractiveSession(
applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
accountId, applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
sessionStorageService.createSession(sessionModel);
} catch (VersionConflictEngineException e) {
String errorMsg = "session already exist. " + sessionId;
Expand Down Expand Up @@ -99,6 +100,7 @@ public StatementId submit(QueryRequest request) {
Statement st =
Statement.builder()
.sessionId(sessionId)
.accountId(sessionModel.getAccountId())
.applicationId(sessionModel.getApplicationId())
.jobId(sessionModel.getJobId())
.statementStorageService(statementStorageService)
Expand Down Expand Up @@ -130,6 +132,7 @@ public Optional<Statement> get(StatementId stID) {
model ->
Statement.builder()
.sessionId(sessionId)
.accountId(model.getAccountId())
.applicationId(model.getApplicationId())
.jobId(model.getJobId())
.statementId(model.getStatementId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class SessionModel extends StateModel {
private final SessionType sessionType;
private final SessionId sessionId;
private final SessionState sessionState;
// optional: accountId for EMRS cluster
private final String accountId;
private final String applicationId;
private final String jobId;
private final String datasourceName;
Expand All @@ -37,6 +39,7 @@ public static SessionModel of(SessionModel copy, ImmutableMap<String, Object> me
.sessionId(new SessionId(copy.sessionId.getSessionId()))
.sessionState(copy.sessionState)
.datasourceName(copy.datasourceName)
.accountId(copy.accountId)
.applicationId(copy.getApplicationId())
.jobId(copy.jobId)
.error(UNKNOWN)
Expand All @@ -53,6 +56,7 @@ public static SessionModel copyWithState(
.sessionId(new SessionId(copy.sessionId.getSessionId()))
.sessionState(state)
.datasourceName(copy.datasourceName)
.accountId(copy.getAccountId())
.applicationId(copy.getApplicationId())
.jobId(copy.jobId)
.error(UNKNOWN)
Expand All @@ -62,13 +66,14 @@ public static SessionModel copyWithState(
}

public static SessionModel initInteractiveSession(
String applicationId, String jobId, SessionId sid, String datasourceName) {
String accountId, String applicationId, String jobId, SessionId sid, String datasourceName) {
return builder()
.version("1.0")
.sessionType(INTERACTIVE)
.sessionId(sid)
.sessionState(NOT_STARTED)
.datasourceName(datasourceName)
.accountId(accountId)
.applicationId(applicationId)
.jobId(jobId)
.error(UNKNOWN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public class Statement {
private static final Logger LOG = LogManager.getLogger();

private final SessionId sessionId;
// optional
private final String accountId;
private final String applicationId;
private final String jobId;
private final StatementId statementId;
Expand All @@ -42,6 +44,7 @@ public void open() {
statementModel =
submitStatement(
sessionId,
accountId,
applicationId,
jobId,
statementId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class StatementModel extends StateModel {
private final StatementState statementState;
private final StatementId statementId;
private final SessionId sessionId;
// optional: accountId for EMRS cluster
private final String accountId;
private final String applicationId;
private final String jobId;
private final LangType langType;
Expand All @@ -39,6 +41,7 @@ public static StatementModel copy(StatementModel copy, ImmutableMap<String, Obje
.statementState(copy.statementState)
.statementId(copy.statementId)
.sessionId(copy.sessionId)
.accountId(copy.accountId)
.applicationId(copy.applicationId)
.jobId(copy.jobId)
.langType(copy.langType)
Expand All @@ -58,6 +61,7 @@ public static StatementModel copyWithState(
.statementState(state)
.statementId(copy.statementId)
.sessionId(copy.sessionId)
.accountId(copy.accountId)
.applicationId(copy.applicationId)
.jobId(copy.jobId)
.langType(copy.langType)
Expand All @@ -72,6 +76,7 @@ public static StatementModel copyWithState(

public static StatementModel submitStatement(
SessionId sid,
String accountId,
String applicationId,
String jobId,
StatementId statementId,
Expand All @@ -84,6 +89,7 @@ public static StatementModel submitStatement(
.statementState(WAITING)
.statementId(statementId)
.sessionId(sid)
.accountId(accountId)
.applicationId(applicationId)
.jobId(jobId)
.langType(langType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.execution.xcontent;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.ACCOUNT_ID;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.APPLICATION_ID;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.DATASOURCE_NAME;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.JOB_ID;
Expand Down Expand Up @@ -39,6 +40,7 @@ public XContentBuilder toXContent(AsyncQueryJobMetadata jobMetadata, ToXContent.
.field(QUERY_ID, jobMetadata.getQueryId())
.field(TYPE, TYPE_JOBMETA)
.field(JOB_ID, jobMetadata.getJobId())
.field(ACCOUNT_ID, jobMetadata.getAccountId())
.field(APPLICATION_ID, jobMetadata.getApplicationId())
.field(RESULT_INDEX, jobMetadata.getResultIndex())
.field(SESSION_ID, jobMetadata.getSessionId())
Expand All @@ -63,6 +65,9 @@ public AsyncQueryJobMetadata fromXContent(XContentParser parser, long seqNo, lon
case JOB_ID:
builder.jobId(parser.textOrNull());
break;
case ACCOUNT_ID:
builder.accountId(parser.textOrNull());
break;
case APPLICATION_ID:
builder.applicationId(parser.textOrNull());
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.spark.execution.xcontent;

import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.ACCOUNT_ID;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.APPLICATION_ID;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.DATASOURCE_NAME;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.ERROR;
Expand Down Expand Up @@ -38,6 +39,7 @@ public XContentBuilder toXContent(
.field(VERSION, VERSION_1_0)
.field(TYPE, FLINT_INDEX_DOC_TYPE)
.field(STATE, flintIndexStateModel.getIndexState().getState())
.field(ACCOUNT_ID, flintIndexStateModel.getAccountId())
.field(APPLICATION_ID, flintIndexStateModel.getApplicationId())
.field(JOB_ID, flintIndexStateModel.getJobId())
.field(LATEST_ID, flintIndexStateModel.getLatestId())
Expand All @@ -60,6 +62,9 @@ public FlintIndexStateModel fromXContent(XContentParser parser, long seqNo, long
case STATE:
builder.indexState(FlintIndexState.fromString(parser.text()));
break;
case ACCOUNT_ID:
builder.accountId(parser.textOrNull());
break;
case APPLICATION_ID:
builder.applicationId(parser.text());
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.spark.execution.xcontent;

import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.ACCOUNT_ID;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.APPLICATION_ID;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.DATASOURCE_NAME;
import static org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes.ERROR;
Expand Down Expand Up @@ -42,6 +43,7 @@ public XContentBuilder toXContent(SessionModel sessionModel, ToXContent.Params p
.field(SESSION_ID, sessionModel.getSessionId().getSessionId())
.field(STATE, sessionModel.getSessionState().getSessionState())
.field(DATASOURCE_NAME, sessionModel.getDatasourceName())
.field(ACCOUNT_ID, sessionModel.getAccountId())
.field(APPLICATION_ID, sessionModel.getApplicationId())
.field(JOB_ID, sessionModel.getJobId())
.field(LAST_UPDATE_TIME, sessionModel.getLastUpdateTime())
Expand Down Expand Up @@ -77,6 +79,9 @@ public SessionModel fromXContent(XContentParser parser, long seqNo, long primary
case ERROR:
builder.error(parser.text());
break;
case ACCOUNT_ID:
builder.accountId(parser.textOrNull());
break;
case APPLICATION_ID:
builder.applicationId(parser.text());
break;
Expand Down
Loading

0 comments on commit ffc48fa

Please sign in to comment.