Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(query): support decimal #113

Merged
merged 4 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/tests/00-base.result
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ a 1 true
2
3
with comment
3.00 3.00
bye
2 changes: 2 additions & 0 deletions cli/tests/00-base.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ select /* ignore this block */ 'with comment';
select 'in comment block';
*/

select 1.00 + 2.00, 3.00;

select 'bye';
drop table test;
4 changes: 2 additions & 2 deletions driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rustls = ["databend-client/rustls"]
# Enable native-tls for TLS support
native-tls = ["databend-client/native-tls"]

flight-sql = ["dep:arrow", "dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"]
flight-sql = ["dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"]

[dependencies]
async-trait = "0.1.68"
Expand All @@ -30,7 +30,7 @@ tokio = { version = "1.27.0", features = ["macros"] }
tokio-stream = "0.1.12"
url = { version = "2.3.1", default-features = false }

arrow = { version = "38.0.0", optional = true }
arrow = { version = "38.0.0" }
arrow-array = { version = "38.0.0", optional = true }
arrow-cast = { version = "38.0.0", features = ["prettyprint"], optional = true }
arrow-flight = { version = "38.0.0", features = ["flight-sql-experimental"], optional = true }
Expand Down
2 changes: 1 addition & 1 deletion driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ mod value;

pub use conn::{new_connection, Connection};
pub use rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress};
pub use schema::{DataType, Schema, SchemaRef};
pub use schema::{DataType, DecimalSize, Schema, SchemaRef};
pub use value::{NumberValue, Value};
88 changes: 70 additions & 18 deletions driver/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,26 @@ pub enum NumberDataType {
Float64,
}

// #[derive(Debug, Clone, PartialEq, Eq)]
// pub struct DecimalSize {
// pub precision: u8,
// pub scale: u8,
// }

// #[derive(Debug, Clone, PartialEq, Eq)]
// pub enum DecimalDataType {
// Decimal128(DecimalSize),
// Decimal256(DecimalSize),
// }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DecimalSize {
pub precision: u8,
pub scale: u8,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecimalDataType {
Decimal128(DecimalSize),
Decimal256(DecimalSize),
}

impl DecimalDataType {
pub fn decimal_size(&self) -> &DecimalSize {
match self {
DecimalDataType::Decimal128(size) => size,
DecimalDataType::Decimal256(size) => size,
}
}
}

