Skip to content

Commit

Permalink
Merge branch 'main' into feature/unparse-timstamp
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Jun 26, 2024
2 parents cdd4139 + 82f7bf4 commit c2e27e0
Show file tree
Hide file tree
Showing 20 changed files with 390 additions and 182 deletions.
9 changes: 5 additions & 4 deletions datafusion-cli/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,14 @@ mod tests {
fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
let mut ctx = SessionContext::new();
ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new(
ctx.state().catalog_list(),
ctx.state().catalog_list().clone(),
ctx.state_weak_ref(),
)));

let provider =
&DynamicFileCatalog::new(ctx.state().catalog_list(), ctx.state_weak_ref())
as &dyn CatalogProviderList;
let provider = &DynamicFileCatalog::new(
ctx.state().catalog_list().clone(),
ctx.state_weak_ref(),
) as &dyn CatalogProviderList;
let catalog = provider
.catalog(provider.catalog_names().first().unwrap())
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion datafusion-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ async fn main_inner() -> Result<()> {
ctx.refresh_catalogs().await?;
// install dynamic catalog provider that knows how to open files
ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new(
ctx.state().catalog_list(),
ctx.state().catalog_list().clone(),
ctx.state_weak_ref(),
)));
// register `parquet_metadata` table function to get metadata from parquet files
Expand Down
5 changes: 3 additions & 2 deletions datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@

# DataFusion Examples

This crate includes several examples of how to use various DataFusion APIs and help you on your way.
This crate includes end to end, highly commented examples of how to use
various DataFusion APIs to help you get started.

## Prerequisites:

Run `git submodule update --init` to init test files.

## Running Examples

To run the examples, use the `cargo run` command, such as:
To run an example, use the `cargo run` command, such as:

