diff --git a/plugin/trino-delta-lake/src/main/antlr4/io/trino/plugin/deltalake/expression/SparkExpressionBase.g4 b/plugin/trino-delta-lake/src/main/antlr4/io/trino/plugin/deltalake/expression/SparkExpressionBase.g4 index bbc97f4ace83..bf91194ad03e 100644 --- a/plugin/trino-delta-lake/src/main/antlr4/io/trino/plugin/deltalake/expression/SparkExpressionBase.g4 +++ b/plugin/trino-delta-lake/src/main/antlr4/io/trino/plugin/deltalake/expression/SparkExpressionBase.g4 @@ -40,6 +40,9 @@ predicate[ParserRuleContext value] valueExpression : primaryExpression #valueExpressionDefault + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(CIRCUMFLEX | AMPERSAND) right=valueExpression #arithmeticBinary ; primaryExpression @@ -85,7 +88,13 @@ LTE: '<='; GT: '>'; GTE: '>='; +PLUS: '+'; MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +AMPERSAND: '&'; +CIRCUMFLEX: '^'; STRING : '\'' ( ~'\'' | '\'\'' )* '\'' diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/ArithmeticBinaryExpression.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/ArithmeticBinaryExpression.java new file mode 100644 index 000000000000..21f6cd6f923c --- /dev/null +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/ArithmeticBinaryExpression.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.deltalake.expression; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; + +public class ArithmeticBinaryExpression + extends SparkExpression +{ + public enum Operator + { + ADD("+"), + SUBTRACT("-"), + MULTIPLY("*"), + DIVIDE("/"), + MODULUS("%"), + BITWISE_AND("&"), + BITWISE_XOR("^"); + private final String value; + + Operator(String value) + { + this.value = value; + } + + public String getValue() + { + return value; + } + } + + private final Operator operator; + private final SparkExpression left; + private final SparkExpression right; + + public ArithmeticBinaryExpression(Operator operator, SparkExpression left, SparkExpression right) + { + this.operator = operator; + this.left = left; + this.right = right; + } + + public Operator getOperator() + { + return operator; + } + + public SparkExpression getLeft() + { + return left; + } + + public SparkExpression getRight() + { + return right; + } + + @Override + R accept(SparkExpressionTreeVisitor visitor, C context) + { + return visitor.visitArithmeticBinary(this, context); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArithmeticBinaryExpression that = (ArithmeticBinaryExpression) o; + return operator == that.operator && + Objects.equals(left, that.left) && + Objects.equals(right, that.right); + } + + @Override + public int hashCode() + { + return Objects.hash(operator, left, right); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("operator", operator) + .add("left", left) + .add("right", right) + .toString(); + } +} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionBuilder.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionBuilder.java index 7bdd4267fcb9..c7b28ee8dcf3 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionBuilder.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionBuilder.java @@ -46,6 +46,37 @@ public Object visitPredicated(SparkExpressionBaseParser.PredicatedContext contex return visit(context.valueExpression); } + @Override + public Object visitArithmeticBinary(SparkExpressionBaseParser.ArithmeticBinaryContext context) + { + return new ArithmeticBinaryExpression( + getArithmeticBinaryOperator(context.operator), + (SparkExpression) visit(context.left), + (SparkExpression) visit(context.right)); + } + + private static ArithmeticBinaryExpression.Operator getArithmeticBinaryOperator(Token operator) + { + switch (operator.getType()) { + case SparkExpressionBaseParser.PLUS: + return ArithmeticBinaryExpression.Operator.ADD; + case SparkExpressionBaseParser.MINUS: + return ArithmeticBinaryExpression.Operator.SUBTRACT; + case SparkExpressionBaseParser.ASTERISK: + return ArithmeticBinaryExpression.Operator.MULTIPLY; + case SparkExpressionBaseParser.SLASH: + return ArithmeticBinaryExpression.Operator.DIVIDE; + case SparkExpressionBaseParser.PERCENT: + return ArithmeticBinaryExpression.Operator.MODULUS; + case SparkExpressionBaseParser.AMPERSAND: + return ArithmeticBinaryExpression.Operator.BITWISE_AND; + case SparkExpressionBaseParser.CIRCUMFLEX: + return ArithmeticBinaryExpression.Operator.BITWISE_XOR; + } + + throw new UnsupportedOperationException("Unsupported operator: " + operator.getText()); + } + @Override public Object visitComparison(SparkExpressionBaseParser.ComparisonContext context) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionConverter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionConverter.java index 564e3c3fa112..0888b533925a 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionConverter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionConverter.java @@ -13,6 +13,9 @@ */ package io.trino.plugin.deltalake.expression; +import static io.trino.plugin.deltalake.expression.ArithmeticBinaryExpression.Operator.BITWISE_AND; +import static io.trino.plugin.deltalake.expression.ArithmeticBinaryExpression.Operator.BITWISE_XOR; + public final class SparkExpressionConverter { private SparkExpressionConverter() {} @@ -37,6 +40,18 @@ protected String visitComparisonExpression(ComparisonExpression node, Void conte return "(%s %s %s)".formatted(process(node.getLeft(), context), node.getOperator().getValue(), process(node.getRight(), context)); } + @Override + protected String visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + { + if (node.getOperator() == BITWISE_AND) { + return "(bitwise_and(%s, %s))".formatted(process(node.getLeft(), context), process(node.getRight(), context)); + } + if (node.getOperator() == BITWISE_XOR) { + return "(bitwise_xor(%s, %s))".formatted(process(node.getLeft(), context), process(node.getRight(), context)); + } + return "(%s %s %s)".formatted(process(node.getLeft(), context), node.getOperator().getValue(), process(node.getRight(), context)); + } + @Override protected String visitLogicalExpression(LogicalExpression node, Void context) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionTreeVisitor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionTreeVisitor.java index 0010ef12e39a..eb80af1f4b25 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionTreeVisitor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/expression/SparkExpressionTreeVisitor.java @@ -34,6 +34,11 @@ protected R visitLogicalExpression(LogicalExpression node, C context) return visitExpression(node, context); } + protected R visitArithmeticBinary(ArithmeticBinaryExpression node, C context) + { + return visitExpression(node, context); + } + protected R visitIdentifier(Identifier node, C context) { return visitExpression(node, context); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/expression/TestSparkExpressions.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/expression/TestSparkExpressions.java index 0d57b4314b44..8d2678ae4901 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/expression/TestSparkExpressions.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/expression/TestSparkExpressions.java @@ -127,10 +127,22 @@ public void testIdentifier() assertParseFailure("`a`b` = 1"); } + @Test + public void testArithmeticBinary() + { + assertExpressionTranslates("a = b % 1", "(\"a\" = (\"b\" % 1))"); + assertExpressionTranslates("a = b * 1", "(\"a\" = (\"b\" * 1))"); + assertExpressionTranslates("a = b + 1", "(\"a\" = (\"b\" + 1))"); + assertExpressionTranslates("a = b - 1", "(\"a\" = (\"b\" - 1))"); + assertExpressionTranslates("a = b / 1", "(\"a\" = (\"b\" / 1))"); + assertExpressionTranslates("a = b & 1", "(\"a\" = (bitwise_and(\"b\", 1)))"); + assertExpressionTranslates("a = b ^ 1", "(\"a\" = (bitwise_xor(\"b\", 1)))"); + } + @Test public void testInvalidNotBoolean() { - assertParseFailure("a + a"); + assertParseFailure("'Spark' || 'SQL'"); } // TODO: Support following expressions @@ -152,13 +164,6 @@ public void testUnsupportedOperator() { assertParseFailure("a <=> 1"); assertParseFailure("a == 1"); - assertParseFailure("a = b % 1"); - assertParseFailure("a = b & 1"); - assertParseFailure("a = b * 1"); - assertParseFailure("a = b + 1"); - assertParseFailure("a = b - 1"); - assertParseFailure("a = b / 1"); - assertParseFailure("a = b ^ 1"); assertParseFailure("a = b::INTEGER"); assertParseFailure("a = json_column:root"); assertParseFailure("a BETWEEN 1 AND 10"); diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeCheckConstraintCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeCheckConstraintCompatibility.java index 48a19fbb73ef..687068a5af09 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeCheckConstraintCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeCheckConstraintCompatibility.java @@ -81,6 +81,14 @@ public static Object[][] checkConstraints() {"a INT", "a <= 1", "1", row(1), "2"}, {"a INT", "a <> 1", "2", row(2), "1"}, {"a INT", "a != 1", "2", row(2), "1"}, + // Arithmetic binary + {"a INT, b INT", "a = b + 1", "2, 1", row(2, 1), "2, 2"}, + {"a INT, b INT", "a = b - 1", "1, 2", row(1, 2), "1, 3"}, + {"a INT, b INT", "a = b * 2", "4, 2", row(4, 2), "4, 3"}, + {"a INT, b INT", "a = b / 2", "2, 4", row(2, 4), "2, 6"}, + {"a INT, b INT", "a = b % 2", "1, 5", row(1, 5), "1, 6"}, + {"a INT, b INT", "a = b & 5", "1, 3", row(1, 3), "1, 4"}, + {"a INT, b INT", "a = b ^ 5", "6, 3", row(6, 3), "6, 4"}, // Supported types {"a INT", "a < 100", "1", row(1), "100"}, {"a STRING", "a = 'valid'", "'valid'", row("valid"), "'invalid'"},