Skip to content

Commit

Permalink
feat: add interval read support (#291)
Browse files Browse the repository at this point in the history
* feat(arrow): interval read support

* add support for binding intervals

* add chrono support for Duration/TimeDelta

* further tests

* break down units

* add interval to test_all_types

* switch to struct variant

* fix: write out correct extraction logic

* fix: clippy

* fix: add chrono feature

* build: remove duplicate dependency

* chore: improve nanos error message

* chore: overflow nanos into seconds

* chore: add todo

* chore: name magic number
  • Loading branch information
Mause authored Apr 17, 2024
1 parent f85893f commit 3f1ea0a
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 13 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ extensions-full = ["httpfs", "json", "parquet", "vtab-full"]
buildtime_bindgen = ["libduckdb-sys/buildtime_bindgen"]
modern-full = ["chrono", "serde_json", "url", "r2d2", "uuid", "polars"]
polars = ["dep:polars"]
chrono = ["dep:chrono", "num-integer"]

[dependencies]
# time = { version = "0.3.2", features = ["formatting", "parsing"], optional = true }
Expand All @@ -52,14 +53,15 @@ memchr = "2.3"
uuid = { version = "1.0", optional = true }
smallvec = "1.6.1"
cast = { version = "0.3", features = ["std"] }
arrow = { version = "50", default-features = false, features = ["prettyprint", "ffi"] }
arrow = { version = "51", default-features = false, features = ["prettyprint", "ffi"] }
rust_decimal = "1.14"
strum = { version = "0.25", features = ["derive"] }
r2d2 = { version = "0.8.9", optional = true }
calamine = { version = "0.22.0", optional = true }
num = { version = "0.4", optional = true, default-features = false, features = ["std"] }
duckdb-loadable-macros = { version = "0.1.1", path="./duckdb-loadable-macros", optional = true }
polars = { version = "0.35.4", features = ["dtype-full"], optional = true}
num-integer = {version = "0.1.46", optional = true}

[dev-dependencies]
doc-comment = "0.3"
Expand Down
22 changes: 22 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ doc_comment::doctest!("../README.md");

#[cfg(test)]
mod test {
use crate::types::Value;

use super::*;
use std::{error::Error as StdError, fmt};

Expand Down Expand Up @@ -1297,6 +1299,26 @@ mod test {
Ok(())
}

#[test]
fn round_trip_interval() -> Result<()> {
let db = checked_memory_handle();
db.execute_batch("CREATE TABLE foo (t INTERVAL);")?;

let d = Value::Interval {
months: 1,
days: 2,
nanos: 3,
};
db.execute("INSERT INTO foo VALUES (?)", [d])?;

let mut stmt = db.prepare("SELECT t FROM foo")?;
let mut rows = stmt.query([])?;
let row = rows.next()?.unwrap();
let d: Value = row.get_unwrap(0);
assert_eq!(d, d);
Ok(())
}

#[test]
fn test_database_name_to_string() -> Result<()> {
assert_eq!(DatabaseName::Main.to_string(), "main");
Expand Down
30 changes: 22 additions & 8 deletions src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,29 @@ impl<'stmt> Row<'stmt> {
}
ValueRef::Time64(types::TimeUnit::Microsecond, array.value(row))
}
DataType::Interval(unit) => match unit {
IntervalUnit::MonthDayNano => {
let array = column
.as_any()
.downcast_ref::<array::IntervalMonthDayNanoArray>()
.unwrap();

if array.is_null(row) {
return ValueRef::Null;
}

let value = array.value(row);

// TODO: remove this manual conversion once arrow-rs bug is fixed
let months = (value) as i32;
let days = (value >> 32) as i32;
let nanos = (value >> 64) as i64;

ValueRef::Interval { months, days, nanos }
}
_ => unimplemented!("{:?}", unit),
},
// TODO: support more data types
// DataType::Interval(unit) => match unit {
// IntervalUnit::DayTime => {
// make_string_interval_day_time!(column, row)
// }
// IntervalUnit::YearMonth => {
// make_string_interval_year_month!(column, row)
// }
// },
// DataType::List(_) => make_string_from_list!(column, row),
// DataType::Dictionary(index_type, _value_type) => match **index_type {
// DataType::Int8 => dict_array_value_to_string::<Int8Type>(column, row),
Expand Down
4 changes: 4 additions & 0 deletions src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ impl Statement<'_> {
};
ffi::duckdb_bind_timestamp(ptr, col as u64, ffi::duckdb_timestamp { micros })
},
ValueRef::Interval { months, days, nanos } => unsafe {
let micros = nanos / 1_000;
ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros })
},
_ => unreachable!("not supported: {}", value.data_type()),
};
result_from_duckdb_prepare(rc, ptr)
Expand Down
20 changes: 19 additions & 1 deletion src/test_all_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ fn test_all_types() -> crate::Result<()> {
// union is currently blocked by https://github.com/duckdb/duckdb/pull/11326
"union",
// these remaining types are not yet supported by duckdb-rs
"interval",
"small_enum",
"medium_enum",
"large_enum",
Expand Down Expand Up @@ -219,6 +218,25 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
1 => assert_eq!(value, ValueRef::Blob(&[3, 245])),
_ => assert_eq!(value, ValueRef::Null),
},
"interval" => match idx {
0 => assert_eq!(
value,
ValueRef::Interval {
months: 0,
days: 0,
nanos: 0
}
),
1 => assert_eq!(
value,
ValueRef::Interval {
months: 999,
days: 999,
nanos: 999999999000
}
),
_ => assert_eq!(value, ValueRef::Null),
},
_ => todo!("{column:?}"),
}
}
80 changes: 77 additions & 3 deletions src/types/chrono.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
//! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types.

use chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use num_integer::Integer;

use crate::{
types::{FromSql, FromSqlError, FromSqlResult, TimeUnit, ToSql, ToSqlOutput, ValueRef},
Result,
};

use super::Value;

/// ISO 8601 calendar date without timezone => "YYYY-MM-DD"
impl ToSql for NaiveDate {
#[inline]
Expand Down Expand Up @@ -126,13 +129,55 @@ impl FromSql for DateTime<Local> {
}
}

impl FromSql for Duration {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value {
ValueRef::Interval { months, days, nanos } => {
let days = days + (months * 30);
let (additional_seconds, nanos) = nanos.div_mod_floor(&NANOS_PER_SECOND);
let seconds = additional_seconds + (i64::from(days) * 24 * 3600);

match nanos.try_into() {
Ok(nanos) => {
if let Some(duration) = Duration::new(seconds, nanos) {
Ok(duration)
} else {
Err(FromSqlError::Other("Invalid duration".into()))
}
}
Err(err) => Err(FromSqlError::Other(format!("Invalid duration: {err}").into())),
}
}
_ => Err(FromSqlError::InvalidType),
}
}
}

const DAYS_PER_MONTH: i64 = 30;
const SECONDS_PER_DAY: i64 = 24 * 3600;
const NANOS_PER_SECOND: i64 = 1_000_000_000;
const NANOS_PER_DAY: i64 = SECONDS_PER_DAY * NANOS_PER_SECOND;

impl ToSql for Duration {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
let nanos = self.num_nanoseconds().unwrap();
let (days, nanos) = nanos.div_mod_floor(&NANOS_PER_DAY);
let (months, days) = days.div_mod_floor(&DAYS_PER_MONTH);
Ok(ToSqlOutput::Owned(Value::Interval {
months: months.try_into().unwrap(),
days: days.try_into().unwrap(),
nanos,
}))
}
}

#[cfg(test)]
mod test {
use crate::{
types::{FromSql, ValueRef},
types::{FromSql, ToSql, ToSqlOutput, ValueRef},
Connection, Result,
};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, TimeZone, Utc};

fn checked_memory_handle() -> Result<Connection> {
let db = Connection::open_in_memory()?;
Expand Down Expand Up @@ -216,6 +261,35 @@ mod test {
Ok(())
}

#[test]
fn test_time_delta_roundtrip() {
roundtrip_type(TimeDelta::new(3600, 0).unwrap());
roundtrip_type(TimeDelta::new(3600, 1000).unwrap());
}

#[test]
fn test_time_delta() -> Result<()> {
let db = checked_memory_handle()?;
let td = TimeDelta::new(3600, 0).unwrap();

let row: Result<TimeDelta> = db.query_row("SELECT ?", [td], |row| Ok(row.get(0)))?;

assert_eq!(row.unwrap(), td);

Ok(())
}

fn roundtrip_type<T: FromSql + ToSql + Eq + std::fmt::Debug>(td: T) {
let sqled = td.to_sql().unwrap();
let value = match sqled {
ToSqlOutput::Borrowed(v) => v,
ToSqlOutput::Owned(ref v) => ValueRef::from(v),
};
let reversed = FromSql::column_result(value).unwrap();

assert_eq!(td, reversed);
}

#[test]
fn test_date_time_local() -> Result<()> {
let db = checked_memory_handle()?;
Expand Down
3 changes: 3 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ pub enum Type {
Date32,
/// TIME64
Time64,
/// INTERVAL
Interval,
/// Any
Any,
}
Expand All @@ -170,6 +172,7 @@ impl fmt::Display for Type {
Type::Blob => f.pad("Blob"),
Type::Date32 => f.pad("Date32"),
Type::Time64 => f.pad("Time64"),
Type::Interval => f.pad("Interval"),
Type::Any => f.pad("Any"),
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ pub enum Value {
Date32(i32),
/// The value is a time64
Time64(TimeUnit, i64),
/// The value is an interval (month, day, nano)
Interval {
/// months
months: i32,
/// days
days: i32,
/// nanos
nanos: i64,
},
}

impl From<Null> for Value {
Expand Down Expand Up @@ -212,6 +221,7 @@ impl Value {
Value::Blob(_) => Type::Blob,
Value::Date32(_) => Type::Date32,
Value::Time64(..) => Type::Time64,
Value::Interval { .. } => Type::Interval,
}
}
}
12 changes: 12 additions & 0 deletions src/types/value_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ pub enum ValueRef<'a> {
Date32(i32),
/// The value is a time64
Time64(TimeUnit, i64),
/// The value is an interval (month, day, nano)
Interval {
/// months
months: i32,
/// days
days: i32,
/// nanos
nanos: i64,
},
}

impl ValueRef<'_> {
Expand All @@ -87,6 +96,7 @@ impl ValueRef<'_> {
ValueRef::Blob(_) => Type::Blob,
ValueRef::Date32(_) => Type::Date32,
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
}
}
}
Expand Down Expand Up @@ -140,6 +150,7 @@ impl From<ValueRef<'_>> for Value {
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
ValueRef::Date32(d) => Value::Date32(d),
ValueRef::Time64(t, d) => Value::Time64(t, d),
ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos },
}
}
}
Expand Down Expand Up @@ -181,6 +192,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Blob(ref b) => ValueRef::Blob(b),
Value::Date32(d) => ValueRef::Date32(d),
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
}
}
}
Expand Down

0 comments on commit 3f1ea0a

Please sign in to comment.