```bash
git clone https://github.com/apache/datafusion
Expand Down
179 changes: 161 additions & 18 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;

use crate::arrow_datafusion_err;
use crate::cast::{
as_decimal128_array, as_decimal256_array, as_dictionary_array,
as_fixed_size_binary_array, as_fixed_size_list_array,
Expand Down Expand Up @@ -1077,7 +1078,7 @@ impl ScalarValue {
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
_ => {
return _not_impl_err!(
"Can't create a negative one scalar from data_type \"{datatype:?}\""
"Can't create a ten scalar from data_type \"{datatype:?}\""
);
}
})
Expand Down Expand Up @@ -1168,6 +1169,13 @@ impl ScalarValue {

/// Calculate arithmetic negation for a scalar value
pub fn arithmetic_negate(&self) -> Result<Self> {
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
v: T,
ctx: impl Fn() -> String,
) -> Result<T> {
v.neg_checked()
.map_err(|e| arrow_datafusion_err!(e).context(ctx()))
}
match self {
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
Expand All @@ -1177,40 +1185,91 @@ impl ScalarValue {
| ScalarValue::Float64(None) => Ok(self.clone()),
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))),
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))),
ScalarValue::IntervalYearMonth(Some(v)) => {
Ok(ScalarValue::IntervalYearMonth(Some(-v)))
}
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))),
ScalarValue::IntervalYearMonth(Some(v)) => Ok(
ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || {
format!("In negation of IntervalYearMonth({v})")
})?)),
),
ScalarValue::IntervalDayTime(Some(v)) => {
let (days, ms) = IntervalDayTimeType::to_parts(*v);
let val = IntervalDayTimeType::make_value(-days, -ms);
let val = IntervalDayTimeType::make_value(
neg_checked_with_ctx(days, || {
format!("In negation of days {days} in IntervalDayTime")
})?,
neg_checked_with_ctx(ms, || {
format!("In negation of milliseconds {ms} in IntervalDayTime")
})?,
);
Ok(ScalarValue::IntervalDayTime(Some(val)))
}
ScalarValue::IntervalMonthDayNano(Some(v)) => {
let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v);
let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos);
let val = IntervalMonthDayNanoType::make_value(
neg_checked_with_ctx(months, || {
format!("In negation of months {months} of IntervalMonthDayNano")
})?,
neg_checked_with_ctx(days, || {
format!("In negation of days {days} of IntervalMonthDayNano")
})?,
neg_checked_with_ctx(nanos, || {
format!("In negation of nanos {nanos} of IntervalMonthDayNano")
})?,
);
Ok(ScalarValue::IntervalMonthDayNano(Some(val)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
Ok(ScalarValue::Decimal128(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of Decimal128({v}, {precision}, {scale})")
})?),
*precision,
*scale,
))
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal256(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of Decimal256({v}, {precision}, {scale})")
})?),
*precision,
*scale,
))
}
ScalarValue::Decimal256(Some(v), precision, scale) => Ok(
ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale),
),
ScalarValue::TimestampSecond(Some(v), tz) => {
Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampSecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampNanosecond(Some(v), tz) => {
Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampNanosecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampNanoSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampMicrosecond(Some(v), tz) => {
Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampMicrosecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampMicroSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampMillisecond(Some(v), tz) => {
Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampMillisecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampMilliSecond({v})")
})?),
tz.clone(),
))
}
value => _internal_err!(
"Can not run arithmetic negative on scalar value {value:?}"
Expand Down Expand Up @@ -3501,6 +3560,7 @@ mod tests {
use crate::assert_batches_eq;
use arrow::buffer::OffsetBuffer;
use arrow::compute::{is_null, kernels};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use arrow_buffer::Buffer;
use arrow_schema::Fields;
Expand Down Expand Up @@ -5494,6 +5554,89 @@ mod tests {
Ok(())
}

#[test]
#[allow(arithmetic_overflow)] // we want to test them
fn test_scalar_negative_overflows() -> Result<()> {
macro_rules! test_overflow_on_value {
($($val:expr),* $(,)?) => {$(
{
let value: ScalarValue = $val;
let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}");
let root_err = err.find_root();
match root_err{
DataFusionError::ArrowError(
ArrowError::ComputeError(_),
_,
) => {}
_ => return Err(err),
};
}
)*};
}
test_overflow_on_value!(
// the integers
i8::MIN.into(),
i16::MIN.into(),
i32::MIN.into(),
i64::MIN.into(),
// for decimals, only value needs to be tested
ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?,
ScalarValue::Decimal256(Some(i256::MIN), 20, 5),
// interval, check all possible values
ScalarValue::IntervalYearMonth(Some(i32::MIN)),
ScalarValue::new_interval_dt(i32::MIN, 999),
ScalarValue::new_interval_dt(1, i32::MIN),
ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456),
ScalarValue::new_interval_mdn(12, i32::MIN, 123_456),
ScalarValue::new_interval_mdn(12, 15, i64::MIN),
// tz doesn't matter when negating
ScalarValue::TimestampSecond(Some(i64::MIN), None),
ScalarValue::TimestampMillisecond(Some(i64::MIN), None),
ScalarValue::TimestampMicrosecond(Some(i64::MIN), None),
ScalarValue::TimestampNanosecond(Some(i64::MIN), None),
);

let float_cases = [
(
ScalarValue::Float16(Some(f16::MIN)),
ScalarValue::Float16(Some(f16::MAX)),
),
(
ScalarValue::Float16(Some(f16::MAX)),
ScalarValue::Float16(Some(f16::MIN)),
),
(f32::MIN.into(), f32::MAX.into()),
(f32::MAX.into(), f32::MIN.into()),
(f64::MIN.into(), f64::MAX.into()),
(f64::MAX.into(), f64::MIN.into()),
];
// skip float 16 because they aren't supported
for (test, expected) in float_cases.into_iter().skip(2) {
assert_eq!(test.arithmetic_negate()?, expected);
}
Ok(())
}

#[test]
#[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")]
fn f16_test_overflow() {
// TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
let cases = [
(
ScalarValue::Float16(Some(f16::MIN)),
ScalarValue::Float16(Some(f16::MAX)),
),
(
ScalarValue::Float16(Some(f16::MAX)),
ScalarValue::Float16(Some(f16::MIN)),
),
];

for (test, expected) in cases {
assert_eq!(test.arithmetic_negate().unwrap(), expected);
}
}

macro_rules! expect_operation_error {
($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => {
#[test]
Expand Down
38 changes: 13 additions & 25 deletions datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,32 +354,11 @@ macro_rules! get_statistics {
))),
DataType::Timestamp(unit, timezone) =>{
let iter = [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied());

Ok(match unit {
TimeUnit::Second => {
Arc::new(match timezone {
Some(tz) => TimestampSecondArray::from_iter(iter).with_timezone(tz.clone()),
None => TimestampSecondArray::from_iter(iter),
})
}
TimeUnit::Millisecond => {
Arc::new(match timezone {
Some(tz) => TimestampMillisecondArray::from_iter(iter).with_timezone(tz.clone()),
None => TimestampMillisecondArray::from_iter(iter),
})
}
TimeUnit::Microsecond => {
Arc::new(match timezone {
Some(tz) => TimestampMicrosecondArray::from_iter(iter).with_timezone(tz.clone()),
None => TimestampMicrosecondArray::from_iter(iter),
})
}
TimeUnit::Nanosecond => {
Arc::new(match timezone {
Some(tz) => TimestampNanosecondArray::from_iter(iter).with_timezone(tz.clone()),
None => TimestampNanosecondArray::from_iter(iter),
})
}
TimeUnit::Second => Arc::new(TimestampSecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
})
},
DataType::Time32(unit) => {
Expand Down Expand Up @@ -713,6 +692,15 @@ macro_rules! get_data_page_statistics {
)),
Some(DataType::Float32) => Ok(Arc::new(Float32Array::from_iter([<$stat_type_prefix Float32DataPageStatsIterator>]::new($iterator).flatten()))),
Some(DataType::Float64) => Ok(Arc::new(Float64Array::from_iter([<$stat_type_prefix Float64DataPageStatsIterator>]::new($iterator).flatten()))),
Some(DataType::Timestamp(unit, timezone)) => {
let iter = [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten();
Ok(match unit {
TimeUnit::Second => Arc::new(TimestampSecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())),
})
},
_ => unimplemented!()
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,7 @@ mod tests {

let catalog_list_weak = {
let state = ctx.state.read();
Arc::downgrade(&state.catalog_list())
Arc::downgrade(state.catalog_list())
};

drop(ctx);
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,8 +807,8 @@ impl SessionState {
}

/// Return catalog list
pub fn catalog_list(&self) -> Arc<dyn CatalogProviderList> {
self.catalog_list.clone()
pub fn catalog_list(&self) -> &Arc<dyn CatalogProviderList> {
&self.catalog_list
}

/// set the catalog list
Expand Down Expand Up @@ -840,8 +840,8 @@ impl SessionState {
}

/// Return [SerializerRegistry] for extensions
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
self.serializer_registry.clone()
pub fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry> {
&self.serializer_registry
}

/// Return version of the cargo package that produced this query
Expand Down
Loading

0 comments on commit c2e27e0

Please sign in to comment.