Skip to content

Commit

Permalink
Translate CAST to connector expression
Browse files Browse the repository at this point in the history
Co-authored-by: Raunaq Morarka <raunaqmorarka@gmail.com>
  • Loading branch information
2 people authored and findepi committed Apr 8, 2022
1 parent 7302afa commit b7cd83c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.VarcharType;
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
Expand Down Expand Up @@ -79,6 +78,7 @@
import static io.trino.SystemSessionProperties.isComplexExpressionPushdown;
import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME;
Expand All @@ -97,7 +97,10 @@
import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature;
import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -225,9 +228,13 @@ protected Optional<Expression> translateCall(Call call)
if (IS_NULL_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) {
return translateIsNull(call.getArguments().get(0));
}

if (NULLIF_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 2) {
return translateNullIf(call.getArguments().get(0), call.getArguments().get(1));
}
if (CAST_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) {
return translateCast(call.getType(), call.getArguments().get(0));
}

// comparisons
if (call.getArguments().size() == 2) {
Expand Down Expand Up @@ -303,6 +310,16 @@ private Optional<Expression> translateNot(ConnectorExpression argument)
if (argument.getType().equals(BOOLEAN) && translatedArgument.isPresent()) {
return Optional.of(new NotExpression(translatedArgument.get()));
}
return Optional.empty();
}

private Optional<Expression> translateCast(Type type, ConnectorExpression expression)
{
Optional<Expression> translatedExpression = translate(expression);

if (translatedExpression.isPresent()) {
return Optional.of(new Cast(translatedExpression.get(), toSqlType(type)));
}

return Optional.empty();
}
Expand Down Expand Up @@ -573,6 +590,22 @@ protected Optional<ConnectorExpression> visitCast(Cast node, Void context)
if (isEffectivelyLiteral(plannerContext, session, node)) {
return Optional.of(constantFor(node));
}

if (node.isSafe()) {
// try_cast would need to be modeled separately
return Optional.empty();
}

if (!isComplexExpressionPushdown(session)) {
return Optional.empty();
}

Optional<ConnectorExpression> translatedExpression = process(node.getExpression());
if (translatedExpression.isPresent()) {
Type type = plannerContext.getTypeManager().getType(toTypeSignature(node.getType()));
return Optional.of(new Call(type, CAST_FUNCTION_NAME, List.of(translatedExpression.get())));
}

return Optional.empty();
}

Expand All @@ -598,11 +631,11 @@ protected Optional<ConnectorExpression> visitFunctionCall(FunctionCall node, Voi
Object value = evaluateConstant(node);
if (value instanceof JoniRegexp) {
Slice pattern = ((JoniRegexp) value).pattern();
return Optional.of(new Constant(pattern, VarcharType.createVarcharType(countCodePoints(pattern))));
return Optional.of(new Constant(pattern, createVarcharType(countCodePoints(pattern))));
}
if (value instanceof Re2JRegexp) {
Slice pattern = Slices.utf8Slice(((Re2JRegexp) value).pattern());
return Optional.of(new Constant(pattern, VarcharType.createVarcharType(countCodePoints(pattern))));
return Optional.of(new Constant(pattern, createVarcharType(countCodePoints(pattern))));
}
return Optional.of(new Constant(value, types.get(NodeRef.of(node))));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.DoubleLiteral;
import io.trino.sql.tree.Expression;
Expand Down Expand Up @@ -55,6 +57,7 @@
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
Expand All @@ -73,6 +76,7 @@
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.ConnectorExpressionTranslator.translate;
import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer;
Expand All @@ -86,7 +90,7 @@ public class TestConnectorExpressionTranslator
private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT);
private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5)));
private static final Type VARCHAR_TYPE = createVarcharType(25);
private static final VarcharType VARCHAR_TYPE = createVarcharType(25);
private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT);

private static final Map<Symbol, Type> symbols = ImmutableMap.<Symbol, Type>builder()
Expand Down Expand Up @@ -330,6 +334,37 @@ public void testTranslateNullIf()
new Variable("varchar_symbol_1", VARCHAR_TYPE))));
}

