diff --git a/Cargo.toml b/Cargo.toml index 8d04f6799..a38ce8bca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ iceberg = { version = "0.3.0", path = "./crates/iceberg" } iceberg-catalog-rest = { version = "0.3.0", path = "./crates/catalog/rest" } iceberg-catalog-hms = { version = "0.3.0", path = "./crates/catalog/hms" } iceberg-catalog-memory = { version = "0.3.0", path = "./crates/catalog/memory" } +iceberg-datafusion = { version = "0.3.0", path = "./crates/integrations/datafusion" } itertools = "0.13" log = "0.4" mockito = "1" diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index c50b32efb..576acea6b 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -18,6 +18,7 @@ use std::any::Any; use std::pin::Pin; use std::sync::Arc; +use std::vec; use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; @@ -44,17 +45,25 @@ pub(crate) struct IcebergTableScan { /// Stores certain, often expensive to compute, /// plan properties used in query optimization. plan_properties: PlanProperties, + /// Projection column names, None means all columns + projection: Option>, } impl IcebergTableScan { /// Creates a new [`IcebergTableScan`] object. - pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self { + pub(crate) fn new( + table: Table, + schema: ArrowSchemaRef, + projection: Option<&Vec>, + ) -> Self { let plan_properties = Self::compute_properties(schema.clone()); + let projection = get_column_names(schema.clone(), projection); Self { table, schema, plan_properties, + projection, } } @@ -100,7 +109,7 @@ impl ExecutionPlan for IcebergTableScan { _partition: usize, _context: Arc, ) -> DFResult { - let fut = get_batch_stream(self.table.clone()); + let fut = get_batch_stream(self.table.clone(), self.projection.clone()); let stream = futures::stream::once(fut).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( @@ -116,7 +125,13 @@ impl DisplayAs for IcebergTableScan { _t: datafusion::physical_plan::DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - write!(f, "IcebergTableScan") + write!( + f, + "IcebergTableScan projection:[{}]", + self.projection + .clone() + .map_or(String::new(), |v| v.join(",")) + ) } } @@ -127,8 +142,13 @@ impl DisplayAs for IcebergTableScan { /// and then converts it into a stream of Arrow [`RecordBatch`]es. async fn get_batch_stream( table: Table, + column_names: Option>, ) -> DFResult> + Send>>> { - let table_scan = table.scan().build().map_err(to_datafusion_error)?; + let scan_builder = match column_names { + Some(column_names) => table.scan().select(column_names), + None => table.scan().select_all(), + }; + let table_scan = scan_builder.build().map_err(to_datafusion_error)?; let stream = table_scan .to_arrow() @@ -138,3 +158,14 @@ async fn get_batch_stream( Ok(Box::pin(stream)) } + +fn get_column_names( + schema: ArrowSchemaRef, + projection: Option<&Vec>, +) -> Option> { + projection.map(|v| { + v.iter() + .map(|p| schema.field(*p).name().clone()) + .collect::>() + }) +} diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs index 7ff7b2211..8d70d9488 100644 --- a/crates/integrations/datafusion/src/table.rs +++ b/crates/integrations/datafusion/src/table.rs @@ -75,13 +75,14 @@ impl TableProvider for IcebergTableProvider { async fn scan( &self, _state: &dyn Session, - _projection: Option<&Vec>, + projection: Option<&Vec>, _filters: &[Expr], _limit: Option, ) -> DFResult> { Ok(Arc::new(IcebergTableScan::new( self.table.clone(), self.schema.clone(), + projection, ))) } } diff --git a/crates/integrations/datafusion/tests/integration_datafusion_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_test.rs index 9e62930fd..d6e22d044 100644 --- a/crates/integrations/datafusion/tests/integration_datafusion_test.rs +++ b/crates/integrations/datafusion/tests/integration_datafusion_test.rs @@ -19,11 +19,13 @@ use std::collections::HashMap; use std::sync::Arc; +use std::vec; +use datafusion::arrow::array::{Array, StringArray}; use datafusion::arrow::datatypes::DataType; use datafusion::execution::context::SessionContext; use iceberg::io::FileIOBuilder; -use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; +use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Type}; use iceberg::{Catalog, NamespaceIdent, Result, TableCreation}; use iceberg_catalog_memory::MemoryCatalog; use iceberg_datafusion::IcebergCatalogProvider; @@ -39,6 +41,13 @@ fn get_iceberg_catalog() -> MemoryCatalog { MemoryCatalog::new(file_io, Some(temp_path())) } +fn get_struct_type() -> StructType { + StructType::new(vec![ + NestedField::required(4, "s_foo1", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(5, "s_foo2", Type::Primitive(PrimitiveType::String)).into(), + ]) +} + async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent) -> Result<()> { let properties = HashMap::new(); @@ -47,14 +56,21 @@ async fn set_test_namespace(catalog: &MemoryCatalog, namespace: &NamespaceIdent) Ok(()) } -fn set_table_creation(location: impl ToString, name: impl ToString) -> Result { - let schema = Schema::builder() - .with_schema_id(0) - .with_fields(vec![ - NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(), - NestedField::required(2, "bar", Type::Primitive(PrimitiveType::String)).into(), - ]) - .build()?; +fn get_table_creation( + location: impl ToString, + name: impl ToString, + schema: Option, +) -> Result { + let schema = match schema { + None => Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "foo1", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "foo2", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?, + Some(schema) => schema, + }; let creation = TableCreation::builder() .location(location.to_string()) @@ -72,7 +88,7 @@ async fn test_provider_get_table_schema() -> Result<()> { let namespace = NamespaceIdent::new("test_provider_get_table_schema".to_string()); set_test_namespace(&iceberg_catalog, &namespace).await?; - let creation = set_table_creation(temp_path(), "my_table")?; + let creation = get_table_creation(temp_path(), "my_table", None)?; iceberg_catalog.create_table(&namespace, creation).await?; let client = Arc::new(iceberg_catalog); @@ -87,7 +103,7 @@ async fn test_provider_get_table_schema() -> Result<()> { let table = schema.table("my_table").await.unwrap().unwrap(); let table_schema = table.schema(); - let expected = [("foo", &DataType::Int32), ("bar", &DataType::Utf8)]; + let expected = [("foo1", &DataType::Int32), ("foo2", &DataType::Utf8)]; for (field, exp) in table_schema.fields().iter().zip(expected.iter()) { assert_eq!(field.name(), exp.0); @@ -104,7 +120,7 @@ async fn test_provider_list_table_names() -> Result<()> { let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string()); set_test_namespace(&iceberg_catalog, &namespace).await?; - let creation = set_table_creation(temp_path(), "my_table")?; + let creation = get_table_creation(temp_path(), "my_table", None)?; iceberg_catalog.create_table(&namespace, creation).await?; let client = Arc::new(iceberg_catalog); @@ -130,7 +146,6 @@ async fn test_provider_list_schema_names() -> Result<()> { let namespace = NamespaceIdent::new("test_provider_list_schema_names".to_string()); set_test_namespace(&iceberg_catalog, &namespace).await?; - set_table_creation("test_provider_list_schema_names", "my_table")?; let client = Arc::new(iceberg_catalog); let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); @@ -147,3 +162,71 @@ async fn test_provider_list_schema_names() -> Result<()> { .all(|item| result.contains(&item.to_string()))); Ok(()) } + +#[tokio::test] +async fn test_table_projection() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("ns".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + + let schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "foo1", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "foo2", Type::Primitive(PrimitiveType::String)).into(), + NestedField::optional(3, "foo3", Type::Struct(get_struct_type())).into(), + ]) + .build()?; + let creation = get_table_creation(temp_path(), "t1", Some(schema))?; + iceberg_catalog.create_table(&namespace, creation).await?; + + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + let table_df = ctx.table("catalog.ns.t1").await.unwrap(); + + let records = table_df + .clone() + .explain(false, false) + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(1, records.len()); + let record = &records[0]; + // the first column is plan_type, the second column plan string. + let s = record + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, s.len()); + // the first row is logical_plan, the second row is physical_plan + assert_eq!( + "IcebergTableScan projection:[foo1,foo2,foo3]", + s.value(1).trim() + ); + + // datafusion doesn't support query foo3.s_foo1, use foo3 instead + let records = table_df + .select_columns(&["foo1", "foo3"]) + .unwrap() + .explain(false, false) + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(1, records.len()); + let record = &records[0]; + let s = record + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, s.len()); + assert_eq!("IcebergTableScan projection:[foo1,foo3]", s.value(1).trim()); + + Ok(()) +}