From 3f1ea0a7dcd22d12fe391a9d8fb3b5871a96444f Mon Sep 17 00:00:00 2001 From: Elliana May Date: Wed, 17 Apr 2024 16:21:43 +0800 Subject: [PATCH] feat: add interval read support (#291) * feat(arrow): interval read support * add support for binding intervals * add chrono support for Duration/TimeDelta * further tests * break down units * add interval to test_all_types * switch to struct variant * fix: write out correct extraction logic * fix: clippy * fix: add chrono feature * build: remove duplicate dependency * chore: improve nanos error message * chore: overflow nanos into seconds * chore: add todo * chore: name magic number --- Cargo.toml | 4 ++- src/lib.rs | 22 ++++++++++++ src/row.rs | 30 +++++++++++----- src/statement.rs | 4 +++ src/test_all_types.rs | 20 ++++++++++- src/types/chrono.rs | 80 ++++++++++++++++++++++++++++++++++++++++-- src/types/mod.rs | 3 ++ src/types/value.rs | 10 ++++++ src/types/value_ref.rs | 12 +++++++ 9 files changed, 172 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 62cdf61e..9757360f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ extensions-full = ["httpfs", "json", "parquet", "vtab-full"] buildtime_bindgen = ["libduckdb-sys/buildtime_bindgen"] modern-full = ["chrono", "serde_json", "url", "r2d2", "uuid", "polars"] polars = ["dep:polars"] +chrono = ["dep:chrono", "num-integer"] [dependencies] # time = { version = "0.3.2", features = ["formatting", "parsing"], optional = true } @@ -52,7 +53,7 @@ memchr = "2.3" uuid = { version = "1.0", optional = true } smallvec = "1.6.1" cast = { version = "0.3", features = ["std"] } -arrow = { version = "50", default-features = false, features = ["prettyprint", "ffi"] } +arrow = { version = "51", default-features = false, features = ["prettyprint", "ffi"] } rust_decimal = "1.14" strum = { version = "0.25", features = ["derive"] } r2d2 = { version = "0.8.9", optional = true } @@ -60,6 +61,7 @@ calamine = { version = "0.22.0", optional = true } num = { version = "0.4", optional = true, default-features = false, features = ["std"] } duckdb-loadable-macros = { version = "0.1.1", path="./duckdb-loadable-macros", optional = true } polars = { version = "0.35.4", features = ["dtype-full"], optional = true} +num-integer = {version = "0.1.46", optional = true} [dev-dependencies] doc-comment = "0.3" diff --git a/src/lib.rs b/src/lib.rs index d1541653..81e2a5d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -570,6 +570,8 @@ doc_comment::doctest!("../README.md"); #[cfg(test)] mod test { + use crate::types::Value; + use super::*; use std::{error::Error as StdError, fmt}; @@ -1297,6 +1299,26 @@ mod test { Ok(()) } + #[test] + fn round_trip_interval() -> Result<()> { + let db = checked_memory_handle(); + db.execute_batch("CREATE TABLE foo (t INTERVAL);")?; + + let d = Value::Interval { + months: 1, + days: 2, + nanos: 3, + }; + db.execute("INSERT INTO foo VALUES (?)", [d])?; + + let mut stmt = db.prepare("SELECT t FROM foo")?; + let mut rows = stmt.query([])?; + let row = rows.next()?.unwrap(); + let d: Value = row.get_unwrap(0); + assert_eq!(d, d); + Ok(()) + } + #[test] fn test_database_name_to_string() -> Result<()> { assert_eq!(DatabaseName::Main.to_string(), "main"); diff --git a/src/row.rs b/src/row.rs index 28238526..f5b88a15 100644 --- a/src/row.rs +++ b/src/row.rs @@ -542,15 +542,29 @@ impl<'stmt> Row<'stmt> { } ValueRef::Time64(types::TimeUnit::Microsecond, array.value(row)) } + DataType::Interval(unit) => match unit { + IntervalUnit::MonthDayNano => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + + if array.is_null(row) { + return ValueRef::Null; + } + + let value = array.value(row); + + // TODO: remove this manual conversion once arrow-rs bug is fixed + let months = (value) as i32; + let days = (value >> 32) as i32; + let nanos = (value >> 64) as i64; + + ValueRef::Interval { months, days, nanos } + } + _ => unimplemented!("{:?}", unit), + }, // TODO: support more data types - // DataType::Interval(unit) => match unit { - // IntervalUnit::DayTime => { - // make_string_interval_day_time!(column, row) - // } - // IntervalUnit::YearMonth => { - // make_string_interval_year_month!(column, row) - // } - // }, // DataType::List(_) => make_string_from_list!(column, row), // DataType::Dictionary(index_type, _value_type) => match **index_type { // DataType::Int8 => dict_array_value_to_string::(column, row), diff --git a/src/statement.rs b/src/statement.rs index 810399a9..1a05c2f4 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -497,6 +497,10 @@ impl Statement<'_> { }; ffi::duckdb_bind_timestamp(ptr, col as u64, ffi::duckdb_timestamp { micros }) }, + ValueRef::Interval { months, days, nanos } => unsafe { + let micros = nanos / 1_000; + ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros }) + }, _ => unreachable!("not supported: {}", value.data_type()), }; result_from_duckdb_prepare(rc, ptr) diff --git a/src/test_all_types.rs b/src/test_all_types.rs index 185c6118..bbdcdbf1 100644 --- a/src/test_all_types.rs +++ b/src/test_all_types.rs @@ -18,7 +18,6 @@ fn test_all_types() -> crate::Result<()> { // union is currently blocked by https://github.com/duckdb/duckdb/pull/11326 "union", // these remaining types are not yet supported by duckdb-rs - "interval", "small_enum", "medium_enum", "large_enum", @@ -219,6 +218,25 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) { 1 => assert_eq!(value, ValueRef::Blob(&[3, 245])), _ => assert_eq!(value, ValueRef::Null), }, + "interval" => match idx { + 0 => assert_eq!( + value, + ValueRef::Interval { + months: 0, + days: 0, + nanos: 0 + } + ), + 1 => assert_eq!( + value, + ValueRef::Interval { + months: 999, + days: 999, + nanos: 999999999000 + } + ), + _ => assert_eq!(value, ValueRef::Null), + }, _ => todo!("{column:?}"), } } diff --git a/src/types/chrono.rs b/src/types/chrono.rs index e4f9365c..fecce4bb 100644 --- a/src/types/chrono.rs +++ b/src/types/chrono.rs @@ -1,12 +1,15 @@ //! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types. -use chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; +use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; +use num_integer::Integer; use crate::{ types::{FromSql, FromSqlError, FromSqlResult, TimeUnit, ToSql, ToSqlOutput, ValueRef}, Result, }; +use super::Value; + /// ISO 8601 calendar date without timezone => "YYYY-MM-DD" impl ToSql for NaiveDate { #[inline] @@ -126,13 +129,55 @@ impl FromSql for DateTime { } } +impl FromSql for Duration { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value { + ValueRef::Interval { months, days, nanos } => { + let days = days + (months * 30); + let (additional_seconds, nanos) = nanos.div_mod_floor(&NANOS_PER_SECOND); + let seconds = additional_seconds + (i64::from(days) * 24 * 3600); + + match nanos.try_into() { + Ok(nanos) => { + if let Some(duration) = Duration::new(seconds, nanos) { + Ok(duration) + } else { + Err(FromSqlError::Other("Invalid duration".into())) + } + } + Err(err) => Err(FromSqlError::Other(format!("Invalid duration: {err}").into())), + } + } + _ => Err(FromSqlError::InvalidType), + } + } +} + +const DAYS_PER_MONTH: i64 = 30; +const SECONDS_PER_DAY: i64 = 24 * 3600; +const NANOS_PER_SECOND: i64 = 1_000_000_000; +const NANOS_PER_DAY: i64 = SECONDS_PER_DAY * NANOS_PER_SECOND; + +impl ToSql for Duration { + fn to_sql(&self) -> Result> { + let nanos = self.num_nanoseconds().unwrap(); + let (days, nanos) = nanos.div_mod_floor(&NANOS_PER_DAY); + let (months, days) = days.div_mod_floor(&DAYS_PER_MONTH); + Ok(ToSqlOutput::Owned(Value::Interval { + months: months.try_into().unwrap(), + days: days.try_into().unwrap(), + nanos, + })) + } +} + #[cfg(test)] mod test { use crate::{ - types::{FromSql, ValueRef}, + types::{FromSql, ToSql, ToSqlOutput, ValueRef}, Connection, Result, }; - use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; + use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, TimeZone, Utc}; fn checked_memory_handle() -> Result { let db = Connection::open_in_memory()?; @@ -216,6 +261,35 @@ mod test { Ok(()) } + #[test] + fn test_time_delta_roundtrip() { + roundtrip_type(TimeDelta::new(3600, 0).unwrap()); + roundtrip_type(TimeDelta::new(3600, 1000).unwrap()); + } + + #[test] + fn test_time_delta() -> Result<()> { + let db = checked_memory_handle()?; + let td = TimeDelta::new(3600, 0).unwrap(); + + let row: Result = db.query_row("SELECT ?", [td], |row| Ok(row.get(0)))?; + + assert_eq!(row.unwrap(), td); + + Ok(()) + } + + fn roundtrip_type(td: T) { + let sqled = td.to_sql().unwrap(); + let value = match sqled { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + }; + let reversed = FromSql::column_result(value).unwrap(); + + assert_eq!(td, reversed); + } + #[test] fn test_date_time_local() -> Result<()> { let db = checked_memory_handle()?; diff --git a/src/types/mod.rs b/src/types/mod.rs index d93f5974..ea793b86 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -144,6 +144,8 @@ pub enum Type { Date32, /// TIME64 Time64, + /// INTERVAL + Interval, /// Any Any, } @@ -170,6 +172,7 @@ impl fmt::Display for Type { Type::Blob => f.pad("Blob"), Type::Date32 => f.pad("Date32"), Type::Time64 => f.pad("Time64"), + Type::Interval => f.pad("Interval"), Type::Any => f.pad("Any"), } } diff --git a/src/types/value.rs b/src/types/value.rs index f3cf1f7f..58e4cc9d 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -46,6 +46,15 @@ pub enum Value { Date32(i32), /// The value is a time64 Time64(TimeUnit, i64), + /// The value is an interval (month, day, nano) + Interval { + /// months + months: i32, + /// days + days: i32, + /// nanos + nanos: i64, + }, } impl From for Value { @@ -212,6 +221,7 @@ impl Value { Value::Blob(_) => Type::Blob, Value::Date32(_) => Type::Date32, Value::Time64(..) => Type::Time64, + Value::Interval { .. } => Type::Interval, } } } diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index f071dc3b..ae8f71c6 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -61,6 +61,15 @@ pub enum ValueRef<'a> { Date32(i32), /// The value is a time64 Time64(TimeUnit, i64), + /// The value is an interval (month, day, nano) + Interval { + /// months + months: i32, + /// days + days: i32, + /// nanos + nanos: i64, + }, } impl ValueRef<'_> { @@ -87,6 +96,7 @@ impl ValueRef<'_> { ValueRef::Blob(_) => Type::Blob, ValueRef::Date32(_) => Type::Date32, ValueRef::Time64(..) => Type::Time64, + ValueRef::Interval { .. } => Type::Interval, } } } @@ -140,6 +150,7 @@ impl From> for Value { ValueRef::Blob(b) => Value::Blob(b.to_vec()), ValueRef::Date32(d) => Value::Date32(d), ValueRef::Time64(t, d) => Value::Time64(t, d), + ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos }, } } } @@ -181,6 +192,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> { Value::Blob(ref b) => ValueRef::Blob(b), Value::Date32(d) => ValueRef::Date32(d), Value::Time64(t, d) => ValueRef::Time64(t, d), + Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos }, } } }