From c76e47ea313ebc19b44e5a423e3855eaf5505fc7 Mon Sep 17 00:00:00 2001 From: Lonng Date: Fri, 2 Aug 2019 14:28:28 +0800 Subject: [PATCH] coprocessor: add a Convert trait and implement convert to decimal (#5167) Signed-off-by: Lonng --- components/tidb_query/src/codec/convert.rs | 74 ++-------- components/tidb_query/src/codec/datum.rs | 42 +++--- .../tidb_query/src/codec/mysql/decimal.rs | 137 ++++++++++++++++-- .../tidb_query/src/codec/mysql/duration.rs | 24 +-- .../tidb_query/src/codec/mysql/json/mod.rs | 4 +- .../tidb_query/src/codec/mysql/time/mod.rs | 33 +++-- .../tidb_query/src/executor/aggregate.rs | 26 ++-- .../tidb_query/src/expr/builtin_cast.rs | 45 +++--- components/tidb_query/src/expr/mod.rs | 2 +- .../src/rpn_expr/impl_arithmetic.rs | 20 +-- .../tidb_query/src/rpn_expr/impl_cast.rs | 26 ++-- .../tidb_query/src/rpn_expr/impl_compare.rs | 11 +- fuzz/targets/mod.rs | 29 +++- 13 files changed, 286 insertions(+), 187 deletions(-) diff --git a/components/tidb_query/src/codec/convert.rs b/components/tidb_query/src/codec/convert.rs index 3262d35ce74..f97ca91d3e7 100644 --- a/components/tidb_query/src/codec/convert.rs +++ b/components/tidb_query/src/codec/convert.rs @@ -1,7 +1,6 @@ // Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0. use std::borrow::Cow; -use std::convert::TryFrom; use std::{self, char, i16, i32, i64, i8, str, u16, u32, u64, u8}; use tidb_query_datatype::{self, FieldTypeTp}; @@ -358,7 +357,8 @@ impl ToInt for DateTime { // TODO: avoid this clone after refactor the `Time` let mut t = self.clone(); t.round_frac(DEFAULT_FSP)?; - let val = t.to_decimal()?.as_i64_with_ctx(ctx)?; + let dec: Decimal = t.convert(ctx)?; + let val = dec.as_i64_with_ctx(ctx)?; val.to_int(ctx, tp) } @@ -367,7 +367,8 @@ impl ToInt for DateTime { // TODO: avoid this clone after refactor the `Time` let mut t = self.clone(); t.round_frac(DEFAULT_FSP)?; - decimal_as_u64(ctx, t.to_decimal()?, tp) + let dec: Decimal = t.convert(ctx)?; + decimal_as_u64(ctx, dec, tp) } } @@ -375,14 +376,16 @@ impl ToInt for Duration { #[inline] fn to_int(&self, ctx: &mut EvalContext, tp: FieldTypeTp) -> Result { let dur = (*self).round_frac(DEFAULT_FSP)?; - let val = Decimal::try_from(dur)?.as_i64_with_ctx(ctx)?; + let dec: Decimal = dur.convert(ctx)?; + let val = dec.as_i64_with_ctx(ctx)?; val.to_int(ctx, tp) } #[inline] fn to_uint(&self, ctx: &mut EvalContext, tp: FieldTypeTp) -> Result { let dur = (*self).round_frac(DEFAULT_FSP)?; - decimal_as_u64(ctx, Decimal::try_from(dur)?, tp) + let dec: Decimal = dur.convert(ctx)?; + decimal_as_u64(ctx, dec, tp) } } @@ -471,23 +474,6 @@ fn decimal_as_u64(ctx: &mut EvalContext, dec: Decimal, tp: FieldTypeTp) -> Resul val.to_uint(ctx, tp) } -/// Converts a bytes slice to a `Decimal` -#[inline] -pub fn convert_bytes_to_decimal(ctx: &mut EvalContext, bytes: &[u8]) -> Result { - let dec = match Decimal::from_bytes(bytes)? { - Res::Ok(d) => d, - Res::Overflow(d) => { - ctx.handle_overflow(Error::overflow("DECIMAL", ""))?; - d - } - Res::Truncated(d) => { - ctx.handle_truncate(true)?; - d - } - }; - Ok(dec) -} - /// `bytes_to_int_without_context` converts a byte arrays to an i64 /// in best effort, but without context. pub fn bytes_to_int_without_context(bytes: &[u8]) -> Result { @@ -548,12 +534,14 @@ pub fn bytes_to_uint_without_context(bytes: &[u8]) -> Result { } impl ConvertTo for i64 { + #[inline] fn convert(&self, _: &mut EvalContext) -> Result { Ok(*self as f64) } } impl ConvertTo for u64 { + #[inline] fn convert(&self, _: &mut EvalContext) -> Result { Ok(*self as f64) } @@ -581,12 +569,14 @@ impl ConvertTo for &[u8] { } impl ConvertTo for std::borrow::Cow<'_, [u8]> { + #[inline] fn convert(&self, ctx: &mut EvalContext) -> Result { self.as_ref().convert(ctx) } } impl ConvertTo for Bytes { + #[inline] fn convert(&self, ctx: &mut EvalContext) -> Result { self.as_slice().convert(ctx) } @@ -798,7 +788,6 @@ mod tests { use std::{f64, i64, isize, u64}; use crate::codec::error::{ERR_DATA_OUT_OF_RANGE, WARN_DATA_TRUNCATED}; - use crate::codec::mysql::decimal::{self, DIGITS_PER_WORD, WORD_BUF_LEN}; use crate::expr::Flag; use crate::expr::{EvalConfig, EvalContext}; @@ -1480,45 +1469,6 @@ mod tests { } } - #[test] - fn test_convert_bytes_to_decimal() { - let cases: Vec<(&[u8], Decimal)> = vec![ - (b"123456.1", Decimal::from_f64(123456.1).unwrap()), - (b"-123456.1", Decimal::from_f64(-123456.1).unwrap()), - (b"123456", Decimal::from(123456)), - (b"-123456", Decimal::from(-123456)), - ]; - let mut ctx = EvalContext::default(); - for (s, expect) in cases { - let got = convert_bytes_to_decimal(&mut ctx, s).unwrap(); - assert_eq!(got, expect, "from {:?}, expect: {} got: {}", s, expect, got); - } - - // OVERFLOWING - let big = (0..85).map(|_| '9').collect::(); - let val = convert_bytes_to_decimal(&mut ctx, big.as_bytes()); - assert!( - val.is_err(), - "expected error, but got {:?}", - val.unwrap().to_string() - ); - assert_eq!(val.unwrap_err().code(), ERR_DATA_OUT_OF_RANGE); - - // OVERFLOW_AS_WARNING - let mut ctx = EvalContext::new(Arc::new(EvalConfig::from_flag(Flag::OVERFLOW_AS_WARNING))); - let val = convert_bytes_to_decimal(&mut ctx, big.as_bytes()).unwrap(); - let max = decimal::max_decimal(WORD_BUF_LEN * DIGITS_PER_WORD, 0); - assert_eq!( - val, - max, - "expect: {}, got: {}", - val.to_string(), - max.to_string() - ); - assert_eq!(ctx.warnings.warning_cnt, 1); - assert_eq!(ctx.warnings.warnings[0].get_code(), ERR_DATA_OUT_OF_RANGE); - } - #[test] fn test_bytes_to_f64() { let tests: Vec<(&'static [u8], Option)> = vec![ diff --git a/components/tidb_query/src/codec/datum.rs b/components/tidb_query/src/codec/datum.rs index a9f0e0320e7..f85d303e48c 100644 --- a/components/tidb_query/src/codec/datum.rs +++ b/components/tidb_query/src/codec/datum.rs @@ -3,10 +3,8 @@ use byteorder::WriteBytesExt; use std::borrow::Cow; use std::cmp::Ordering; -use std::convert::TryFrom; use std::fmt::{self, Debug, Display, Formatter}; use std::io::Write; -use std::str::FromStr; use std::{i64, str}; use tidb_query_datatype::FieldTypeTp; @@ -381,14 +379,14 @@ impl Datum { Datum::Bytes(bs) => ConvertTo::::convert(&bs, ctx).map(From::from), Datum::Time(t) => { // if time has no precision, return int64 - let dec = t.to_decimal()?; + let dec: Decimal = t.convert(ctx)?; if t.get_fsp() == 0 { return Ok(Datum::I64(dec.as_i64().unwrap())); } Ok(Datum::Dec(dec)) } Datum::Dur(d) => { - let dec = Decimal::try_from(d)?; + let dec: Decimal = d.convert(ctx)?; if d.fsp() == 0 { return Ok(Datum::I64(dec.as_i64().unwrap())); } @@ -399,10 +397,11 @@ impl Datum { } /// Keep compatible with TiDB's `ToDecimal` function. + /// FIXME: the `EvalContext` should be passed by caller pub fn into_dec(self) -> Result { match self { - Datum::Time(t) => t.to_decimal(), - Datum::Dur(d) => Decimal::try_from(d).map_err(From::from), + Datum::Time(t) => t.convert(&mut EvalContext::default()), + Datum::Dur(d) => d.convert(&mut EvalContext::default()), d => match d.coerce_to_dec()? { Datum::Dec(d) => Ok(d), d => Err(box_err!("failed to conver {} to decimal", d)), @@ -471,10 +470,13 @@ impl Datum { let dec = match self { Datum::I64(i) => i.into(), Datum::U64(u) => u.into(), - Datum::F64(f) => Decimal::from_f64(f)?, + Datum::F64(f) => { + // FIXME: the `EvalContext` should be passed from caller + f.convert(&mut EvalContext::default())? + } Datum::Bytes(ref bs) => { - let s = box_try!(str::from_utf8(bs)); - Decimal::from_str(s)? + // FIXME: the `EvalContext` should be passed from caller + bs.convert(&mut EvalContext::default())? } d @ Datum::Dec(_) => return Ok(d), _ => return Err(box_err!("failed to convert {} to decimal", self)), @@ -483,15 +485,11 @@ impl Datum { } /// Try its best effort to convert into a f64 datum. - fn coerce_to_f64(self) -> Result { + fn coerce_to_f64(self, ctx: &mut EvalContext) -> Result { match self { Datum::I64(i) => Ok(Datum::F64(i as f64)), Datum::U64(u) => Ok(Datum::F64(u as f64)), - Datum::Dec(d) => { - // TODO: remove this function `coerce_to_f64` - let f = d.convert(&mut EvalContext::default())?; - Ok(Datum::F64(f)) - } + Datum::Dec(d) => Ok(Datum::F64(d.convert(ctx)?)), a => Ok(a), } } @@ -500,11 +498,11 @@ impl Datum { /// If left or right is F64, changes the both to F64. /// Else if left or right is Decimal, changes the both to Decimal. /// Keep compatible with TiDB's `CoerceDatum` function. - pub fn coerce(left: Datum, right: Datum) -> Result<(Datum, Datum)> { + pub fn coerce(ctx: &mut EvalContext, left: Datum, right: Datum) -> Result<(Datum, Datum)> { let res = match (left, right) { a @ (Datum::Dec(_), Datum::Dec(_)) | a @ (Datum::F64(_), Datum::F64(_)) => a, - (l @ Datum::F64(_), r) => (l, r.coerce_to_f64()?), - (l, r @ Datum::F64(_)) => (l.coerce_to_f64()?, r), + (l @ Datum::F64(_), r) => (l, r.coerce_to_f64(ctx)?), + (l, r @ Datum::F64(_)) => (l.coerce_to_f64(ctx)?, r), (l @ Datum::Dec(_), r) => (l, r.coerce_to_dec()?), (l, r @ Datum::Dec(_)) => (l.coerce_to_dec()?, r), p => p, @@ -1011,6 +1009,7 @@ mod tests { use tikv_util::as_slice; use std::cmp::Ordering; + use std::str::FromStr; use std::sync::Arc; use std::{i16, i32, i64, i8, u16, u32, u64, u8}; @@ -1049,7 +1048,7 @@ mod tests { ], vec![ Datum::U64(1), - Datum::Dec(Decimal::from_f64(2.3).unwrap()), + Datum::Dec(2.3.convert(&mut EvalContext::default()).unwrap()), Datum::Dec("-34".parse().unwrap()), ], vec![ @@ -1654,7 +1653,7 @@ mod tests { Some(true), ), ( - Datum::Dec(Decimal::from_f64(0.1415926).unwrap()), + Datum::Dec(0.1415926.convert(&mut EvalContext::default()).unwrap()), Some(false), ), (Datum::Dec(0u64.into()), Some(false)), @@ -1751,8 +1750,9 @@ mod tests { ), ]; + let mut ctx = EvalContext::default(); for (x, y, exp_x, exp_y) in cases { - let (res_x, res_y) = Datum::coerce(x, y).unwrap(); + let (res_x, res_y) = Datum::coerce(&mut ctx, x, y).unwrap(); assert_eq!(res_x, exp_x); assert_eq!(res_y, exp_y); } diff --git a/components/tidb_query/src/codec/mysql/decimal.rs b/components/tidb_query/src/codec/mysql/decimal.rs index 34251a335a6..4bba728a1b0 100644 --- a/components/tidb_query/src/codec/mysql/decimal.rs +++ b/components/tidb_query/src/codec/mysql/decimal.rs @@ -16,6 +16,7 @@ use tikv_util::codec::BytesSlice; use tikv_util::escape; use crate::codec::convert::{self, ConvertTo}; +use crate::codec::data_type::*; use crate::codec::{Error, Result, TEN_POW}; use crate::expr::EvalContext; @@ -92,9 +93,9 @@ impl DerefMut for Res { } // A `Decimal` holds 9 words. -pub const WORD_BUF_LEN: u8 = 9; +const WORD_BUF_LEN: u8 = 9; // A word holds 9 digits. -pub const DIGITS_PER_WORD: u8 = 9; +const DIGITS_PER_WORD: u8 = 9; // A word is 4 bytes i32. const WORD_SIZE: u8 = 4; const DIG_MASK: u32 = TEN_POW[8]; @@ -396,7 +397,7 @@ fn do_sub<'a>(mut lhs: &'a Decimal, mut rhs: &'a Decimal) -> Res { } /// Get the max possible decimal with giving precision and fraction digit count. -pub fn max_decimal(prec: u8, frac_cnt: u8) -> Decimal { +fn max_decimal(prec: u8, frac_cnt: u8) -> Decimal { let int_cnt = prec - frac_cnt; let mut res = Decimal::new(int_cnt, frac_cnt, false); let mut idx = 0; @@ -1503,19 +1504,6 @@ impl Decimal { Res::Ok(x) } - /// Convert a float number to decimal. - /// - /// This function will use float's canonical string representation - /// rather than the accurate value the float represent. - pub fn from_f64(f: f64) -> Result { - if !f.is_finite() { - return Err(invalid_type!("{} can't be convert to decimal'", f)); - } - - let s = format!("{}", f); - s.parse() - } - pub fn from_bytes(s: &[u8]) -> Result> { Decimal::from_bytes_with_word_buf(s, WORD_BUF_LEN) } @@ -1715,6 +1703,75 @@ impl From for Decimal { } } +impl ConvertTo for i64 { + #[inline] + fn convert(&self, _: &mut EvalContext) -> Result { + Ok(Decimal::from(*self)) + } +} + +impl ConvertTo for u64 { + #[inline] + fn convert(&self, _: &mut EvalContext) -> Result { + Ok(Decimal::from(*self)) + } +} + +impl ConvertTo for f64 { + /// Convert a float number to decimal. + /// + /// This function will use float's canonical string representation + /// rather than the accurate value the float represent. + #[inline] + fn convert(&self, _: &mut EvalContext) -> Result { + if !self.is_finite() { + return Err(invalid_type!("{} can't be convert to decimal'", self)); + } + + let s = format!("{}", self); + s.parse() + } +} + +impl ConvertTo for Real { + #[inline] + fn convert(&self, ctx: &mut EvalContext) -> Result { + self.into_inner().convert(ctx) + } +} + +impl ConvertTo for &[u8] { + #[inline] + fn convert(&self, ctx: &mut EvalContext) -> Result { + let dec = match Decimal::from_bytes(self)? { + Res::Ok(d) => d, + Res::Overflow(d) => { + ctx.handle_overflow(Error::overflow("DECIMAL", ""))?; + d + } + Res::Truncated(d) => { + ctx.handle_truncate(true)?; + d + } + }; + Ok(dec) + } +} + +impl ConvertTo for std::borrow::Cow<'_, [u8]> { + #[inline] + fn convert(&self, ctx: &mut EvalContext) -> Result { + self.as_ref().convert(ctx) + } +} + +impl ConvertTo for Bytes { + #[inline] + fn convert(&self, ctx: &mut EvalContext) -> Result { + self.as_slice().convert(ctx) + } +} + /// Get the first non-digit ascii char in `bs` from `start_idx`. fn first_non_digit(bs: &[u8], start_idx: usize) -> usize { bs.iter() @@ -2219,9 +2276,12 @@ mod tests { use super::*; use super::{DEFAULT_DIV_FRAC_INCR, WORD_BUF_LEN}; + use crate::codec::error::ERR_DATA_OUT_OF_RANGE; + use crate::expr::{EvalConfig, Flag}; use std::cmp::Ordering; use std::f64::EPSILON; use std::iter::repeat; + use std::sync::Arc; #[test] fn test_from_i64() { @@ -3394,4 +3454,49 @@ mod tests { assert_eq!(got, exp); } } + + #[test] + fn test_bytes_to_decimal() { + let cases: Vec<(&[u8], Decimal)> = vec![ + ( + b"123456.1", + ConvertTo::::convert(&123456.1, &mut EvalContext::default()).unwrap(), + ), + ( + b"-123456.1", + ConvertTo::::convert(&-123456.1, &mut EvalContext::default()).unwrap(), + ), + (b"123456", Decimal::from(123456)), + (b"-123456", Decimal::from(-123456)), + ]; + let mut ctx = EvalContext::default(); + for (s, expect) in cases { + let got: Decimal = s.convert(&mut ctx).unwrap(); + assert_eq!(got, expect, "from {:?}, expect: {} got: {}", s, expect, got); + } + + // OVERFLOWING + let big = (0..85).map(|_| '9').collect::(); + let val: Result = big.as_bytes().convert(&mut ctx); + assert!( + val.is_err(), + "expected error, but got {:?}", + val.unwrap().to_string() + ); + assert_eq!(val.unwrap_err().code(), ERR_DATA_OUT_OF_RANGE); + + // OVERFLOW_AS_WARNING + let mut ctx = EvalContext::new(Arc::new(EvalConfig::from_flag(Flag::OVERFLOW_AS_WARNING))); + let val: Decimal = big.as_bytes().convert(&mut ctx).unwrap(); + let max = max_decimal(WORD_BUF_LEN * DIGITS_PER_WORD, 0); + assert_eq!( + val, + max, + "expect: {}, got: {}", + val.to_string(), + max.to_string() + ); + assert_eq!(ctx.warnings.warning_cnt, 1); + assert_eq!(ctx.warnings.warnings[0].get_code(), ERR_DATA_OUT_OF_RANGE); + } } diff --git a/components/tidb_query/src/codec/mysql/duration.rs b/components/tidb_query/src/codec/mysql/duration.rs index 67466d4d5e4..c609cd448d0 100644 --- a/components/tidb_query/src/codec/mysql/duration.rs +++ b/components/tidb_query/src/codec/mysql/duration.rs @@ -1,7 +1,6 @@ // Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0. use std::cmp::Ordering; -use std::convert::TryFrom; use std::fmt::{self, Display, Formatter}; use std::io::Write; use std::{i64, u64}; @@ -682,17 +681,17 @@ impl Duration { } impl ConvertTo for Duration { + #[inline] fn convert(&self, _: &mut EvalContext) -> Result { let val = self.to_numeric_string().parse()?; Ok(val) } } -// TODO: define a convert::Convert trait for all conversion -impl TryFrom for Decimal { - type Error = crate::codec::Error; - fn try_from(duration: Duration) -> Result { - duration.to_numeric_string().parse() +impl ConvertTo for Duration { + #[inline] + fn convert(&self, _: &mut EvalContext) -> Result { + self.to_numeric_string().parse() } } @@ -773,7 +772,6 @@ mod tests { use std::f64::EPSILON; use super::*; - use crate::codec::convert::convert_bytes_to_decimal; use crate::codec::data_type::DateTime; use crate::expr::EvalContext; @@ -970,9 +968,11 @@ mod tests { ("-11:30:45.9233456", 0, "-113046"), ]; + let mut ctx = EvalContext::default(); for (input, fsp, exp) in cases { let t = Duration::parse(input.as_bytes(), fsp).unwrap(); - let res = format!("{}", Decimal::try_from(t).unwrap()); + let dec: Decimal = t.convert(&mut ctx).unwrap(); + let res = format!("{}", dec); assert_eq!(exp, res); } let cases = vec![ @@ -984,14 +984,13 @@ mod tests { ("2017-01-05 23:59:59.575601", 0, "000000"), ("0000-00-00 00:00:00", 6, "000000"), ]; - let mut ctx = EvalContext::default(); for (s, fsp, expect) in cases { let t = DateTime::parse_utc_datetime(s, fsp).unwrap(); let du = t.to_duration().unwrap(); - let get = Decimal::try_from(du).unwrap(); + let get: Decimal = du.convert(&mut ctx).unwrap(); assert_eq!( get, - convert_bytes_to_decimal(&mut ctx, expect.as_bytes()).unwrap(), + expect.as_bytes().convert(&mut ctx).unwrap(), "convert duration {} to decimal", s ); @@ -1146,7 +1145,8 @@ mod benches { let duration = Duration::parse(b"-12:34:56.123456", 6).unwrap(); b.iter(|| { let duration = test::black_box(duration); - let _ = test::black_box(Decimal::try_from(duration).unwrap()); + let dec: Result = duration.convert(&mut EvalContext::default()); + let _ = test::black_box(dec.unwrap()); }) } diff --git a/components/tidb_query/src/codec/mysql/json/mod.rs b/components/tidb_query/src/codec/mysql/json/mod.rs index 02068c22c5e..d2ed719e397 100644 --- a/components/tidb_query/src/codec/mysql/json/mod.rs +++ b/components/tidb_query/src/codec/mysql/json/mod.rs @@ -107,8 +107,8 @@ impl ConvertTo for Json { /// Converts a `Json` to a `Decimal` #[inline] fn convert(&self, ctx: &mut EvalContext) -> Result { - let f = self.convert(ctx)?; - Decimal::from_f64(f) + let f: f64 = self.convert(ctx)?; + f.convert(ctx) } } diff --git a/components/tidb_query/src/codec/mysql/time/mod.rs b/components/tidb_query/src/codec/mysql/time/mod.rs index ee985123bbb..18f30fd54f5 100644 --- a/components/tidb_query/src/codec/mysql/time/mod.rs +++ b/components/tidb_query/src/codec/mysql/time/mod.rs @@ -303,16 +303,6 @@ impl Time { } } - /// Returns the `Decimal` representation of the `DateTime/Date` - #[inline] - pub fn to_decimal(&self) -> Result { - if self.is_zero() { - return Ok(0.into()); - } - - self.to_numeric_string().parse() - } - fn parse_datetime_format(s: &str) -> Vec<&str> { let trimmed = s.trim(); if trimmed.is_empty() { @@ -823,6 +813,17 @@ impl ConvertTo for Time { } } +impl ConvertTo for Time { + #[inline] + fn convert(&self, _: &mut EvalContext) -> Result { + if self.is_zero() { + return Ok(0.into()); + } + + self.to_numeric_string().parse() + } +} + impl PartialOrd for Time { fn partial_cmp(&self, right: &Time) -> Option { Some(self.cmp(right)) @@ -976,7 +977,6 @@ mod tests { use chrono::{Duration, Local}; - use crate::codec::convert::convert_bytes_to_decimal; use crate::codec::mysql::{Duration as MyDuration, MAX_FSP, UNSPECIFIED_FSP}; use crate::expr::EvalContext; @@ -1307,10 +1307,10 @@ mod tests { let mut ctx = EvalContext::default(); for (s, fsp, expect) in cases { let t = Time::parse_utc_datetime(s, fsp).unwrap(); - let get = t.to_decimal().unwrap(); + let get: Decimal = t.convert(&mut ctx).unwrap(); assert_eq!( get, - convert_bytes_to_decimal(&mut ctx, expect.as_bytes()).unwrap(), + expect.as_bytes().convert(&mut ctx).unwrap(), "convert datetime {} to decimal", s ); @@ -1358,13 +1358,16 @@ mod tests { for (t_str, fsp, datetime_dec, date_dec) in cases { for_each_tz(move |tz, _offset| { + let mut ctx = EvalContext::default(); let mut t = Time::parse_datetime(t_str, fsp, &tz).unwrap(); - let mut res = format!("{}", t.to_decimal().unwrap()); + let dec: Result = t.convert(&mut ctx); + let mut res = format!("{}", dec.unwrap()); assert_eq!(res, datetime_dec); t = Time::parse_datetime(t_str, 0, &tz).unwrap(); t.set_time_type(TimeType::Date).unwrap(); - res = format!("{}", t.to_decimal().unwrap()); + let dec: Result = t.convert(&mut ctx); + res = format!("{}", dec.unwrap()); assert_eq!(res, date_dec); }); } diff --git a/components/tidb_query/src/executor/aggregate.rs b/components/tidb_query/src/executor/aggregate.rs index c57a8afc093..c381a5fc70e 100644 --- a/components/tidb_query/src/executor/aggregate.rs +++ b/components/tidb_query/src/executor/aggregate.rs @@ -341,6 +341,12 @@ mod tests { assert_eq!(v, Datum::F64(res)); } + fn f64_to_decimal(ctx: &mut EvalContext, f: f64) -> Result { + use crate::codec::convert::ConvertTo; + let val = f.convert(ctx)?; + Ok(val) + } + #[test] fn test_bit_and() { let mut aggr = AggBitAnd { @@ -356,9 +362,9 @@ mod tests { Datum::U64(1), Datum::I64(3), Datum::I64(2), - Datum::Dec(Decimal::from_f64(1.234).unwrap()), - Datum::Dec(Decimal::from_f64(3.012).unwrap()), - Datum::Dec(Decimal::from_f64(2.12345678).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 1.234).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 3.012).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 2.12345678).unwrap()), ]; for v in data { @@ -378,10 +384,10 @@ mod tests { Datum::U64(1), Datum::I64(3), Datum::I64(2), - Datum::Dec(Decimal::from_f64(12.34).unwrap()), - Datum::Dec(Decimal::from_f64(1.012).unwrap()), - Datum::Dec(Decimal::from_f64(15.12345678).unwrap()), - Datum::Dec(Decimal::from_f64(16.000).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 12.34).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 1.012).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 15.12345678).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 16.000).unwrap()), ]; for v in data { @@ -402,9 +408,9 @@ mod tests { Datum::U64(1), Datum::I64(3), Datum::I64(2), - Datum::Dec(Decimal::from_f64(1.234).unwrap()), - Datum::Dec(Decimal::from_f64(1.012).unwrap()), - Datum::Dec(Decimal::from_f64(2.12345678).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 1.234).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 1.012).unwrap()), + Datum::Dec(f64_to_decimal(&mut ctx, 2.12345678).unwrap()), ]; for v in data { diff --git a/components/tidb_query/src/expr/builtin_cast.rs b/components/tidb_query/src/expr/builtin_cast.rs index 685f21d8a58..3b07f7914a3 100644 --- a/components/tidb_query/src/expr/builtin_cast.rs +++ b/components/tidb_query/src/expr/builtin_cast.rs @@ -1,7 +1,7 @@ // Copyright 2017 TiKV Project Authors. Licensed under Apache-2.0. use std::borrow::Cow; -use std::convert::{TryFrom, TryInto}; +use std::convert::TryInto; use std::{i64, str, u64}; use tidb_query_datatype::prelude::*; @@ -105,7 +105,7 @@ impl ScalarFunc { pub fn cast_time_as_int(&self, ctx: &mut EvalContext, row: &[Datum]) -> Result> { let val = try_opt!(self.children[0].eval_time(ctx, row)); - let dec = val.to_decimal()?; + let dec: Decimal = val.convert(ctx)?; let dec = dec .round(mysql::DEFAULT_FSP as i8, RoundMode::HalfEven) .unwrap(); @@ -119,7 +119,7 @@ impl ScalarFunc { row: &[Datum], ) -> Result> { let val = try_opt!(self.children[0].eval_duration(ctx, row)); - let dec = Decimal::try_from(val)?; + let dec: Decimal = val.convert(ctx)?; let dec = dec .round(mysql::DEFAULT_FSP as i8, RoundMode::HalfEven) .unwrap(); @@ -181,7 +181,7 @@ impl ScalarFunc { row: &[Datum], ) -> Result> { let val = try_opt!(self.children[0].eval_duration(ctx, row)); - let val = Decimal::try_from(val)?; + let val: Decimal = val.convert(ctx)?; let res = val.convert(ctx)?; Ok(Some(self.produce_float_with_specified_tp(ctx, res)?)) } @@ -213,8 +213,8 @@ impl ScalarFunc { ctx: &mut EvalContext, row: &'a [Datum], ) -> Result>> { - let val = try_opt!(self.children[0].eval_real(ctx, row)); - let res = Decimal::from_f64(val)?; + let val: f64 = try_opt!(self.children[0].eval_real(ctx, row)); + let res: Decimal = val.convert(ctx)?; self.produce_dec_with_specified_tp(ctx, Cow::Owned(res)) .map(Some) } @@ -254,7 +254,7 @@ impl ScalarFunc { row: &'a [Datum], ) -> Result>> { let val = try_opt!(self.children[0].eval_time(ctx, row)); - let dec = val.to_decimal()?; + let dec = val.convert(ctx)?; self.produce_dec_with_specified_tp(ctx, Cow::Owned(dec)) .map(Some) } @@ -265,7 +265,7 @@ impl ScalarFunc { row: &'a [Datum], ) -> Result>> { let val = try_opt!(self.children[0].eval_duration(ctx, row)); - let dec = Decimal::try_from(val)?; + let dec: Decimal = val.convert(ctx)?; self.produce_dec_with_specified_tp(ctx, Cow::Owned(dec)) .map(Some) } @@ -276,8 +276,8 @@ impl ScalarFunc { row: &'a [Datum], ) -> Result>> { let val = try_opt!(self.children[0].eval_json(ctx, row)); - let val = val.convert(ctx)?; - let dec = Decimal::from_f64(val)?; + let val: f64 = val.convert(ctx)?; + let dec: Decimal = val.convert(ctx)?; self.produce_dec_with_specified_tp(ctx, Cow::Owned(dec)) .map(Some) } @@ -1019,6 +1019,12 @@ mod tests { } } + fn f64_to_decimal(ctx: &mut EvalContext, f: f64) -> Result { + use crate::codec::convert::ConvertTo; + let val = f.convert(ctx)?; + Ok(val) + } + #[test] fn test_cast_as_decimal() { let mut ctx = EvalContext::new(Arc::new(EvalConfig::default_for_test())); @@ -1040,7 +1046,7 @@ mod tests { vec![Datum::I64(1234)], 7, 3, - Decimal::from_f64(1234.000).unwrap(), + f64_to_decimal(&mut ctx, 1234.000).unwrap(), ), ( ScalarFuncSig::CastStringAsDecimal, @@ -1056,7 +1062,7 @@ mod tests { vec![Datum::Bytes(b"1234".to_vec())], 7, 3, - Decimal::from_f64(1234.000).unwrap(), + f64_to_decimal(&mut ctx, 1234.000).unwrap(), ), ( ScalarFuncSig::CastRealAsDecimal, @@ -1072,7 +1078,7 @@ mod tests { vec![Datum::F64(1234.123)], 8, 4, - Decimal::from_f64(1234.1230).unwrap(), + f64_to_decimal(&mut ctx, 1234.1230).unwrap(), ), ( ScalarFuncSig::CastTimeAsDecimal, @@ -1104,7 +1110,7 @@ mod tests { vec![Datum::Dur(duration_t)], 7, 1, - Decimal::from_f64(120023.0).unwrap(), + f64_to_decimal(&mut ctx, 120023.0).unwrap(), ), ( ScalarFuncSig::CastJsonAsDecimal, @@ -1120,7 +1126,7 @@ mod tests { vec![Datum::Json(Json::I64(1))], 2, 1, - Decimal::from_f64(1.0).unwrap(), + f64_to_decimal(&mut ctx, 1.0).unwrap(), ), ( ScalarFuncSig::CastDecimalAsDecimal, @@ -1136,7 +1142,7 @@ mod tests { vec![Datum::Dec(Decimal::from(1))], 2, 1, - Decimal::from_f64(1.0).unwrap(), + f64_to_decimal(&mut ctx, 1.0).unwrap(), ), ]; @@ -1733,7 +1739,7 @@ mod tests { let mut ctx = EvalContext::new(Arc::new(EvalConfig::default_for_test())); let cases = vec![ ( - vec![Datum::Dec(Decimal::from_f64(32.0001).unwrap())], + vec![Datum::Dec(f64_to_decimal(&mut ctx, 32.0001).unwrap())], Some(Json::Double(32.0001)), ), (vec![Datum::Null], None), @@ -1887,18 +1893,19 @@ mod tests { #[test] fn test_dec_as_int_with_overflow() { + let mut ctx = EvalContext::default(); let cases = vec![ ( FieldTypeFlag::empty(), vec![Datum::Dec( - Decimal::from_f64(i64::MAX as f64 + 100.5).unwrap(), + f64_to_decimal(&mut ctx, i64::MAX as f64 + 100.5).unwrap(), )], i64::MAX, ), ( FieldTypeFlag::UNSIGNED, vec![Datum::Dec( - Decimal::from_f64(u64::MAX as f64 + 100.5).unwrap(), + f64_to_decimal(&mut ctx, u64::MAX as f64 + 100.5).unwrap(), )], u64::MAX as i64, ), diff --git a/components/tidb_query/src/expr/mod.rs b/components/tidb_query/src/expr/mod.rs index e0cfc8df0e1..ab650d3f803 100644 --- a/components/tidb_query/src/expr/mod.rs +++ b/components/tidb_query/src/expr/mod.rs @@ -305,7 +305,7 @@ where let left = left.into_arith(ctx)?; let right = right.into_arith(ctx)?; - let (left, right) = Datum::coerce(left, right)?; + let (left, right) = Datum::coerce(ctx, left, right)?; if left == Datum::Null || right == Datum::Null { return Ok(Datum::Null); } diff --git a/components/tidb_query/src/rpn_expr/impl_arithmetic.rs b/components/tidb_query/src/rpn_expr/impl_arithmetic.rs index 84c4c211113..6b93793b4dc 100644 --- a/components/tidb_query/src/rpn_expr/impl_arithmetic.rs +++ b/components/tidb_query/src/rpn_expr/impl_arithmetic.rs @@ -965,21 +965,21 @@ mod tests { #[test] fn test_int_divide_decimal() { let test_cases = vec![ - (Some(11.01), Some(1.1), Some(10)), - (Some(-11.01), Some(1.1), Some(-10)), - (Some(11.01), Some(-1.1), Some(-10)), - (Some(-11.01), Some(-1.1), Some(10)), - (Some(123.0), None, None), - (None, Some(123.0), None), + (Some("11.01"), Some("1.1"), Some(10)), + (Some("-11.01"), Some("1.1"), Some(-10)), + (Some("11.01"), Some("-1.1"), Some(-10)), + (Some("-11.01"), Some("-1.1"), Some(10)), + (Some("123.0"), None, None), + (None, Some("123.0"), None), // divide by zero - (Some(0.0), Some(0.0), None), + (Some("0.0"), Some("0.0"), None), (None, None, None), ]; for (lhs, rhs, expected) in test_cases { let output = RpnFnScalarEvaluator::new() - .push_param(lhs.map(|f| Decimal::from_f64(f).unwrap())) - .push_param(rhs.map(|f| Decimal::from_f64(f).unwrap())) + .push_param(lhs.map(|f| Decimal::from_bytes(f.as_bytes()).unwrap().unwrap())) + .push_param(rhs.map(|f| Decimal::from_bytes(f.as_bytes()).unwrap().unwrap())) .evaluate(ScalarFuncSig::IntDivideDecimal) .unwrap(); @@ -993,7 +993,7 @@ mod tests { (Decimal::from(std::i64::MIN), Decimal::from(-1)), ( Decimal::from(std::i64::MAX), - Decimal::from_f64(0.1).unwrap(), + Decimal::from_bytes(b"0.1").unwrap().unwrap(), ), ]; diff --git a/components/tidb_query/src/rpn_expr/impl_cast.rs b/components/tidb_query/src/rpn_expr/impl_cast.rs index e582877ce54..adc46323ccc 100644 --- a/components/tidb_query/src/rpn_expr/impl_cast.rs +++ b/components/tidb_query/src/rpn_expr/impl_cast.rs @@ -23,13 +23,6 @@ pub fn get_cast_fn_rpn_node( let from = box_try!(EvalType::try_from(from_field_type.tp())); let to = box_try!(EvalType::try_from(to_field_type.tp())); let func_meta = match (from, to) { - (EvalType::Int, EvalType::Decimal) => { - if !from_field_type.is_unsigned() && !to_field_type.is_unsigned() { - cast_int_as_decimal_fn_meta() - } else { - cast_uint_as_decimal_fn_meta() - } - } (EvalType::Int, EvalType::Real) => { if !from_field_type.is_unsigned() { cast_any_as_any_fn_meta::() @@ -41,7 +34,18 @@ pub fn get_cast_fn_rpn_node( (EvalType::Decimal, EvalType::Real) => cast_any_as_any_fn_meta::(), (EvalType::DateTime, EvalType::Real) => cast_any_as_any_fn_meta::(), (EvalType::Duration, EvalType::Real) => cast_any_as_any_fn_meta::(), - (EvalType::Json, EvalType::Real) => cast_any_as_any_fn_meta::(), + (EvalType::Int, EvalType::Decimal) => { + if !from_field_type.is_unsigned() && !to_field_type.is_unsigned() { + cast_any_as_decimal_fn_meta::() + } else { + cast_uint_as_decimal_fn_meta() + } + } + (EvalType::Bytes, EvalType::Decimal) => cast_any_as_decimal_fn_meta::(), + (EvalType::Real, EvalType::Decimal) => cast_any_as_decimal_fn_meta::(), + (EvalType::DateTime, EvalType::Decimal) => cast_any_as_decimal_fn_meta::(), + (EvalType::Duration, EvalType::Decimal) => cast_any_as_decimal_fn_meta::(), + (EvalType::Json, EvalType::Decimal) => cast_any_as_decimal_fn_meta::(), (EvalType::Int, EvalType::Int) => { match (from_field_type.is_unsigned(), to_field_type.is_unsigned()) { (false, false) => cast_any_as_any_fn_meta::(), @@ -167,15 +171,15 @@ pub fn cast_uint_as_decimal( /// The signed int implementation for push down signature `CastIntAsDecimal`. #[rpn_fn(capture = [ctx, extra])] #[inline] -pub fn cast_int_as_decimal( +pub fn cast_any_as_decimal>( ctx: &mut EvalContext, extra: &RpnFnCallExtra<'_>, - val: &Option, + val: &Option, ) -> Result> { match val { None => Ok(None), Some(val) => { - let dec = Decimal::from(*val); + let dec: Decimal = val.convert(ctx)?; Ok(Some(produce_dec_with_specified_tp( ctx, dec, diff --git a/components/tidb_query/src/rpn_expr/impl_compare.rs b/components/tidb_query/src/rpn_expr/impl_compare.rs index 17e1248cbef..55e334c0acf 100644 --- a/components/tidb_query/src/rpn_expr/impl_compare.rs +++ b/components/tidb_query/src/rpn_expr/impl_compare.rs @@ -512,6 +512,13 @@ mod tests { #[test] fn test_compare_decimal() { + use crate::codec::convert::ConvertTo; + use crate::expr::EvalContext; + fn f64_to_decimal(ctx: &mut EvalContext, f: f64) -> Result { + let val = f.convert(ctx)?; + Ok(val) + } + let mut ctx = EvalContext::default(); for (arg0, arg1, cmp_op, expect_output) in generate_numeric_compare_cases() { let sig = match cmp_op { TestCaseCmpOp::GT => ScalarFuncSig::GTDecimal, @@ -523,8 +530,8 @@ mod tests { TestCaseCmpOp::NullEQ => ScalarFuncSig::NullEQDecimal, }; let output = RpnFnScalarEvaluator::new() - .push_param(arg0.map(|v| Decimal::from_f64(v.into_inner()).unwrap())) - .push_param(arg1.map(|v| Decimal::from_f64(v.into_inner()).unwrap())) + .push_param(arg0.map(|v| f64_to_decimal(&mut ctx, v.into_inner()).unwrap())) + .push_param(arg1.map(|v| f64_to_decimal(&mut ctx, v.into_inner()).unwrap())) .evaluate(sig) .unwrap(); assert_eq!(output, expect_output, "{:?}, {:?}, {:?}", arg0, arg1, sig); diff --git a/fuzz/targets/mod.rs b/fuzz/targets/mod.rs index 71b0382ebdc..36a1444d6b3 100644 --- a/fuzz/targets/mod.rs +++ b/fuzz/targets/mod.rs @@ -102,7 +102,9 @@ impl ReadAsDecimalRoundMode for T {} #[inline(always)] pub fn fuzz_coprocessor_codec_decimal(data: &[u8]) -> Result<(), Error> { - use tidb_query::codec::mysql::decimal::Decimal; + use tidb_query::codec::convert::ConvertTo; + use tidb_query::codec::data_type::Decimal; + use tidb_query::expr::EvalContext; fn fuzz(lhs: &Decimal, rhs: &Decimal, cursor: &mut Cursor<&[u8]>) -> Result<(), Error> { let _ = lhs.clone().abs(); @@ -134,8 +136,9 @@ pub fn fuzz_coprocessor_codec_decimal(data: &[u8]) -> Result<(), Error> { } let mut cursor = Cursor::new(data); - let decimal1 = Decimal::from_f64(cursor.read_as_f64()?)?; - let decimal2 = Decimal::from_f64(cursor.read_as_f64()?)?; + let mut ctx = EvalContext::default(); + let decimal1: Decimal = cursor.read_as_f64()?.convert(&mut ctx)?; + let decimal2: Decimal = cursor.read_as_f64()?.convert(&mut ctx)?; let _ = fuzz(&decimal1, &decimal2, &mut cursor); let _ = fuzz(&decimal2, &decimal1, &mut cursor); Ok(()) @@ -154,11 +157,14 @@ trait ReadAsTimeType: ReadLiteralExt { impl ReadAsTimeType for T {} fn fuzz_time(t: tidb_query::codec::mysql::Time, mut cursor: Cursor<&[u8]>) -> Result<(), Error> { + use tidb_query::codec::convert::ConvertTo; + use tidb_query::codec::data_type::Decimal; use tidb_query::codec::mysql::TimeEncoder; + use tidb_query::expr::EvalContext; + let _ = t.clone().set_time_type(cursor.read_as_time_type()?); let _ = t.is_zero(); let _ = t.invalid_zero(); - let _ = t.to_decimal(); let _ = t.to_duration(); let _ = t.to_packed_u64(); let _ = t.clone().round_frac(cursor.read_as_i8()?); @@ -167,6 +173,13 @@ fn fuzz_time(t: tidb_query::codec::mysql::Time, mut cursor: Cursor<&[u8]>) -> Re let _ = t.to_string(); let mut v = Vec::new(); let _ = v.encode_time(&t); + + let mut ctx = EvalContext::default(); + let _: i64 = t.convert(&mut ctx)?; + let _: u64 = t.convert(&mut ctx)?; + let _: f64 = t.convert(&mut ctx)?; + let _: Vec = t.convert(&mut ctx)?; + let _: Decimal = t.convert(&mut ctx)?; Ok(()) } @@ -198,9 +211,10 @@ fn fuzz_duration( t: tidb_query::codec::mysql::Duration, mut cursor: Cursor<&[u8]>, ) -> Result<(), Error> { - use std::convert::TryFrom; + use tidb_query::codec::convert::ConvertTo; use tidb_query::codec::mysql::decimal::Decimal; use tidb_query::codec::mysql::DurationEncoder; + use tidb_query::expr::EvalContext; let _ = t.fsp(); let u = t; @@ -211,12 +225,15 @@ fn fuzz_duration( let _ = t.subsec_micros(); let _ = t.to_secs_f64(); let _ = t.is_zero(); - let _ = Decimal::try_from(t)?; let u = t; u.round_frac(cursor.read_as_i8()?)?; let mut v = Vec::new(); let _ = v.encode_duration(t); + + let mut ctx = EvalContext::default(); + let _: Decimal = t.convert(&mut ctx)?; + Ok(()) }