Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate PPL-builtin functions to Spark-builtin functions #448

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,25 @@ logicalExpression
| left = logicalExpression OR right = logicalExpression # logicalOr
| left = logicalExpression (AND)? right = logicalExpression # logicalAnd
| left = logicalExpression XOR right = logicalExpression # logicalXor
| booleanExpression # booleanExpr
;

comparisonExpression
: left = valueExpression comparisonOperator right = valueExpression # compareExpr
| valueExpression IN valueList # inExpr
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added inExpr back since it is required by position functions.

;

valueExpression
: primaryExpression # valueExpressionDefault
: left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic
| left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic
| primaryExpression # valueExpressionDefault
| positionFunction # positionFunctionCall
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
;

primaryExpression
: fieldExpression
: evalFunctionCall
| fieldExpression
| literalValue
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import org.opensearch.sql.ppl.utils.SortUtils;
import scala.Option;
Expand Down Expand Up @@ -397,7 +398,21 @@ public Expression visitEval(Eval node, CatalystPlanContext context) {

@Override
public Expression visitFunction(Function node, CatalystPlanContext context) {
throw new IllegalStateException("Not Supported operation : Function");
List<Expression> arguments =
node.getFuncArgs().stream()
.map(
unresolvedExpression -> {
var ret = analyze(unresolvedExpression, context);
if (ret == null) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", unresolvedExpression));
} else {
return context.popNamedParseExpressions().get();
}
})
.collect(Collectors.toList());
Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments);
return context.getNamedParseExpressions().push(function);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ public UnresolvedExpression visitCompareExpr(OpenSearchPPLParser.CompareExprCont
return new Compare(ctx.comparisonOperator().getText(), visit(ctx.left), visit(ctx.right));
}

/**
* Value Expression.
*/
@Override
public UnresolvedExpression visitBinaryArithmetic(OpenSearchPPLParser.BinaryArithmeticContext ctx) {
return new Function(
ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right)));
}

@Override
public UnresolvedExpression visitParentheticValueExpr(OpenSearchPPLParser.ParentheticValueExprContext ctx) {
return visit(ctx.valueExpression()); // Discard parenthesis around
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl.utils;

import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

import java.util.List;
import java.util.Locale;

import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.Option.empty;

public interface BuiltinFunctionTranslator {

static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List<Expression> args) {
if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) {
// TODO change it when UDF is supported
// TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions
throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL");
} else {
String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT);
return new UnresolvedFunction(seq(name), seq(args), false, empty(),false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
package org.opensearch.sql.ppl.utils;


import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.unsafe.types.UTF8String;
import org.opensearch.sql.ast.expression.SpanUnit;
Expand All @@ -33,8 +38,8 @@
* translate the PPL ast expressions data-types into catalyst data-types
*/
public interface DataTypeTransformer {
static <T> Seq<T> seq(T element) {
return seq(List.of(element));
static <T> Seq<T> seq(T... elements) {
return seq(List.of(elements));
}
static <T> Seq<T> seq(List<T> list) {
return asScalaBufferConverter(list).asScala().seq();
Expand All @@ -46,6 +51,16 @@ static DataType translate(org.opensearch.sql.ast.expression.DataType source) {
return DateType$.MODULE$;
case INTEGER:
return IntegerType$.MODULE$;
case LONG:
return LongType$.MODULE$;
case DOUBLE:
return DoubleType$.MODULE$;
case FLOAT:
return FloatType$.MODULE$;
case BOOLEAN:
return BooleanType$.MODULE$;
case SHORT:
return ShortType$.MODULE$;
case BYTE:
return ByteType$.MODULE$;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,4 @@ trait LogicalPlanTestUtils {
// Return the string representation of the transformed plan
transformedPlan.toString
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.common.antlr.SyntaxCheckException
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}

class PPLLogicalPlanMathFunctionsTranslatorTestSuite
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add additional IT tests that use these commands with a more complex context of full fledge query use cases

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

extends SparkFunSuite
with LogicalPlanTestUtils
with Matchers {

private val planTransformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test abs") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = abs(b)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("abs", seq(UnresolvedAttribute("b")), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test ceil") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ceil(b)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("ceil", seq(UnresolvedAttribute("b")), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test floor") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = floor(b)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("floor", seq(UnresolvedAttribute("b")), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test ln") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = ln(b)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("ln", seq(UnresolvedAttribute("b")), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test mod") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(plan(pplParser, "source=t a = mod(10, 4)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("mod", seq(Literal(10), Literal(4)), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test pow") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = pow(2, 3)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("pow", seq(Literal(2), Literal(3)), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test sqrt") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = sqrt(b)", false), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("sqrt", seq(UnresolvedAttribute("b")), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}

test("test arithmetic: + - * / %") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a few more complex test for such operators

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(
pplParser,
"source=t | where sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1 - sqrt(pow(a, 2)) * 1 = sqrt(pow(a, 2)) % 1",
false),
context)
val table = UnresolvedRelation(Seq("t"))
// sqrt(pow(a, 2))
val sqrtPow =
UnresolvedFunction(
"sqrt",
seq(
UnresolvedFunction(
"pow",
seq(UnresolvedAttribute("a"), Literal(2)),
isDistinct = false)),
isDistinct = false)
// sqrt(pow(a, 2)) / 1
val sqrtPowDivide = UnresolvedFunction("divide", seq(sqrtPow, Literal(1)), isDistinct = false)
// sqrt(pow(a, 2)) * 1
val sqrtPowMultiply =
UnresolvedFunction("multiply", seq(sqrtPow, Literal(1)), isDistinct = false)
// sqrt(pow(a, 2)) % 1
val sqrtPowMod = UnresolvedFunction("modulus", seq(sqrtPow, Literal(1)), isDistinct = false)
// sqrt(pow(a, 2)) + sqrt(pow(a, 2)) / 1
val add = UnresolvedFunction("add", seq(sqrtPow, sqrtPowDivide), isDistinct = false)
val sub = UnresolvedFunction("subtract", seq(add, sqrtPowMultiply), isDistinct = false)
val filterExpr = EqualTo(sub, sqrtPowMod)
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
assertEquals(expectedPlan, logPlan)
}
}
Loading
Loading