Skip to content

Commit

Permalink
Adds validation to allow only flint queries and sql SELECT queries to…
Browse files Browse the repository at this point in the history
… security lake type datasource (opensearch-project#2959)

* allows only flint queries and select sql queries to security lake datasource

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* add sql validator for security lake and refactor validateSparkSqlQuery class

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* spotless fixes

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* address review comments.

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* address comment to extract validate logic into a separate method in tests

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* add more tests to get more code coverage

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

---------

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>
  • Loading branch information
eirsep committed Sep 4, 2024
1 parent 729bb13 commit 6c5c685
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ public DispatchQueryResponse dispatch(
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

List<String> validationErrors = SQLQueryUtils.validateSparkSqlQuery(query);
List<String> validationErrors =
SQLQueryUtils.validateSparkSqlQuery(
dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query);
if (!validationErrors.isEmpty()) {
throw new IllegalArgumentException(
"Query is not allowed: " + String.join(", ", validationErrors));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,21 @@
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.tree.ParseTree;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser.MaterializedViewQueryContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.StatementContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
Expand All @@ -38,13 +43,14 @@
*/
@UtilityClass
public class SQLQueryUtils {
private static final Logger logger = LogManager.getLogger(SQLQueryUtils.class);

public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
StatementContext statement = sqlBaseParser.statement();
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
statement.accept(sparkSqlTableNameVisitor);
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
Expand Down Expand Up @@ -77,32 +83,73 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(String sqlQuery) {
SparkSqlValidatorVisitor sparkSqlValidatorVisitor = new SparkSqlValidatorVisitor();
public static List<String> validateSparkSqlQuery(DataSource datasource, String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
sparkSqlValidatorVisitor.visit(statement);
return sparkSqlValidatorVisitor.getValidationErrors();
} catch (SyntaxCheckException syntaxCheckException) {
SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource);
StatementContext statement = sqlBaseParser.statement();
sqlParserBaseVisitor.visit(statement);
return sqlParserBaseVisitor.getValidationErrors();
} catch (SyntaxCheckException e) {
logger.error(
String.format(
"Failed to parse sql statement context while validating sql query %s", sqlQuery),
e);
return Collections.emptyList();
}
}

private static class SparkSqlValidatorVisitor extends SqlBaseParserBaseVisitor<Void> {
private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) {
if (datasource != null
&& datasource.getConnectorType() != null
&& datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) {
return new SparkSqlSecurityLakeValidatorVisitor();
} else {
return new SparkSqlValidatorVisitor();
}
}

@Getter private final List<String> validationErrors = new ArrayList<>();
/**
* A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class
* supports accumulating validation errors on visiting sql statement
*/
@Getter
private static class SqlBaseValidatorVisitor<T> extends SqlBaseParserBaseVisitor<T> {
private final List<String> validationErrors = new ArrayList<>();
}

/** A generic validator impl for Spark Sql Queries */
private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor<Void> {
@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
validationErrors.add("Creating user-defined functions is not allowed");
getValidationErrors().add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

/** A validator impl specific to Security Lake for Spark Sql Queries */
private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor<Void> {

public SparkSqlSecurityLakeValidatorVisitor() {
// only select statement allowed. hence we add the validation error to all types of statements
// by default
// and remove the validation error only for select statement.
getValidationErrors()
.add(
"Unsupported sql statement for security lake data source. Only select queries are"
+ " allowed");
}

@Override
public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) {
getValidationErrors().clear();
return super.visitStatementDefault(ctx);
}
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index;
import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv;
import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex;
Expand All @@ -18,7 +19,10 @@
import lombok.Getter;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
Expand All @@ -27,6 +31,8 @@
@ExtendWith(MockitoExtension.class)
public class SQLQueryUtilsTest {

@Mock private DataSource dataSource;

@Test
void testExtractionOfTableNameFromSQLQueries() {
String sqlQuery = "select * from my_glue.default.http_logs";
Expand Down Expand Up @@ -404,15 +410,96 @@ void testAutoRefresh() {

@Test
void testValidateSparkSqlQuery_ValidQuery() {
String validQuery = "SELECT * FROM users WHERE age > 18";
List<String> errors = SQLQueryUtils.validateSparkSqlQuery(validQuery);
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'",
DataSourceType.PROMETHEUS);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors");
}

@Test
void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

@Test
void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() {
List<String> errors =
validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

@Test
void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE);

assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

@Test
void testValidateSparkSqlQuery_nullDatasource() {
List<String> errors =
SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18");
assertTrue(errors.isEmpty(), "Valid query should not produce any errors ");
}

private List<String> validateSparkSqlQueryForDataSourceType(
String query, DataSourceType dataSourceType) {
when(this.dataSource.getConnectorType()).thenReturn(dataSourceType);

return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query);
}

@Test
void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() {
List<String> errors =
validateSparkSqlQueryForDataSourceType(
"REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE);

assertFalse(
errors.isEmpty(),
"Invalid query as Security Lake datasource supports only flint queries and SELECT sql"
+ " queries. Given query was REFRESH sql query");
assertEquals(
errors.get(0),
"Unsupported sql statement for security lake data source. Only select queries are allowed");
}

@Test
void
testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() {
String query =
"CREATE TABLE AccountSummaryOrWhatever AS "
+ "select taxid, address1, count(address1) from dbo.t "
+ "group by taxid, address1;";

List<String> errors =
validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE);

assertFalse(
errors.isEmpty(),
"Invalid query as Security Lake datasource supports only flint queries and SELECT sql"
+ " queries. Given query was REFRESH sql query");
assertEquals(
errors.get(0),
"Unsupported sql statement for security lake data source. Only select queries are allowed");
}

@Test
void testValidateSparkSqlQuery_InvalidQuery() {
when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS);
String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'";
List<String> errors = SQLQueryUtils.validateSparkSqlQuery(invalidQuery);

List<String> errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery);

assertFalse(errors.isEmpty(), "Invalid query should produce errors");
assertEquals(1, errors.size(), "Should have one error");
assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import java.util.HashMap;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
@EqualsAndHashCode
public class DataSourceType {
public static DataSourceType PROMETHEUS = new DataSourceType("PROMETHEUS");
public static DataSourceType OPENSEARCH = new DataSourceType("OPENSEARCH");
Expand Down

0 comments on commit 6c5c685

Please sign in to comment.