Skip to content

Commit

Permalink
expression: fix the issue that comparison between Decimal may cause o…
Browse files Browse the repository at this point in the history
…verflow and report `Can't compare`. (pingcap#3097) (pingcap#3366)
  • Loading branch information
ti-chi-bot authored Dec 21, 2021
1 parent a139e02 commit 2e0d16c
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 172 deletions.
92 changes: 47 additions & 45 deletions dbms/src/Core/DecimalComparison.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,53 @@ class DecimalComparison
return applyWithScale(a, b, shift);
}

template <bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{
CompareInt x = static_cast<CompareInt>(a);
CompareInt y = static_cast<CompareInt>(b);

if constexpr (_check_overflow)
{
bool invalid = false;

if constexpr (sizeof(A) > sizeof(CompareInt))
invalid |= (A(x) != a);
if constexpr (sizeof(B) > sizeof(CompareInt))
invalid |= (B(y) != b);
if constexpr (std::is_unsigned_v<A>)
invalid |= (x < 0);
if constexpr (std::is_unsigned_v<B>)
invalid |= (y < 0);

if (invalid)
throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW);
}

if constexpr (scale_left && scale_right)
throw DB::Exception("Assumption broken: there should only one side need to be multiplied in decimal comparison.", ErrorCodes::LOGICAL_ERROR);
if constexpr (!scale_left && !scale_right)
return Op::apply(x, y);

// overflow means absolute value must be greater.
// we use this variable to mark whether the right side is greater than left side by overflow.
int right_side_greater_by_overflow = 0;
if constexpr (scale_left)
{
int sign = boost::math::sign(x);
right_side_greater_by_overflow = -sign * common::mulOverflow(x, scale, x); // x will be changed.
}
if constexpr (scale_right)
{
int sign = boost::math::sign(y);
right_side_greater_by_overflow = sign * common::mulOverflow(y, scale, y); // y will be changed.
}

if (right_side_greater_by_overflow)
return Op::apply(0, right_side_greater_by_overflow);
return Op::apply(x, y);
}

private:

struct Shift
Expand Down Expand Up @@ -239,51 +286,6 @@ class DecimalComparison
return c_res;
}

template <bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{
CompareInt x = static_cast<CompareInt>(a);
CompareInt y = static_cast<CompareInt>(b);

if constexpr (_check_overflow)
{
bool overflow = false;

if constexpr (sizeof(A) > sizeof(CompareInt))
overflow |= (A(x) != a);
if constexpr (sizeof(B) > sizeof(CompareInt))
overflow |= (B(y) != b);
if constexpr (std::is_unsigned_v<A>)
overflow |= (x < 0);
if constexpr (std::is_unsigned_v<B>)
overflow |= (y < 0);

if constexpr (scale_left) {
if constexpr (std::is_same_v<CompareInt, Int256>)
x = x * scale;
else
overflow |= common::mulOverflow(x, scale, x);
}
if constexpr (scale_right) {
if constexpr (std::is_same_v<CompareInt, Int256>)
y = y * scale;
else
overflow |= common::mulOverflow(y, scale, y);
}
if (overflow)
throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW);
}
else
{
if constexpr (scale_left)
x *= scale;
if constexpr (scale_right)
y *= scale;
}

return Op::apply(x, y);
}

template <bool scale_left, bool scale_right>
static void NO_INLINE vector_vector(const ArrayA & a, const ArrayB & b, PaddedPODArray<UInt8> & c,
CompareInt scale [[maybe_unused]])
Expand Down
18 changes: 18 additions & 0 deletions dbms/src/Core/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#endif

#include <limits>
#include <common/arithmeticOverflow.h>

#if !defined(__GLIBCXX_BITSIZE_INT_N_0) && defined(__SIZEOF_INT128__)
namespace std
Expand Down Expand Up @@ -166,3 +167,20 @@ template <> struct TypeId<Float32> { static constexpr const TypeIndex value = T
template <> struct TypeId<Float64> { static constexpr const TypeIndex value = TypeIndex::Float64; };

}