@Test
public void testTranslateCast()
{
assertTranslationRoundTrips(
new Cast(new SymbolReference("varchar_symbol_1"), toSqlType(VARCHAR_TYPE)),
new Call(
VARCHAR_TYPE,
CAST_FUNCTION_NAME,
List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));

// type-only
VarcharType longerVarchar = createVarcharType(VARCHAR_TYPE.getBoundedLength() + 1);
assertTranslationToConnectorExpression(
TEST_SESSION,
new Cast(new SymbolReference("varchar_symbol_1"), toSqlType(longerVarchar), false, true),
new Call(
longerVarchar,
CAST_FUNCTION_NAME,
List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));

// TRY_CAST is not translated
assertTranslationToConnectorExpression(
TEST_SESSION,
new Cast(
new SymbolReference("varchar_symbol_1"),
toSqlType(BIGINT),
true,
true),
Optional.empty());
}

@Test
public void testTranslateResolvedFunction()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.TableScanRedirectApplicationResult;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.Variable;
Expand All @@ -53,6 +54,7 @@
import static io.trino.connector.MockConnectorFactory.ApplyFilter;
import static io.trino.connector.MockConnectorFactory.ApplyProjection;
import static io.trino.connector.MockConnectorFactory.ApplyTableScanRedirect;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.predicate.Domain.singleValue;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.IntegerType.INTEGER;
Expand Down Expand Up @@ -366,9 +368,12 @@ private Optional<ProjectionApplicationResult<ConnectorTableHandle>> mockApplyPro

for (ConnectorExpression projection : projections) {
String newVariableName;
ConnectorExpression newVariable;
ColumnHandle newColumnHandle;
Type type = projection.getType();
if (projection instanceof Variable) {
newVariableName = ((Variable) projection).getName();
newVariable = new Variable(newVariableName, type);
newColumnHandle = assignments.get(newVariableName);
}
else if (projection instanceof FieldDereference) {
Expand All @@ -378,16 +383,27 @@ else if (projection instanceof FieldDereference) {
}
String dereferenceTargetName = ((Variable) dereference.getTarget()).getName();
newVariableName = ((MockConnectorColumnHandle) assignments.get(dereferenceTargetName)).getName() + "#" + dereference.getField();
newColumnHandle = new MockConnectorColumnHandle(newVariableName, projection.getType());
newVariable = new Variable(newVariableName, type);
newColumnHandle = new MockConnectorColumnHandle(newVariableName, type);
}
else if (projection instanceof Call) {
Call call = (Call) projection;
if (!(CAST_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1)) {
throw new UnsupportedOperationException();
}
// Avoid CAST pushdown into the connector
newVariableName = ((Variable) call.getArguments().get(0)).getName();
newVariable = projection;
newColumnHandle = assignments.get(newVariableName);
type = call.getArguments().get(0).getType();
}
else {
throw new UnsupportedOperationException();
}

Variable newVariable = new Variable(newVariableName, projection.getType());
newColumnsBuilder.add(newColumnHandle);
outputExpressions.add(newVariable);
outputAssignments.add(new Assignment(newVariableName, newColumnHandle, projection.getType()));
outputAssignments.add(new Assignment(newVariableName, newColumnHandle, type));
}

List<ColumnHandle> newColumns = newColumnsBuilder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ private StandardFunctions() {}
*/
public static final FunctionName NULLIF_FUNCTION_NAME = new FunctionName("$nullif");

/**
* $cast function result type is determined by the {@link Call#getType()}
*/
public static final FunctionName CAST_FUNCTION_NAME = new FunctionName("$cast");

public static final FunctionName EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$equal");
public static final FunctionName NOT_EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$not_equal");
public static final FunctionName LESS_THAN_OPERATOR_FUNCTION_NAME = new FunctionName("$less_than");
Expand Down

0 comments on commit b7cd83c

Please sign in to comment.