From 0b0659f730241143517cd79dacac663425a11ba0 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Thu, 5 Sep 2024 14:19:10 -0700 Subject: [PATCH] init --- .../substrait/src/logical_plan/consumer.rs | 30 ++++- .../substrait/tests/cases/bugs_converage.rs | 54 +++++++++ datafusion/substrait/tests/cases/mod.rs | 1 + .../testdata/extra_projection_with_input.json | 113 ++++++++++++++++++ 4 files changed, 193 insertions(+), 5 deletions(-) create mode 100644 datafusion/substrait/tests/cases/bugs_converage.rs create mode 100644 datafusion/substrait/tests/testdata/extra_projection_with_input.json diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 21bef3c2c98e..3efc3f2fccc6 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -228,14 +228,32 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)), + LogicalPlan::Projection(p) => { + Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, + p.input + )? + )) + }, LogicalPlan::Aggregate(a) => { let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + Ok(LogicalPlan::Aggregate( + Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)? + )) }, // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions( + plan.schema().columns().iter().map(|c| col(c.to_owned())), + plan.schema(), + &renamed_schema + )?, + Arc::new(plan) + )? + )), } } }, @@ -363,7 +381,6 @@ fn make_renamed_schema( } let mut name_idx = 0; - let (qualifiers, fields): (_, Vec) = schema .iter() .map(|(q, f)| { @@ -390,7 +407,6 @@ fn make_renamed_schema( name_idx, dfs_names.len()); } - DFSchema::from_field_specific_qualified_schema( qualifiers, &Arc::new(Schema::new(fields)), @@ -412,6 +428,10 @@ pub async fn from_substrait_rel( ); let mut names: HashSet = HashSet::new(); let mut exprs: Vec = vec![]; + input.schema().iter().for_each(|(qualifier, field)| { + exprs.push(col(Column::from((qualifier, field)))) + }); + for e in &p.expressions { let x = from_substrait_rex(ctx, e, input.clone().schema(), extensions) diff --git a/datafusion/substrait/tests/cases/bugs_converage.rs b/datafusion/substrait/tests/cases/bugs_converage.rs new file mode 100644 index 000000000000..da3d9825f541 --- /dev/null +++ b/datafusion/substrait/tests/cases/bugs_converage.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for bugs in substrait + +#[cfg(test)] +mod tests { + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::Result; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::Plan; + #[tokio::test] + async fn extra_projection_with_input() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("user_id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("paid_for_service", DataType::Boolean, false), + ]); + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + ctx.register_table("users", Arc::new(memory_table))?; + let path = "tests/testdata/extra_projection_with_input.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{}", plan); + assert_eq!(plan_str, "Projection: users.user_id, users.name, users.paid_for_service, row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS row_number\ + \n WindowAggr: windowExpr=[[row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: users projection=[user_id, name, paid_for_service]"); + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d3ea7695e4b9..816790388660 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod bugs_converage; mod consumer_integration; mod function_test; mod logical_plans; diff --git a/datafusion/substrait/tests/testdata/extra_projection_with_input.json b/datafusion/substrait/tests/testdata/extra_projection_with_input.json new file mode 100644 index 000000000000..41b93a8f2e10 --- /dev/null +++ b/datafusion/substrait/tests/testdata/extra_projection_with_input.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "row_number" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "user_id", + "name", + "paid_for_service" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "users" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 1, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + } + ], + "upperBound": { + "unbounded": {} + }, + "lowerBound": { + "unbounded": {} + }, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + } + ] + } + }, + "names": [ + "user_id", + "name", + "paid_for_service", + "row_number" + ] + } + } + ], + "version": { + "minorNumber": 52, + "producer": "spark-substrait-gateway" + } +} \ No newline at end of file