namespace common
{
template <>
inline bool mulOverflow(DB::Int256 x, DB::Int256 y, DB::Int256 & res)
{
try
{
res = x * y;
}
catch (std::overflow_error &)
{
return true;
}
return false;
}
} // namespace common
57 changes: 31 additions & 26 deletions dbms/src/DataTypes/DataTypeDecimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

namespace DB
{

namespace ErrorCodes
{
extern const int ARGUMENT_OUT_OF_BOUND;
Expand Down Expand Up @@ -43,9 +42,13 @@ class DataTypeDecimal : public IDataType
static constexpr size_t maxPrecision() { return maxDecimalPrecision<T>(); }

// If scale is omitted, the default is 0. If precision is omitted, the default is 10.
DataTypeDecimal() : DataTypeDecimal(10, 0) {}
DataTypeDecimal()
: DataTypeDecimal(10, 0)
{}

DataTypeDecimal(size_t precision_, size_t scale_) : precision(precision_), scale(scale_)
DataTypeDecimal(size_t precision_, size_t scale_)
: precision(precision_)
, scale(scale_)
{
if (precision > decimal_max_prec || scale > precision || scale > decimal_max_scale)
{
Expand Down Expand Up @@ -129,7 +132,7 @@ class DataTypeDecimal : public IDataType
template <typename U>
typename T::NativeType scaleFactorFor(const DataTypeDecimal<U> & x) const
{
if (scale < x.getScale())
if (getScale() < x.getScale())
{
return 1;
}
Expand Down Expand Up @@ -164,8 +167,8 @@ inline DataTypePtr createDecimal(UInt64 prec, UInt64 scale)

if (static_cast<UInt64>(scale) > prec)
throw Exception("Negative scales and scales larger than precision are not supported. precision:" + DB::toString(prec)
+ ", scale:" + DB::toString(scale),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);
+ ", scale:" + DB::toString(scale),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);

if (prec <= maxDecimalPrecision<Decimal32>())
{
Expand Down Expand Up @@ -195,15 +198,17 @@ inline bool IsDecimalDataType(const DataTypePtr & type)
}
template <typename T, typename U>
typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal<T>> decimalResultType(
const DataTypeDecimal<T> & tx, const DataTypeDecimal<U> & ty)
const DataTypeDecimal<T> & tx,
const DataTypeDecimal<U> & ty)
{
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
return DataTypeDecimal<T>(maxDecimalPrecision<T>(), scale);
}

template <typename T, typename U>
typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal<U>> decimalResultType(
const DataTypeDecimal<T> & tx, const DataTypeDecimal<U> & ty)
const DataTypeDecimal<T> & tx,
const DataTypeDecimal<U> & ty)
{
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
return DataTypeDecimal<U>(maxDecimalPrecision<U>(), scale);
Expand All @@ -226,24 +231,24 @@ inline UInt32 leastDecimalPrecisionFor(TypeIndex int_type)
{
switch (int_type)
{
case TypeIndex::Int8:
[[fallthrough]];
case TypeIndex::UInt8:
return 3;
case TypeIndex::Int16:
[[fallthrough]];
case TypeIndex::UInt16:
return 5;
case TypeIndex::Int32:
[[fallthrough]];
case TypeIndex::UInt32:
return 10;
case TypeIndex::Int64:
return 19;
case TypeIndex::UInt64:
return 20;
default:
break;
case TypeIndex::Int8:
[[fallthrough]];
case TypeIndex::UInt8:
return 3;
case TypeIndex::Int16:
[[fallthrough]];
case TypeIndex::UInt16:
return 5;
case TypeIndex::Int32:
[[fallthrough]];
case TypeIndex::UInt32:
return 10;
case TypeIndex::Int64:
return 19;
case TypeIndex::UInt64:
return 20;
default:
break;
}
return 0;
}
Expand Down
Loading

0 comments on commit 2e0d16c

Please sign in to comment.