From 84eb4b3e2ee642e50af82ccc11b8e3a6197de5f5 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Fri, 19 Jul 2024 13:31:27 -0700 Subject: [PATCH] Restrict UDF functions (#2790) Signed-off-by: Vamsi Manohar (cherry picked from commit db2bd6619453a320eadd741df8ca318a905f3fa4) --- .../src/main/antlr/SqlBaseLexer.g4 | 1 + .../src/main/antlr/SqlBaseParser.g4 | 52 +++++++++---- .../dispatcher/SparkQueryDispatcher.java | 62 +++++++++++----- .../sql/spark/utils/SQLQueryUtils.java | 28 +++++++ .../dispatcher/SparkQueryDispatcherTest.java | 73 +++++++++++++++++++ .../sql/spark/utils/SQLQueryUtilsTest.java | 19 +++++ docs/user/interfaces/asyncqueryinterface.rst | 2 + 7 files changed, 206 insertions(+), 31 deletions(-) diff --git a/async-query-core/src/main/antlr/SqlBaseLexer.g4 b/async-query-core/src/main/antlr/SqlBaseLexer.g4 index 85a4633e80..bde298c23e 100644 --- a/async-query-core/src/main/antlr/SqlBaseLexer.g4 +++ b/async-query-core/src/main/antlr/SqlBaseLexer.g4 @@ -316,6 +316,7 @@ NANOSECOND: 'NANOSECOND'; NANOSECONDS: 'NANOSECONDS'; NATURAL: 'NATURAL'; NO: 'NO'; +NONE: 'NONE'; NOT: 'NOT'; NULL: 'NULL'; NULLS: 'NULLS'; diff --git a/async-query-core/src/main/antlr/SqlBaseParser.g4 b/async-query-core/src/main/antlr/SqlBaseParser.g4 index 4d39e1717a..e8c40bec7d 100644 --- a/async-query-core/src/main/antlr/SqlBaseParser.g4 +++ b/async-query-core/src/main/antlr/SqlBaseParser.g4 @@ -52,7 +52,7 @@ singleCompoundStatement ; beginEndCompoundBlock - : BEGIN compoundBody END + : beginLabel? BEGIN compoundBody END endLabel? ; compoundBody @@ -61,11 +61,26 @@ compoundBody compoundStatement : statement + | setStatementWithOptionalVarKeyword | beginEndCompoundBlock ; +setStatementWithOptionalVarKeyword + : SET (VARIABLE | VAR)? assignmentList #setVariableWithOptionalKeyword + | SET (VARIABLE | VAR)? LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword + ; + singleStatement - : statement SEMICOLON* EOF + : (statement|setResetStatement) SEMICOLON* EOF + ; + +beginLabel + : multipartIdentifier COLON + ; + +endLabel + : multipartIdentifier ; singleExpression @@ -175,6 +190,8 @@ statement | ALTER TABLE identifierReference (partitionSpec)? SET locationSpec #setTableLocation | ALTER TABLE identifierReference RECOVER PARTITIONS #recoverPartitions + | ALTER TABLE identifierReference + (clusterBySpec | CLUSTER BY NONE) #alterClusterBy | DROP TABLE (IF EXISTS)? identifierReference PURGE? #dropTable | DROP VIEW (IF EXISTS)? identifierReference #dropView | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? @@ -203,7 +220,7 @@ statement identifierReference dataType? variableDefaultExpression? #createVariable | DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? - statement #explain + (statement|setResetStatement) #explain | SHOW TABLES ((FROM | IN) identifierReference)? (LIKE? pattern=stringLit)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) ns=identifierReference)? @@ -242,26 +259,29 @@ statement | (MSCK)? REPAIR TABLE identifierReference (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable | op=(ADD | LIST) identifier .*? #manageResource - | SET COLLATION collationName=identifier #setCollation - | SET ROLE .*? #failNativeCommand + | CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE? + identifierReference (USING indexType=identifier)? + LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN + (OPTIONS options=propertyList)? #createIndex + | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +setResetStatement + : SET COLLATION collationName=identifier #setCollation + | SET ROLE .*? #failSetRole | SET TIME ZONE interval #setTimeZone | SET TIME ZONE timezone #setTimeZone | SET TIME ZONE .*? #setTimeZone | SET (VARIABLE | VAR) assignmentList #setVariable | SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ - LEFT_PAREN query RIGHT_PAREN #setVariable + LEFT_PAREN query RIGHT_PAREN #setVariable | SET configKey EQ configValue #setQuotedConfiguration | SET configKey (EQ .*?)? #setConfiguration | SET .*? EQ configValue #setQuotedConfiguration | SET .*? #setConfiguration | RESET configKey #resetQuotedConfiguration | RESET .*? #resetConfiguration - | CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE? - identifierReference (USING indexType=identifier)? - LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN - (OPTIONS options=propertyList)? #createIndex - | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex - | unsupportedHiveNativeCommands .*? #failNativeCommand ; executeImmediate @@ -854,13 +874,17 @@ identifierComment relationPrimary : identifierReference temporalClause? - sample? tableAlias #tableName + optionsClause? sample? tableAlias #tableName | LEFT_PAREN query RIGHT_PAREN sample? tableAlias #aliasedQuery | LEFT_PAREN relation RIGHT_PAREN sample? tableAlias #aliasedRelation | inlineTable #inlineTableDefault2 | functionTable #tableValuedFunction ; +optionsClause + : WITH options=propertyList + ; + inlineTable : VALUES expression (COMMA expression)* tableAlias ; @@ -1573,6 +1597,7 @@ ansiNonReserved | NANOSECOND | NANOSECONDS | NO + | NONE | NULLS | NUMERIC | OF @@ -1921,6 +1946,7 @@ nonReserved | NANOSECOND | NANOSECONDS | NO + | NONE | NOT | NULL | NULLS diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 3366e21894..0e871f9ddc 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.dispatcher; import java.util.HashMap; +import java.util.List; import java.util.Map; import lombok.AllArgsConstructor; import org.jetbrains.annotations.NotNull; @@ -45,25 +46,50 @@ public DispatchQueryResponse dispatch( this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata( dispatchQueryRequest.getDatasource()); - if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) - && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { - IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); - DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) - .indexQueryDetails(indexQueryDetails) - .asyncQueryRequestContext(asyncQueryRequestContext) - .build(); - - return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails) - .submit(dispatchQueryRequest, context); - } else { - DispatchQueryContext context = - getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) - .asyncQueryRequestContext(asyncQueryRequestContext) - .build(); - return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()) - .submit(dispatchQueryRequest, context); + if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { + String query = dispatchQueryRequest.getQuery(); + + if (SQLQueryUtils.isFlintExtensionQuery(query)) { + return handleFlintExtensionQuery( + dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); + } + + List validationErrors = SQLQueryUtils.validateSparkSqlQuery(query); + if (!validationErrors.isEmpty()) { + throw new IllegalArgumentException( + "Query is not allowed: " + String.join(", ", validationErrors)); + } } + return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); + } + + private DispatchQueryResponse handleFlintExtensionQuery( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext, + DataSourceMetadata dataSourceMetadata) { + IndexQueryDetails indexQueryDetails = getIndexQueryDetails(dispatchQueryRequest); + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .indexQueryDetails(indexQueryDetails) + .asyncQueryRequestContext(asyncQueryRequestContext) + .build(); + + return getQueryHandlerForFlintExtensionQuery(dispatchQueryRequest, indexQueryDetails) + .submit(dispatchQueryRequest, context); + } + + private DispatchQueryResponse handleDefaultQuery( + DispatchQueryRequest dispatchQueryRequest, + AsyncQueryRequestContext asyncQueryRequestContext, + DataSourceMetadata dataSourceMetadata) { + + DispatchQueryContext context = + getDefaultDispatchContextBuilder(dispatchQueryRequest, dataSourceMetadata) + .asyncQueryRequestContext(asyncQueryRequestContext) + .build(); + + return getDefaultAsyncQueryHandler(dispatchQueryRequest.getAccountId()) + .submit(dispatchQueryRequest, context); } private DispatchQueryContext.DispatchQueryContextBuilder getDefaultDispatchContextBuilder( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index a96e203cea..0bb9cb4b85 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.utils; +import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -73,6 +75,32 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } + public static List validateSparkSqlQuery(String sqlQuery) { + SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor(); + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); + sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + try { + SqlBaseParser.StatementContext statement = sqlBaseParser.statement(); + sparkSqlValidatorVisitor.visit(statement); + return sparkSqlValidatorVisitor.getValidationErrors(); + } catch (SyntaxCheckException syntaxCheckException) { + return Collections.emptyList(); + } + } + + private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor { + + @Getter private final List validationErrors = new ArrayList<>(); + + @Override + public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + validationErrors.add("Creating user-defined functions is not allowed"); + return super.visitCreateFunction(ctx); + } + } + public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { @Getter private List fullyQualifiedTableNames = new LinkedList<>(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 5582de332c..f9a83ef9f6 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -41,9 +41,11 @@ import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; @@ -438,6 +440,77 @@ void testDispatchWithPPLQuery() { verifyNoInteractions(flintIndexMetadataService); } + @Test + void testDispatchWithSparkUDFQuery() { + List udfQueries = new ArrayList<>(); + udfQueries.add( + "CREATE FUNCTION celsius_to_fahrenheit AS 'org.apache.spark.sql.functions.expr(\"(celsius *" + + " 9/5) + 32\")'"); + udfQueries.add( + "CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'"); + for (String query : udfQueries) { + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + .thenReturn(dataSourceMetadata); + + IllegalArgumentException illegalArgumentException = + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + sparkQueryDispatcher.dispatch( + getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(), + asyncQueryRequestContext)); + Assertions.assertEquals( + "Query is not allowed: Creating user-defined functions is not allowed", + illegalArgumentException.getMessage()); + verifyNoInteractions(emrServerlessClient); + verifyNoInteractions(flintIndexMetadataService); + } + } + + @Test + void testInvalidSQLQueryDispatchToSpark() { + when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); + HashMap tags = new HashMap<>(); + tags.put(DATASOURCE_TAG_KEY, MY_GLUE); + tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + String query = "myselect 1"; + String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query); + StartJobRequest expected = + new StartJobRequest( + "TEST_CLUSTER:batch", + null, + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + "query_execution_result_my_glue"); + when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID); + DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); + when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE)) + .thenReturn(dataSourceMetadata); + + DispatchQueryResponse dispatchQueryResponse = + sparkQueryDispatcher.dispatch( + DispatchQueryRequest.builder() + .applicationId(EMRS_APPLICATION_ID) + .query(query) + .datasource(MY_GLUE) + .langType(LangType.SQL) + .executionRoleARN(EMRS_EXECUTION_ROLE) + .clusterName(TEST_CLUSTER_NAME) + .sparkSubmitParameterModifier(sparkSubmitParameterModifier) + .build(), + asyncQueryRequestContext); + + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); + verifyNoInteractions(flintIndexMetadataService); + } + @Test void testDispatchQueryWithoutATableAndDataSourceName() { when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 0d7c43fc0d..bf6fe9e5db 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -390,6 +390,25 @@ void testAutoRefresh() { .autoRefresh()); } + @Test + void testValidateSparkSqlQuery_ValidQuery() { + String validQuery = "SELECT * FROM users WHERE age > 18"; + List errors = SQLQueryUtils.validateSparkSqlQuery(validQuery); + assertTrue(errors.isEmpty(), "Valid query should not produce any errors"); + } + + @Test + void testValidateSparkSqlQuery_InvalidQuery() { + String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'"; + List errors = SQLQueryUtils.validateSparkSqlQuery(invalidQuery); + assertFalse(errors.isEmpty(), "Invalid query should produce errors"); + assertEquals(1, errors.size(), "Should have one error"); + assertEquals( + "Creating user-defined functions is not allowed", + errors.get(0), + "Error message should match"); + } + @Getter protected static class IndexQuery { private String query; diff --git a/docs/user/interfaces/asyncqueryinterface.rst b/docs/user/interfaces/asyncqueryinterface.rst index af49a59838..9b889f7f97 100644 --- a/docs/user/interfaces/asyncqueryinterface.rst +++ b/docs/user/interfaces/asyncqueryinterface.rst @@ -68,6 +68,8 @@ Async Query Creation API ====================================== If security plugin is enabled, this API can only be invoked by users with permission ``cluster:admin/opensearch/ql/async_query/create``. +Limitation: Spark SQL queries that create User-Defined Functions (UDFs) are not allowed. + HTTP URI: ``_plugins/_async_query`` HTTP VERB: ``POST``