Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use Calcites.getColumnTypeForRelDataType for SQL CAST operator conversion #13890

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.druid.sql.calcite.expression.builtin;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
Expand All @@ -30,6 +29,7 @@
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
Expand All @@ -39,46 +39,10 @@
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.joda.time.Period;

import java.util.Map;
import java.util.function.Function;

public class CastOperatorConversion implements SqlOperatorConversion
{
private static final Map<SqlTypeName, ExprType> EXPRESSION_TYPES;

static {
final ImmutableMap.Builder<SqlTypeName, ExprType> builder = ImmutableMap.builder();

for (SqlTypeName type : SqlTypeName.FRACTIONAL_TYPES) {
builder.put(type, ExprType.DOUBLE);
}

for (SqlTypeName type : SqlTypeName.INT_TYPES) {
builder.put(type, ExprType.LONG);
}

for (SqlTypeName type : SqlTypeName.STRING_TYPES) {
builder.put(type, ExprType.STRING);
}

// Booleans are treated as longs in Druid expressions, using two-value logic (positive = true, nonpositive = false).
builder.put(SqlTypeName.BOOLEAN, ExprType.LONG);

// Timestamps are treated as longs (millis since the epoch) in Druid expressions.
builder.put(SqlTypeName.TIMESTAMP, ExprType.LONG);
builder.put(SqlTypeName.DATE, ExprType.LONG);

for (SqlTypeName type : SqlTypeName.DAY_INTERVAL_TYPES) {
builder.put(type, ExprType.LONG);
}

for (SqlTypeName type : SqlTypeName.YEAR_INTERVAL_TYPES) {
builder.put(type, ExprType.LONG);
}

EXPRESSION_TYPES = builder.build();
}

@Override
public SqlOperator calciteOperator()
{
Expand Down Expand Up @@ -118,28 +82,34 @@ public DruidExpression toDruidExpression(
} else {
// Handle other casts. If either type is ANY, use the other type instead. If both are ANY, this means nulls
// downstream, Druid will try its best
final ExprType fromExprType = SqlTypeName.ANY.equals(fromType)
? EXPRESSION_TYPES.get(toType)
: EXPRESSION_TYPES.get(fromType);
final ExprType toExprType = SqlTypeName.ANY.equals(toType)
? EXPRESSION_TYPES.get(fromType)
: EXPRESSION_TYPES.get(toType);

if (fromExprType == null || toExprType == null) {
final ColumnType fromDruidType = Calcites.getColumnTypeForRelDataType(operand.getType());
final ColumnType toDruidType = Calcites.getColumnTypeForRelDataType(rexNode.getType());

final ExpressionType fromExpressionType = SqlTypeName.ANY.equals(fromType)
? ExpressionType.fromColumnType(toDruidType)
: ExpressionType.fromColumnType(fromDruidType);
final ExpressionType toExpressionType = SqlTypeName.ANY.equals(toType)
? ExpressionType.fromColumnType(fromDruidType)
: ExpressionType.fromColumnType(toDruidType);

if (fromExpressionType == null || toExpressionType == null) {
// We have no runtime type for these SQL types.
return null;
}

final DruidExpression typeCastExpression;

if (fromExprType != toExprType) {
if (fromExpressionType.equals(toExpressionType)) {
typeCastExpression = operandExpression;
} else if (SqlTypeName.INTERVAL_TYPES.contains(fromType) && toExpressionType.is(ExprType.LONG)) {
// intervals can be longs without an explicit cast
typeCastExpression = operandExpression;
} else {
// Ignore casts for simple extractions (use Function.identity) since it is ok in many cases.
typeCastExpression = operandExpression.map(
Function.identity(),
expression -> StringUtils.format("CAST(%s, '%s')", expression, toExprType.toString())
expression -> StringUtils.format("CAST(%s, '%s')", expression, toExpressionType.asTypeString())
);
} else {
typeCastExpression = operandExpression;
}

if (toType == SqlTypeName.DATE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,16 @@ private ReductionOperatorConversionHelper()
boolean hasDouble = false;
boolean isString = false;
for (int i = 0; i < n; i++) {
RelDataType type = opBinding.getOperandType(i);
SqlTypeName sqlTypeName = type.getSqlTypeName();
ColumnType valueType = Calcites.getColumnTypeForRelDataType(type);
final RelDataType type = opBinding.getOperandType(i);
final SqlTypeName sqlTypeName = type.getSqlTypeName();
final ColumnType valueType;

if (SqlTypeName.INTERVAL_TYPES.contains(type.getSqlTypeName())) {
// handle intervals as a LONG type even though it is a string
valueType = ColumnType.LONG;
} else {
valueType = Calcites.getColumnTypeForRelDataType(type);
}

// Return types are listed in order of preference:
if (valueType != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public static ColumnType getValueTypeForRelDataTypeFull(final RelDataType type)
return ColumnType.DOUBLE;
} else if (isLongType(sqlTypeName)) {
return ColumnType.LONG;
} else if (SqlTypeName.CHAR_TYPES.contains(sqlTypeName)) {
} else if (isStringType(sqlTypeName)) {
return ColumnType.STRING;
} else if (SqlTypeName.OTHER == sqlTypeName) {
if (type instanceof RowSignatures.ComplexSqlType) {
Expand All @@ -178,6 +178,12 @@ public static ColumnType getValueTypeForRelDataTypeFull(final RelDataType type)
}
}

public static boolean isStringType(SqlTypeName sqlTypeName)
{
return SqlTypeName.CHAR_TYPES.contains(sqlTypeName) ||
SqlTypeName.INTERVAL_TYPES.contains(sqlTypeName);
}

public static boolean isDoubleType(SqlTypeName sqlTypeName)
{
return SqlTypeName.FRACTIONAL_TYPES.contains(sqlTypeName) || SqlTypeName.APPROX_TYPES.contains(sqlTypeName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.ExpressionDimFilter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.LikeDimFilter;
import org.apache.druid.query.filter.OrDimFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupByQueryConfig;
Expand All @@ -39,6 +41,7 @@
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.segment.virtual.ListFilteredVirtualColumn;
import org.apache.druid.sql.SqlPlanningException;
import org.apache.druid.sql.calcite.filtration.Filtration;
Expand Down Expand Up @@ -1847,4 +1850,87 @@ public void testMultiValueToArrayArgsWithArray()
exception -> exception.expect(RuntimeException.class)
);
}

@Test
public void testMultiValueStringOverlapFilterCoalesceNvl()
{
testQuery(
"SELECT COALESCE(dim3, 'other') FROM druid.numfoo "
+ "WHERE MV_OVERLAP(COALESCE(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) OR "
+ "MV_OVERLAP(NVL(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5",
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.eternityInterval()
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"case_searched(notnull(\"dim3\"),\"dim3\",'other')",
ColumnType.STRING,
queryFramework().macroTable()
)
)
.filters(
new OrDimFilter(
new ExpressionDimFilter(
"case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)",
null,
queryFramework().macroTable()
),
new ExpressionDimFilter(
"case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)",
null,
queryFramework().macroTable()
)
)
)
.columns("v0")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.limit(5)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
NullHandling.replaceWithDefault()
? ImmutableList.of(
new Object[]{"[\"a\",\"b\"]"},
new Object[]{"[\"b\",\"c\"]"},
new Object[]{"other"},
new Object[]{"other"},
new Object[]{"other"}
)
: ImmutableList.of(
new Object[]{"[\"a\",\"b\"]"},
new Object[]{"[\"b\",\"c\"]"},
new Object[]{"other"},
new Object[]{"other"}
)
);
}

@Test
public void testMultiValueStringOverlapFilterInconsistentUsage()
{
testQueryThrows(
"SELECT COALESCE(dim3, 'other') FROM druid.numfoo "
+ "WHERE MV_OVERLAP(COALESCE(dim3, ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5",
e -> {
e.expect(SqlPlanningException.class);
e.expectMessage("Illegal mixing of types in CASE or COALESCE statement");
}

);
}

@Test
public void testMultiValueStringOverlapFilterInconsistentUsage2()
{
testQueryThrows(
"SELECT COALESCE(dim3, 'other') FROM druid.numfoo "
+ "WHERE MV_OVERLAP(COALESCE(dim3, 'other'), ARRAY['a', 'b', 'other']) LIMIT 5",
e -> {
e.expect(RuntimeException.class);
e.expectMessage("Invalid expression: (case_searched [(notnull [dim3]), (array_overlap [dim3, [a, b, other]]), 1]); [dim3] used as both scalar and array variables");
}
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1751,8 +1751,7 @@ public void testTimeMinusDayTimeInterval()
(args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")",
ImmutableList.of(
DruidExpression.ofColumn(ColumnType.LONG, "t"),
// RexNode type of "interval day to minute" is not converted to druid long... yet
DruidExpression.ofLiteral(null, "90060000")
DruidExpression.ofLiteral(ColumnType.STRING, "90060000")
)
),
DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis()
Expand All @@ -1779,8 +1778,7 @@ public void testTimeMinusYearMonthInterval()
DruidExpression.functionCall("timestamp_shift"),
ImmutableList.of(
DruidExpression.ofColumn(ColumnType.LONG, "t"),
// RexNode type "interval year to month" is not reported as ColumnType.STRING
DruidExpression.ofLiteral(null, DruidExpression.stringLiteral("P13M")),
DruidExpression.ofLiteral(ColumnType.STRING, DruidExpression.stringLiteral("P13M")),
DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(-1)),
DruidExpression.ofStringLiteral("UTC")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,17 @@ public void testTimestamp()
}

@Test
public void testInvalidType()
public void testIntervalYearMonth()
{
expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH");

testExpression(
Collections.singletonList(
testHelper.makeLiteral(
new BigDecimal(13), // YEAR-MONTH literals value is months
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
),
null,
null
buildExpectedExpression(13),
13L
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,19 +247,17 @@ public void testTimestamp()
}

@Test
public void testInvalidType()
public void testIntervalYearMonth()
{
expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH");

testExpression(
Collections.singletonList(
testHelper.makeLiteral(
new BigDecimal(13), // YEAR-MONTH literals value is months
new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO)
)
),
null,
null
buildExpectedExpression(13),
13L
);
}

Expand Down