Skip to content

Commit

Permalink
Restrict UDF functions
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <reddyvam@amazon.com>
  • Loading branch information
vamsi-amazon committed Jul 18, 2024
1 parent 956ec15 commit 00cc71e
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.spark.dispatcher;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -45,25 +46,50 @@ public DispatchQueryResponse dispatch(
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())
&& SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
} else {
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();
return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
String query = dispatchQueryRequest.getQuery();

if (SQLQueryUtils.isFlintExtensionQuery(query)) {
return handleFlintExtensionQuery(
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

List<String> validationErrors = SQLQueryUtils.validateSparkSqlQuery(query);
if (!validationErrors.isEmpty()) {
throw new IllegalArgumentException(
"Query is not allowed: " + String.join(", ", validationErrors));
}
}
return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

private DispatchQueryResponse handleFlintExtensionQuery(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext,
DataSourceMetadata dataSourceMetadata) {
IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest);
DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.indexQueryDetails(indexQueryDetails)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails)
.submit(dispatchQueryRequest, context);
}

private DispatchQueryResponse handleDefaultQuery(
DispatchQueryRequest dispatchQueryRequest,
AsyncQueryRequestContext asyncQueryRequestContext,
DataSourceMetadata dataSourceMetadata) {

DispatchQueryContext context =
getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata)
.asyncQueryRequestContext(asyncQueryRequestContext)
.build();

return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId())
.submit(dispatchQueryRequest, context);
}

private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -73,6 +75,32 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(String sqlQuery) {
SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor();
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
sparkSqlValidatorVisitor.visit(statement);
return sparkSqlValidatorVisitor.getValidationErrors();
} catch (SyntaxCheckException syntaxCheckException) {
return Collections.emptyList();
}
}

private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private final List<String> validationErrors = new ArrayList<>();

@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
validationErrors.add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import com.amazonaws.services.emrserverless.model.JobRun;
import com.amazonaws.services.emrserverless.model.JobRunState;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -438,6 +440,85 @@ void testDispatchWithPPLQuery() {
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testDispatchWithSparkUDFQuery() {
List<String> udfQueries = new ArrayList<>();
udfQueries.add(
"CREATE FUNCTION celsius_to_fahrenheit AS 'org.apache.spark.sql.functions.expr(\"(celsius *"
+ " 9/5) + 32\")'");
udfQueries.add(
"CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'");
for (String query : udfQueries) {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);

IllegalArgumentException illegalArgumentException =
Assertions.assertThrows(
IllegalArgumentException.class,
() ->
sparkQueryDispatcher.dispatch(
getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(),
asyncQueryRequestContext));
Assertions.assertEquals(
"Query is not allowed: Creating user-defined functions is not allowed",
illegalArgumentException.getMessage());
verifyNoInteractions(emrServerlessClient);
verifyNoInteractions(flintIndexMetadataService);
}
}

@Test
void testInvalidSQLQueryDispatchToSpark() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
HashMap<String, String> tags = new HashMap<>();
tags.put(DATASOURCE_TAG_KEY, MY_GLUE);
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText());
String query = "myselect 1";
String sparkSubmitParameters =
constructExpectedSparkSubmitParameterString(
"sigv4",
new HashMap<>() {
{
put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1");
}
},
query);
StartJobRequest expected =
new StartJobRequest(
"TEST_CLUSTER:batch",
null,
EMRS_APPLICATION_ID,
EMRS_EXECUTION_ROLE,
sparkSubmitParameters,
tags,
false,
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
sparkQueryDispatcher.dispatch(
DispatchQueryRequest.builder()
.applicationId(EMRS_APPLICATION_ID)
.query(query)
.datasource(MY_GLUE)
.langType(LangType.SQL)
.executionRoleARN(EMRS_EXECUTION_ROLE)
.clusterName(TEST_CLUSTER_NAME)
.sparkSubmitParameterModifier(sparkSubmitParameterModifier)
.build(),
asyncQueryRequestContext);

verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture());
Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue());
Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId());
verifyNoInteractions(flintIndexMetadataService);
}

@Test
void testDispatchQueryWithoutATableAndDataSourceName() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
Expand Down
2 changes: 2 additions & 0 deletions docs/user/interfaces/asyncqueryinterface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Async Query Creation API
======================================
If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``.

Limitation: Spark SQL queries that create User-Defined Functions (UDFs) are not allowed.

HTTP URI: ``_plugins/_async_query``

HTTP VERB: ``POST``
Expand Down

0 comments on commit 00cc71e

Please sign in to comment.