Skip to content

Commit

Permalink
Support arithmetic binary in Delta check constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Apr 3, 2023
1 parent 7777ba6 commit 20dbb02
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ predicate[ParserRuleContext value]

valueExpression
: primaryExpression #valueExpressionDefault
| left=valueExpression operator=(ASTERISK | SLASH | PERCENT) right=valueExpression #arithmeticBinary
| left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary
| left=valueExpression operator=(AMPERSAND | CIRCUMFLEX) right=valueExpression #arithmeticBinary
;

primaryExpression
Expand Down Expand Up @@ -85,7 +88,13 @@ LTE: '<=';
GT: '>';
GTE: '>=';

PLUS: '+';
MINUS: '-';
ASTERISK: '*';
SLASH: '/';
PERCENT: '%';
AMPERSAND: '&';
CIRCUMFLEX: '^';

STRING
: '\'' ( ~'\'' | '\'\'' )* '\''
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.deltalake.expression;

import java.util.Objects;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public class ArithmeticBinaryExpression
extends SparkExpression
{
public enum Operator
{
ADD("+"),
SUBTRACT("-"),
MULTIPLY("*"),
DIVIDE("/"),
MODULUS("%"),
BITWISE_AND("&"),
BITWISE_XOR("^");

private final String value;

Operator(String value)
{
this.value = value;
}

public String getValue()
{
return value;
}
}

private final Operator operator;
private final SparkExpression left;
private final SparkExpression right;

public ArithmeticBinaryExpression(Operator operator, SparkExpression left, SparkExpression right)
{
this.operator = requireNonNull(operator, "operator is null");
this.left = requireNonNull(left, "left is null");
this.right = requireNonNull(right, "right is null");
}

public Operator getOperator()
{
return operator;
}

public SparkExpression getLeft()
{
return left;
}

public SparkExpression getRight()
{
return right;
}

@Override
<R, C> R accept(SparkExpressionTreeVisitor<R, C> visitor, C context)
{
return visitor.visitArithmeticBinary(this, context);
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ArithmeticBinaryExpression that = (ArithmeticBinaryExpression) o;
return operator == that.operator &&
Objects.equals(left, that.left) &&
Objects.equals(right, that.right);
}

@Override
public int hashCode()
{
return Objects.hash(operator, left, right);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("operator", operator)
.add("left", left)
.add("right", right)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,37 @@ public Object visitPredicated(SparkExpressionBaseParser.PredicatedContext contex
return visit(context.valueExpression);
}

@Override
public Object visitArithmeticBinary(SparkExpressionBaseParser.ArithmeticBinaryContext context)
{
return new ArithmeticBinaryExpression(
getArithmeticBinaryOperator(context.operator),
(SparkExpression) visit(context.left),
(SparkExpression) visit(context.right));
}

private static ArithmeticBinaryExpression.Operator getArithmeticBinaryOperator(Token operator)
{
switch (operator.getType()) {
case SparkExpressionBaseParser.PLUS:
return ArithmeticBinaryExpression.Operator.ADD;
case SparkExpressionBaseParser.MINUS:
return ArithmeticBinaryExpression.Operator.SUBTRACT;
case SparkExpressionBaseParser.ASTERISK:
return ArithmeticBinaryExpression.Operator.MULTIPLY;
case SparkExpressionBaseParser.SLASH:
return ArithmeticBinaryExpression.Operator.DIVIDE;
case SparkExpressionBaseParser.PERCENT:
return ArithmeticBinaryExpression.Operator.MODULUS;
case SparkExpressionBaseParser.AMPERSAND:
return ArithmeticBinaryExpression.Operator.BITWISE_AND;
case SparkExpressionBaseParser.CIRCUMFLEX:
return ArithmeticBinaryExpression.Operator.BITWISE_XOR;
}

throw new UnsupportedOperationException("Unsupported operator: " + operator.getText());
}

@Override
public Object visitComparison(SparkExpressionBaseParser.ComparisonContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
*/
package io.trino.plugin.deltalake.expression;

import static io.trino.plugin.deltalake.expression.ArithmeticBinaryExpression.Operator.BITWISE_AND;
import static io.trino.plugin.deltalake.expression.ArithmeticBinaryExpression.Operator.BITWISE_XOR;

public final class SparkExpressionConverter
{
private SparkExpressionConverter() {}
Expand All @@ -37,6 +40,18 @@ protected String visitComparisonExpression(ComparisonExpression node, Void conte
return "(%s %s %s)".formatted(process(node.getLeft(), context), node.getOperator().getValue(), process(node.getRight(), context));
}

@Override
protected String visitArithmeticBinary(ArithmeticBinaryExpression node, Void context)
{
if (node.getOperator() == BITWISE_AND) {
return "(bitwise_and(%s, %s))".formatted(process(node.getLeft(), context), process(node.getRight(), context));
}
if (node.getOperator() == BITWISE_XOR) {
return "(bitwise_xor(%s, %s))".formatted(process(node.getLeft(), context), process(node.getRight(), context));
}
return "(%s %s %s)".formatted(process(node.getLeft(), context), node.getOperator().getValue(), process(node.getRight(), context));
}

@Override
protected String visitLogicalExpression(LogicalExpression node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ protected R visitLogicalExpression(LogicalExpression node, C context)
return visitExpression(node, context);
}

protected R visitArithmeticBinary(ArithmeticBinaryExpression node, C context)
{
return visitExpression(node, context);
}

protected R visitIdentifier(Identifier node, C context)
{
return visitExpression(node, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package io.trino.plugin.deltalake.expression;

import io.trino.plugin.deltalake.expression.ArithmeticBinaryExpression.Operator;
import org.testng.annotations.Test;

import static io.trino.plugin.deltalake.expression.SparkExpressionParser.createExpression;
Expand Down Expand Up @@ -51,6 +52,20 @@ public void testUnsupportedStringLiteral()
assertParseFailure("r 'two spaces after prefix'", "extraneous input ''two spaces after prefix'' expecting <EOF>");
}

@Test
public void testArithmeticBinary()
{
assertEquals(createExpression("a + b * c"), new ArithmeticBinaryExpression(
Operator.ADD,
new Identifier("a"),
new ArithmeticBinaryExpression(Operator.MULTIPLY, new Identifier("b"), new Identifier("c"))));

assertEquals(createExpression("a * b + c"), new ArithmeticBinaryExpression(
Operator.ADD,
new ArithmeticBinaryExpression(Operator.MULTIPLY, new Identifier("a"), new Identifier("b")),
new Identifier("c")));
}

private static void assertStringLiteral(String sparkExpression, String expected)
{
SparkExpression expression = createExpression(sparkExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,22 @@ public void testIdentifier()
assertParseFailure("`a`b` = 1");
}

@Test
public void testArithmeticBinary()
{
assertExpressionTranslates("a = b % 1", "(\"a\" = (\"b\" % 1))");
assertExpressionTranslates("a = b * 1", "(\"a\" = (\"b\" * 1))");
assertExpressionTranslates("a = b + 1", "(\"a\" = (\"b\" + 1))");
assertExpressionTranslates("a = b - 1", "(\"a\" = (\"b\" - 1))");
assertExpressionTranslates("a = b / 1", "(\"a\" = (\"b\" / 1))");
assertExpressionTranslates("a = b & 1", "(\"a\" = (bitwise_and(\"b\", 1)))");
assertExpressionTranslates("a = b ^ 1", "(\"a\" = (bitwise_xor(\"b\", 1)))");
}

@Test
public void testInvalidNotBoolean()
{
assertParseFailure("a + a");
assertParseFailure("'Spark' || 'SQL'");
}

// TODO: Support following expressions
Expand All @@ -152,13 +164,6 @@ public void testUnsupportedOperator()
{
assertParseFailure("a <=> 1");
assertParseFailure("a == 1");
assertParseFailure("a = b % 1");
assertParseFailure("a = b & 1");
assertParseFailure("a = b * 1");
assertParseFailure("a = b + 1");
assertParseFailure("a = b - 1");
assertParseFailure("a = b / 1");
assertParseFailure("a = b ^ 1");
assertParseFailure("a = b::INTEGER");
assertParseFailure("a = json_column:root");
assertParseFailure("a BETWEEN 1 AND 10");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ public static Object[][] checkConstraints()
{"a INT", "a <= 1", "1", row(1), "2"},
{"a INT", "a <> 1", "2", row(2), "1"},
{"a INT", "a != 1", "2", row(2), "1"},
// Arithmetic binary
{"a INT, b INT", "a = b + 1", "2, 1", row(2, 1), "2, 2"},
{"a INT, b INT", "a = b - 1", "1, 2", row(1, 2), "1, 3"},
{"a INT, b INT", "a = b * 2", "4, 2", row(4, 2), "4, 3"},
{"a INT, b INT", "a = b / 2", "2, 4", row(2, 4), "2, 6"},
{"a INT, b INT", "a = b % 2", "1, 5", row(1, 5), "1, 6"},
{"a INT, b INT", "a = b & 5", "1, 3", row(1, 3), "1, 4"},
{"a INT, b INT", "a = b ^ 5", "6, 3", row(6, 3), "6, 4"},
// Supported types
{"a INT", "a < 100", "1", row(1), "100"},
{"a STRING", "a = 'valid'", "'valid'", row("valid"), "'invalid'"},
Expand Down

0 comments on commit 20dbb02

Please sign in to comment.