#[derive(Debug, Clone)]
pub enum DataType {
Expand All @@ -55,9 +64,7 @@ pub enum DataType {
Boolean,
String,
Number(NumberDataType),
Decimal,
// TODO:(everpcpc) fix Decimal internal type
// Decimal(DecimalDataType),
Decimal(DecimalDataType),
Timestamp,
Date,
Nullable(Box<DataType>),
Expand Down Expand Up @@ -98,7 +105,10 @@ impl std::fmt::Display for DataType {
NumberDataType::Float32 => write!(f, "Float32"),
NumberDataType::Float64 => write!(f, "Float64"),
},
DataType::Decimal => write!(f, "Decimal"),
DataType::Decimal(d) => {
let size = d.decimal_size();
write!(f, "Decimal({}, {})", size.precision, size.scale)
}
DataType::Timestamp => write!(f, "Timestamp"),
DataType::Date => write!(f, "Date"),
DataType::Nullable(inner) => write!(f, "Nullable({})", inner),
Expand Down Expand Up @@ -152,7 +162,22 @@ impl TryFrom<&TypeDesc<'_>> for DataType {
"UInt64" => DataType::Number(NumberDataType::UInt64),
"Float32" => DataType::Number(NumberDataType::Float32),
"Float64" => DataType::Number(NumberDataType::Float64),
"Decimal" => DataType::Decimal,
"Decimal" => {
let precision = desc.args[0].name.parse::<u8>()?;
let scale = desc.args[1].name.parse::<u8>()?;

if precision <= 38 {
DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
precision,
scale,
}))
} else {
DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
precision,
scale,
}))
}
}
"Timestamp" => DataType::Timestamp,
"Date" => DataType::Date,
"Nullable" => {
Expand Down Expand Up @@ -247,8 +272,18 @@ impl TryFrom<&Arc<ArrowField>> for Field {
| ArrowDataType::FixedSizeBinary(_) => DataType::String,
ArrowDataType::Timestamp(_, _) => DataType::Timestamp,
ArrowDataType::Date32 => DataType::Date,
ArrowDataType::Decimal128(_, _) => DataType::Decimal,
ArrowDataType::Decimal256(_, _) => DataType::Decimal,
ArrowDataType::Decimal128(p, s) => {
DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
precision: *p,
scale: *s as u8,
}))
}
ArrowDataType::Decimal256(p, s) => {
DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
precision: *p,
scale: *s as u8,
}))
}
_ => {
return Err(Error::Parsing(format!(
"Unsupported datatype for arrow field: {:?}",
Expand Down Expand Up @@ -356,6 +391,23 @@ mod test {
args: vec![],
},
},
TestCase {
desc: "decimal type",
input: "Decimal(42, 42)",
output: TypeDesc {
name: "Decimal",
args: vec![
TypeDesc {
name: "42",
args: vec![],
},
TypeDesc {
name: "42",
args: vec![],
},
],
},
},
TestCase {
desc: "nullable type",
input: "Nullable(Nothing)",
Expand Down
171 changes: 165 additions & 6 deletions driver/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow::datatypes::i256;
sundy-li marked this conversation as resolved.
Show resolved Hide resolved
use arrow_array::{ArrowNativeTypeOp, Decimal128Array, Decimal256Array};
use chrono::{Datelike, NaiveDate, NaiveDateTime};

use crate::error::{ConvertError, Error, Result};
use crate::{
error::{ConvertError, Error, Result},
schema::{DecimalDataType, DecimalSize},
};
use std::fmt::Write;

// Thu 1970-01-01 is R.D. 719163
const DAYS_FROM_CE: i32 = 719_163;
Expand Down Expand Up @@ -45,16 +51,16 @@ pub enum NumberValue {
UInt64(u64),
Float32(f32),
Float64(f64),
Decimal128(i128, DecimalSize),
Decimal256(i256, DecimalSize),
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Null,
Boolean(bool),
String(String),
Number(NumberValue),
// TODO:(everpcpc) Decimal(DecimalValue),
// Decimal(String),
/// Microseconds from 1970-01-01 00:00:00 UTC
Timestamp(i64),
Date(i32),
Expand Down Expand Up @@ -82,9 +88,11 @@ impl Value {
NumberValue::UInt64(_) => DataType::Number(NumberDataType::UInt64),
NumberValue::Float32(_) => DataType::Number(NumberDataType::Float32),
NumberValue::Float64(_) => DataType::Number(NumberDataType::Float64),
NumberValue::Decimal128(_, s) => DataType::Decimal(DecimalDataType::Decimal128(*s)),
NumberValue::Decimal256(_, s) => DataType::Decimal(DecimalDataType::Decimal256(*s)),
},
// Self::Decimal(_) => DataType::Decimal,
Self::Timestamp(_) => DataType::Timestamp,

Self::Date(_) => DataType::Date,
// TODO:(everpcpc) fix nested type
// Self::Array(v) => DataType::Array(Box::new(v[0].get_type())),
Expand Down Expand Up @@ -133,7 +141,17 @@ impl TryFrom<(&DataType, &str)> for Value {
DataType::Number(NumberDataType::Float64) => {
Ok(Self::Number(NumberValue::Float64(v.parse()?)))
}
// DataType::Decimal => Ok(Self::Decimal(v)),

DataType::Decimal(DecimalDataType::Decimal128(size)) => {
let d = parse_decimal(v, *size)?;
Ok(Self::Number(d))
}

DataType::Decimal(DecimalDataType::Decimal256(size)) => {
let d = parse_decimal(v, *size)?;
Ok(Self::Number(d))
}

DataType::Timestamp => Ok(Self::Timestamp(
chrono::NaiveDateTime::parse_from_str(v, "%Y-%m-%d %H:%M:%S%.6f")?
.timestamp_micros(),
Expand Down Expand Up @@ -210,6 +228,32 @@ impl TryFrom<(&ArrowField, &Arc<dyn ArrowArray>, usize)> for Value {
None => Err(ConvertError::new("float64", format!("{:?}", array)).into()),
},

ArrowDataType::Decimal128(p, s) => {
match array.as_any().downcast_ref::<Decimal128Array>() {
Some(array) => Ok(Value::Number(NumberValue::Decimal128(
array.value(seq),
DecimalSize {
precision: *p,
scale: *s as u8,
},
))),
None => Err(ConvertError::new("Decimal128", format!("{:?}", array)).into()),
}
}

ArrowDataType::Decimal256(p, s) => {
match array.as_any().downcast_ref::<Decimal256Array>() {
Some(array) => Ok(Value::Number(NumberValue::Decimal256(
array.value(seq),
DecimalSize {
precision: *p,
scale: *s as u8,
},
))),
None => Err(ConvertError::new("Decimal256", format!("{:?}", array)).into()),
}
}

ArrowDataType::Binary => match array.as_any().downcast_ref::<BinaryArray>() {
Some(array) => Ok(Value::String(String::from_utf8(array.value(seq).to_vec())?)),
None => Err(ConvertError::new("binary", format!("{:?}", array)).into()),
Expand Down Expand Up @@ -418,6 +462,8 @@ impl std::fmt::Display for NumberValue {
NumberValue::UInt64(i) => write!(f, "{}", i),
NumberValue::Float32(i) => write!(f, "{}", i),
NumberValue::Float64(i) => write!(f, "{}", i),
NumberValue::Decimal128(v, s) => write!(f, "{}", display_decimal_128(*v, s.scale)),
NumberValue::Decimal256(v, s) => write!(f, "{}", display_decimal_256(*v, s.scale)),
}
}
}
Expand All @@ -443,3 +489,116 @@ impl std::fmt::Display for Value {
}
}
}

pub fn display_decimal_128(num: i128, scale: u8) -> String {
let mut buf = String::new();
if scale == 0 {
write!(buf, "{}", num).unwrap();
} else {
let pow_scale = 10_i128.pow(scale as u32);
if num >= 0 {
write!(
buf,
"{}.{:0>width$}",
num / pow_scale,
(num % pow_scale).abs(),
width = scale as usize
)
.unwrap();
} else {
write!(
buf,
"-{}.{:0>width$}",
-num / pow_scale,
(num % pow_scale).abs(),
width = scale as usize
)
.unwrap();
}
}
buf
}

pub fn display_decimal_256(num: i256, scale: u8) -> String {
let mut buf = String::new();
if scale == 0 {
write!(buf, "{}", num).unwrap();
} else {
let pow_scale = i256::from_i128(10i128).pow_wrapping(scale as u32);
// -1/10 = 0
if num >= i256::ZERO {
write!(
buf,
"{}.{:0>width$}",
num / pow_scale,
(num % pow_scale).wrapping_abs(),
width = scale as usize
)
.unwrap();
} else {
write!(
buf,
"-{}.{:0>width$}",
-num / pow_scale,
(num % pow_scale).wrapping_abs(),
width = scale as usize
)
.unwrap();
}
}
buf
}

/// assume text is from
/// used only for expr, so put more weight on readability
pub fn parse_decimal(text: &str, size: DecimalSize) -> Result<NumberValue> {
let mut start = 0;
let bytes = text.as_bytes();
while bytes[start] == b'0' {
start += 1
}
let text = &text[start..];
let point_pos = text.find('.');
let e_pos = text.find(|c| c == 'e' || c == 'E');
let (i_part, f_part, e_part) = match (point_pos, e_pos) {
(Some(p1), Some(p2)) => (&text[..p1], &text[(p1 + 1)..p2], Some(&text[(p2 + 1)..])),
(Some(p), None) => (&text[..p], &text[(p + 1)..], None),
(None, Some(p)) => (&text[..p], "", Some(&text[(p + 1)..])),
_ => {
unreachable!()
}
};
let exp = match e_part {
Some(s) => s.parse::<i32>()?,
None => 0,
};
if i_part.len() as i32 + exp > 76 {
Err(ConvertError::new("decimal", format!("{:?}", text)).into())
} else {
let mut digits = Vec::with_capacity(76);
digits.extend_from_slice(i_part.as_bytes());
digits.extend_from_slice(f_part.as_bytes());
if digits.is_empty() {
digits.push(b'0')
}
let scale = f_part.len() as i32 - exp;
if scale < 0 {
// e.g 123.1e3
for _ in 0..(-scale) {
digits.push(b'0')
}
};

let precision = std::cmp::min(digits.len(), 76);
let digits = unsafe { std::str::from_utf8_unchecked(&digits[..precision]) };

if size.precision > 38 {
Ok(NumberValue::Decimal256(
i256::from_string(digits).unwrap(),
size,
))
} else {
Ok(NumberValue::Decimal128(digits.parse::<i128>()?, size))
}
}
}
Loading