Skip to content

Commit

Permalink
[feature](nereids)support decimalv2 (#28726)
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 authored Dec 25, 2023
1 parent 9975592 commit c53611d
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ public static Expression divideDouble(DoubleLiteral first, DoubleLiteral second)
return new DoubleLiteral(result);
}

/**
* Executable arithmetic functions divide
*/
@ExecFunction(name = "divide", argTypes = {"DECIMAL", "DECIMAL"}, returnType = "DECIMAL")
public static Expression divideDecimal(DecimalLiteral first, DecimalLiteral second) {
if (first.getValue().compareTo(BigDecimal.ZERO) == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ public double getDouble() {
/**
* check precision and scale is enough for value.
*/
public static void checkPrecisionAndScale(int precision, int scale, BigDecimal value) throws AnalysisException {
private static void checkPrecisionAndScale(int precision, int scale, BigDecimal value) throws AnalysisException {
Preconditions.checkNotNull(value);
int realPrecision = value.precision();
int realScale = value.scale();
boolean valid = true;
if (precision != -1 && scale != -1) {
if (precision < realPrecision || scale < realScale) {
if (precision < realPrecision || scale < realScale
|| realPrecision - realScale > DecimalV2Type.MAX_PRECISION - DecimalV2Type.MAX_SCALE) {
valid = false;
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.doris.nereids.trees.expressions.literal;

import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DecimalV3Type;

import com.google.common.base.Preconditions;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Objects;
Expand All @@ -43,7 +46,7 @@ public DecimalV3Literal(BigDecimal value) {
public DecimalV3Literal(DecimalV3Type dataType, BigDecimal value) {
super(DecimalV3Type.createDecimalV3TypeLooseCheck(dataType.getPrecision(), dataType.getScale()));
Objects.requireNonNull(value, "value not be null");
DecimalLiteral.checkPrecisionAndScale(dataType.getPrecision(), dataType.getScale(), value);
checkPrecisionAndScale(dataType.getPrecision(), dataType.getScale(), value);
BigDecimal adjustedValue = value.scale() < 0 ? value
: value.setScale(dataType.getScale(), RoundingMode.HALF_UP);
this.value = Objects.requireNonNull(adjustedValue);
Expand Down Expand Up @@ -80,4 +83,27 @@ public DecimalV3Literal roundFloor(int newScale) {
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.FLOOR));
}

/**
* check precision and scale is enough for value.
*/
private static void checkPrecisionAndScale(int precision, int scale, BigDecimal value) throws AnalysisException {
Preconditions.checkNotNull(value);
int realPrecision = value.precision();
int realScale = value.scale();
boolean valid = true;
if (precision != -1 && scale != -1) {
if (precision < realPrecision || scale < realScale) {
valid = false;
}
} else {
valid = false;
}

if (!valid) {
throw new AnalysisException(
String.format("Invalid precision and scale - expect (%d, %d), but (%d, %d)",
precision, scale, realPrecision, realScale));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ public static DataType convertPrimitiveFromStrings(List<String> types, boolean u
throw new AnalysisException("Nereids do not support type: " + type);
}
break;
case "decimalv2":
// NOTICE, maybe convert to decimalv3, so do not truc here.
switch (types.size()) {
case 1:
dataType = DecimalV2Type.CATALOG_DEFAULT_NOT_CONVERSION;
break;
case 2:
dataType = DecimalV2Type.createDecimalV2TypeWithoutTruncate(
Integer.parseInt(types.get(1)), 0, false);
break;
case 3:
dataType = DecimalV2Type.createDecimalV2TypeWithoutTruncate(
Integer.parseInt(types.get(1)), Integer.parseInt(types.get(2)), false);
break;
default:
throw new AnalysisException("Nereids do not support type: " + type);
}
break;
case "decimalv3":
switch (types.size()) {
case 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,21 @@ public class DecimalV2Type extends FractionalType {

public static int MAX_PRECISION = 27;
public static int MAX_SCALE = 9;
public static final DecimalV2Type SYSTEM_DEFAULT = new DecimalV2Type(MAX_PRECISION, MAX_SCALE);
public static final DecimalV2Type CATALOG_DEFAULT = new DecimalV2Type(DEFAULT_PRECISION, DEFAULT_SCALE);

private static final DecimalV2Type BOOLEAN_DECIMAL = new DecimalV2Type(1, 0);
private static final DecimalV2Type TINYINT_DECIMAL = new DecimalV2Type(3, 0);
private static final DecimalV2Type SMALLINT_DECIMAL = new DecimalV2Type(5, 0);
private static final DecimalV2Type INTEGER_DECIMAL = new DecimalV2Type(10, 0);
private static final DecimalV2Type BIGINT_DECIMAL = new DecimalV2Type(20, 0);
private static final DecimalV2Type LARGEINT_DECIMAL = new DecimalV2Type(27, 0);
private static final DecimalV2Type FLOAT_DECIMAL = new DecimalV2Type(14, 7);
private static final DecimalV2Type DOUBLE_DECIMAL = new DecimalV2Type(27, 9);
public static final DecimalV2Type SYSTEM_DEFAULT = new DecimalV2Type(MAX_PRECISION, MAX_SCALE, true);
public static final DecimalV2Type SYSTEM_DEFAULT_NOT_CONVERSION =
new DecimalV2Type(MAX_PRECISION, MAX_SCALE, false);
public static final DecimalV2Type CATALOG_DEFAULT = new DecimalV2Type(DEFAULT_PRECISION, DEFAULT_SCALE, true);
public static final DecimalV2Type CATALOG_DEFAULT_NOT_CONVERSION =
new DecimalV2Type(DEFAULT_PRECISION, DEFAULT_SCALE, false);

private static final DecimalV2Type BOOLEAN_DECIMAL = new DecimalV2Type(1, 0, true);
private static final DecimalV2Type TINYINT_DECIMAL = new DecimalV2Type(3, 0, true);
private static final DecimalV2Type SMALLINT_DECIMAL = new DecimalV2Type(5, 0, true);
private static final DecimalV2Type INTEGER_DECIMAL = new DecimalV2Type(10, 0, true);
private static final DecimalV2Type BIGINT_DECIMAL = new DecimalV2Type(20, 0, true);
private static final DecimalV2Type LARGEINT_DECIMAL = new DecimalV2Type(27, 0, true);
private static final DecimalV2Type FLOAT_DECIMAL = new DecimalV2Type(14, 7, true);
private static final DecimalV2Type DOUBLE_DECIMAL = new DecimalV2Type(27, 9, true);

private static final int WIDTH = 16;

Expand All @@ -68,14 +72,17 @@ public class DecimalV2Type extends FractionalType {
private final int precision;
private final int scale;

private final boolean shouldConversion;

/**
* constructors.
*/
private DecimalV2Type(int precision, int scale) {
private DecimalV2Type(int precision, int scale, boolean shouldConversion) {
Preconditions.checkArgument(precision >= scale, "precision should not smaller than scale,"
+ " but precision is " + precision, ", scale is " + scale);
this.precision = precision;
this.scale = scale;
this.shouldConversion = shouldConversion;
}

/** createDecimalV2Type. */
Expand All @@ -86,7 +93,7 @@ public static DecimalV2Type createDecimalV2Type(int precision, int scale) {
if (precision == CATALOG_DEFAULT.precision && scale == CATALOG_DEFAULT.scale) {
return CATALOG_DEFAULT;
}
return new DecimalV2Type(Math.min(precision, MAX_PRECISION), Math.min(scale, MAX_SCALE));
return new DecimalV2Type(Math.min(precision, MAX_PRECISION), Math.min(scale, MAX_SCALE), true);
}

public static DecimalV2Type createDecimalV2Type(BigDecimal bigDecimal) {
Expand All @@ -105,7 +112,22 @@ public static DecimalV2Type createDecimalV2TypeWithoutTruncate(int precision, in
if (precision == CATALOG_DEFAULT.precision && scale == CATALOG_DEFAULT.scale) {
return CATALOG_DEFAULT;
}
return new DecimalV2Type(precision, scale);
return new DecimalV2Type(precision, scale, true);
}

/**
* create DecimalV2Type with appropriate scale, precision and shouldConversion flag,
* not truncate to MAX_PRECISION, MAX_SCALE.
*/
public static DecimalV2Type createDecimalV2TypeWithoutTruncate(int precision, int scale,
boolean shouldConversion) {
if (precision == SYSTEM_DEFAULT.precision && scale == SYSTEM_DEFAULT.scale) {
return shouldConversion ? SYSTEM_DEFAULT : SYSTEM_DEFAULT_NOT_CONVERSION;
}
if (precision == CATALOG_DEFAULT.precision && scale == CATALOG_DEFAULT.scale) {
return shouldConversion ? CATALOG_DEFAULT : CATALOG_DEFAULT_NOT_CONVERSION;
}
return new DecimalV2Type(precision, scale, shouldConversion);
}

/**
Expand Down Expand Up @@ -153,7 +175,7 @@ public int getScale() {

@Override
public DataType conversion() {
if (Config.enable_decimal_conversion) {
if (Config.enable_decimal_conversion && shouldConversion) {
return DecimalV3Type.createDecimalV3Type(precision, scale);
}
Preconditions.checkArgument(precision > 0 && precision <= MAX_PRECISION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ public void testDatetimev1() {

}

@Test
public void testDecimalv2() {
String decv2 = "SELECT CAST('1.234' AS decimalv2(10,5))";
NereidsParser nereidsParser = new NereidsParser();
LogicalPlan logicalPlan = (LogicalPlan) nereidsParser.parseSingle(decv2).child(0);
Assertions.assertTrue(logicalPlan.getExpressions().get(0).getDataType().isDecimalV2Type());
}

@Test
public void parseSetOperation() {
String union = "select * from t1 union select * from t2 union all select * from t3";
Expand Down
52 changes: 52 additions & 0 deletions regression-test/data/nereids_p0/datatype/test_decimalv2.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !sql1 --
1.230

-- !sql2 --
1.230

-- !sql1 --
1 1.230
2 2.340
3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

-- !sql2 --
1 1.230 1 1.230
2 2.340 2 2.340
3 3.450 3 3.450

Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ suite("test_decimalv2_overflow", "nonConcurrent") {
def tblName2 = "test_decimalv2_overflow2"
sql "drop table if exists ${tblName2}"
sql """ CREATE TABLE ${tblName2} (
`c2` decimalv2(20, 2),
`c2` decimalv2(20, 2)
) ENGINE=OLAP
UNIQUE KEY(`c2`)
DISTRIBUTED BY HASH(`c2`) BUCKETS 10
Expand Down
88 changes: 88 additions & 0 deletions regression-test/suites/nereids_p0/datatype/test_decimalv2.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_decimalv2") {
def tbName = "test_decimalv2_exprs"

sql """set enable_nereids_planner=true"""
sql """set enable_fallback_to_original_planner=false"""
sql """set enable_nereids_dml=true"""

sql "DROP TABLE IF EXISTS ${tbName}"
sql """
create table ${tbName}(k1 decimalv2(10,3), k2 int) distributed by hash(k1) buckets 1 properties("replication_num" = "1");
"""
sql """ insert into ${tbName} values("1.23", 1); """

qt_sql1 """ select dt
from
(
select cast(k1 as decimalv2(5,3)) as dt
from ${tbName}
) r; """
qt_sql2 """ select dt
from
(
select cast(k1 as decimal(5,3)) as dt
from ${tbName}
) r; """
sql "DROP TABLE ${tbName}"

tbName = "test_decimalv2_runtime_filter"
sql "DROP TABLE IF EXISTS ${tbName}"
sql """
CREATE TABLE IF NOT EXISTS ${tbName} (
c0 int,
c2 decimalv2(10, 3)
)
DISTRIBUTED BY HASH(c0) BUCKETS 5 properties("replication_num" = "1");
"""
sql "insert into ${tbName} values(1, '1.23')"
sql "insert into ${tbName} values(2, '2.34')"
sql "insert into ${tbName} values(3, '3.45')"

qt_sql1 "select * from ${tbName} ORDER BY c0"

sql " set runtime_filter_type = 1; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_type = 2; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_type = 4; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_type = 8; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_wait_time_ms = 0; "

sql " set runtime_filter_type = 1; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_type = 2; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_type = 4; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql " set runtime_filter_type = 8; "
qt_sql2 "select * from ${tbName} a, ${tbName} b WHERE a.c2 = b.c2 ORDER BY a.c0"

sql "DROP TABLE ${tbName}"
}

0 comments on commit c53611d

Please sign in to comment.