diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 67eb84bd857fe..b22b0ebd9a7ae 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -2444,8 +2444,9 @@ class BaseQuery { expressions: { column_aliased: '{{expr}} {{quoted_alias}}', case: 'CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END', - binary: '{{ left }} {{ op }} {{ right }}', + binary: '({{ left }} {{ op }} {{ right }})', sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}', + cast: 'CAST({{ expr }} AS {{ data_type }})', }, quotes: { identifiers: '"', diff --git a/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.js b/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.js index 82a0c65a3dbdc..b1711620fbb5f 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.js @@ -149,6 +149,10 @@ export class BigqueryQuery extends BaseQuery { const templates = super.sqlTemplates(); templates.quotes.identifiers = '`'; templates.quotes.escape = '\\`'; + templates.functions.DATETRUNC = 'DATETIME_TRUNC(CAST({{ args[1] }} AS DATETIME), {{ date_part }})'; + templates.expressions.binary = '{% if op == \'%\' %}MOD({{ left }}, {{ right }}){% else %}({{ left }} {{ op }} {{ right }}){% endif %}'; + templates.expressions.interval = 'INTERVAL {{ interval }}'; + templates.expressions.extract = 'EXTRACT({% if date_part == \'DOW\' %}DAYOFWEEK{% else %}{{ date_part }}{% endif %} FROM {{ expr }})'; return templates; } } diff --git a/packages/cubejs-schema-compiler/src/adapter/PostgresQuery.js b/packages/cubejs-schema-compiler/src/adapter/PostgresQuery.js index d0114aca423ae..96849130d856e 100644 --- a/packages/cubejs-schema-compiler/src/adapter/PostgresQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/PostgresQuery.js @@ -50,6 +50,10 @@ export class PostgresQuery extends BaseQuery { templates.params.param = '${{ param_index + 1 }}'; templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})'; templates.functions.CONCAT = 'CONCAT({% for arg in args %}CAST({{arg}} AS TEXT){% if not loop.last %},{% endif %}{% endfor %})'; + templates.functions.DATEPART = 'DATE_PART({{ args_concat }})'; + templates.expressions.interval = 'INTERVAL \'{{ interval }}\''; + templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})'; + return templates; } } diff --git a/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.js b/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.js index ce048358efc0f..04fdc93fcf234 100644 --- a/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.js @@ -112,6 +112,9 @@ export class PrestodbQuery extends BaseQuery { sqlTemplates() { const templates = super.sqlTemplates(); templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})'; + templates.functions.DATEPART = 'DATE_PART({{ args_concat }})'; + templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})'; + templates.expressions.interval = 'INTERVAL \'{{ num }}\' {{ date_part }}'; return templates; } } diff --git a/packages/cubejs-schema-compiler/src/adapter/SnowflakeQuery.js b/packages/cubejs-schema-compiler/src/adapter/SnowflakeQuery.js index 8e5094f282490..b40b42b043230 100644 --- a/packages/cubejs-schema-compiler/src/adapter/SnowflakeQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/SnowflakeQuery.js @@ -53,6 +53,9 @@ export class SnowflakeQuery extends BaseQuery { sqlTemplates() { const templates = super.sqlTemplates(); templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})'; + templates.functions.DATEPART = 'DATE_PART({{ args_concat }})'; + templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})'; + templates.expressions.interval = 'INTERVAL \'{{ interval }}\''; return templates; } } diff --git a/rust/cubesql/Cargo.lock b/rust/cubesql/Cargo.lock index 9d85e91aff29e..2909452314560 100644 --- a/rust/cubesql/Cargo.lock +++ b/rust/cubesql/Cargo.lock @@ -970,7 +970,7 @@ dependencies = [ [[package]] name = "cube-ext" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=915a600ea4d3b66161cd77ff94747960f840816e#915a600ea4d3b66161cd77ff94747960f840816e" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80#00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80" dependencies = [ "arrow", "chrono", @@ -1103,7 +1103,7 @@ dependencies = [ [[package]] name = "datafusion" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=915a600ea4d3b66161cd77ff94747960f840816e#915a600ea4d3b66161cd77ff94747960f840816e" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80#00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80" dependencies = [ "ahash", "arrow", @@ -1136,7 +1136,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=915a600ea4d3b66161cd77ff94747960f840816e#915a600ea4d3b66161cd77ff94747960f840816e" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80#00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80" dependencies = [ "arrow", "ordered-float 2.10.0", @@ -1147,7 +1147,7 @@ dependencies = [ [[package]] name = "datafusion-data-access" version = "1.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=915a600ea4d3b66161cd77ff94747960f840816e#915a600ea4d3b66161cd77ff94747960f840816e" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80#00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80" dependencies = [ "async-trait", "chrono", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=915a600ea4d3b66161cd77ff94747960f840816e#915a600ea4d3b66161cd77ff94747960f840816e" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80#00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80" dependencies = [ "ahash", "arrow", @@ -1171,7 +1171,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "7.0.0" -source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=915a600ea4d3b66161cd77ff94747960f840816e#915a600ea4d3b66161cd77ff94747960f840816e" +source = "git+https://github.com/cube-js/arrow-datafusion.git?rev=00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80#00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80" dependencies = [ "ahash", "arrow", diff --git a/rust/cubesql/cubesql/Cargo.toml b/rust/cubesql/cubesql/Cargo.toml index 992306b636d56..83e25dbe94911 100644 --- a/rust/cubesql/cubesql/Cargo.toml +++ b/rust/cubesql/cubesql/Cargo.toml @@ -9,7 +9,7 @@ documentation = "https://cube.dev/docs" homepage = "https://cube.dev" [dependencies] -datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "915a600ea4d3b66161cd77ff94747960f840816e", default-features = false, features = ["regex_expressions", "unicode_expressions"] } +datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80", default-features = false, features = ["regex_expressions", "unicode_expressions"] } anyhow = "1.0" thiserror = "1.0" cubeclient = { path = "../cubeclient" } diff --git a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs index f5c2b9120098b..a361f5df51d54 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs @@ -979,6 +979,7 @@ pub fn transform_response( (FieldValue::String(s), builder) => { let timestamp = NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%dT%H:%M:%S.%f") .or_else(|_| NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%d %H:%M:%S.%f")) + .or_else(|_| NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%dT%H:%M:%S")) .map_err(|e| { DataFusionError::Execution(format!( "Can't parse timestamp: '{}': {}", diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index 376b02783d4e8..47c303d96cac4 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -1,6 +1,6 @@ use crate::{ compile::{ - engine::df::scan::{CubeScanNode, MemberField, WrappedSelectNode}, + engine::df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode}, rewrite::WrappedSelectType, }, sql::AuthContextRef, @@ -16,7 +16,7 @@ use datafusion::{ plan::Extension, replace_col, replace_col_to_expr, Column, DFSchema, DFSchemaRef, Expr, LogicalPlan, UserDefinedLogicalNode, }, - physical_plan::aggregates::AggregateFunction, + physical_plan::{aggregates::AggregateFunction, functions::BuiltinScalarFunction}, scalar::ScalarValue, }; use itertools::Itertools; @@ -144,6 +144,10 @@ pub struct SqlGenerationResult { pub request: V1LoadRequestQuery, } +lazy_static! { + static ref DATE_PART_REGEX: Regex = Regex::new("^[A-Za-z_ ]+$").unwrap(); +} + impl CubeScanWrapperNode { pub async fn generate_sql( &self, @@ -934,7 +938,61 @@ impl CubeScanWrapperNode { ); Ok((resulting_sql, sql_query)) } - // Expr::Cast { .. } => {} + Expr::Cast { expr, data_type } => { + let (expr, sql_query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + *expr, + ungrouped_scan_node.clone(), + ) + .await?; + let data_type = match data_type { + DataType::Null => "NULL", + DataType::Boolean => "BOOLEAN", + DataType::Int8 => "INTEGER", + DataType::Int16 => "INTEGER", + DataType::Int32 => "INTEGER", + DataType::Int64 => "INTEGER", + DataType::UInt8 => "INTEGER", + DataType::UInt16 => "INTEGER", + DataType::UInt32 => "INTEGER", + DataType::UInt64 => "INTEGER", + DataType::Float16 => "FLOAT", + DataType::Float32 => "FLOAT", + DataType::Float64 => "DOUBLE", + DataType::Timestamp(_, _) => "TIMESTAMP", + DataType::Date32 => "DATE", + DataType::Date64 => "DATE", + DataType::Time32(_) => "TIME", + DataType::Time64(_) => "TIME", + DataType::Duration(_) => "INTERVAL", + DataType::Interval(_) => "INTERVAL", + DataType::Binary => "BYTEA", + DataType::FixedSizeBinary(_) => "BYTEA", + DataType::Utf8 => "TEXT", + DataType::LargeUtf8 => "TEXT", + x => { + return Err(DataFusionError::Execution(format!( + "Can't generate SQL for cast: type isn't supported: {:?}", + x + ))); + } + }; + let resulting_sql = Self::escape_interpolation_quotes( + sql_generator + .get_sql_templates() + .cast_expr(expr, data_type.to_string()) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for cast: {}", + e + )) + })?, + ungrouped_scan_node.is_some(), + ); + Ok((resulting_sql, sql_query)) + } // Expr::TryCast { .. } => {} Expr::Sort { expr, @@ -1024,7 +1082,41 @@ impl CubeScanWrapperNode { // ScalarValue::TimestampMicrosecond(_, _) => {} // ScalarValue::TimestampNanosecond(_, _) => {} // ScalarValue::IntervalYearMonth(_) => {} - // ScalarValue::IntervalDayTime(_) => {} + ScalarValue::IntervalDayTime(x) => { + if let Some(x) = x { + let days = x >> 32; + let millis = x & 0xFFFFFFFF; + if days > 0 && millis > 0 { + return Err(DataFusionError::Internal(format!( + "Can't generate SQL for interval: mixed intervals aren't supported: {} days {} millis encoded as {}", + days, millis, x + ))); + } + let (num, date_part) = if days > 0 { + (days, "DAY") + } else { + (millis, "MILLISECOND") + }; + let interval = format!("{} {}", num, date_part); + ( + Self::escape_interpolation_quotes( + sql_generator + .get_sql_templates() + .interval_expr(interval, num, date_part.to_string()) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for interval: {}", + e + )) + })?, + ungrouped_scan_node.is_some(), + ), + sql_query, + ) + } else { + ("NULL".to_string(), sql_query) + } + } // ScalarValue::IntervalMonthDayNano(_) => {} // ScalarValue::Struct(_, _) => {} x => { @@ -1036,6 +1128,60 @@ impl CubeScanWrapperNode { }) } Expr::ScalarFunction { fun, args } => { + if let BuiltinScalarFunction::DatePart = &fun { + if args.len() >= 2 { + match &args[0] { + Expr::Literal(ScalarValue::Utf8(Some(date_part))) => { + // Security check to prevent SQL injection + if !DATE_PART_REGEX.is_match(date_part) { + return Err(DataFusionError::Internal(format!( + "Can't generate SQL for scalar function: date part '{}' is not supported", + date_part + ))); + } + let (arg_sql, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + args[1].clone(), + ungrouped_scan_node.clone(), + ) + .await?; + return Ok(( + Self::escape_interpolation_quotes( + sql_generator + .get_sql_templates() + .extract_expr(date_part.to_string(), arg_sql) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for scalar function: {}", + e + )) + })?, + ungrouped_scan_node.is_some(), + ), + query, + )); + } + _ => {} + } + } + } + let date_part = if let BuiltinScalarFunction::DateTrunc = &fun { + match &args[0] { + Expr::Literal(ScalarValue::Utf8(Some(date_part))) => { + // Security check to prevent SQL injection + if DATE_PART_REGEX.is_match(date_part) { + Some(date_part.to_string()) + } else { + None + } + } + _ => None, + } + } else { + None + }; let mut sql_args = Vec::new(); for arg in args { let (sql, query) = Self::generate_sql_for_expr( @@ -1053,7 +1199,7 @@ impl CubeScanWrapperNode { Self::escape_interpolation_quotes( sql_generator .get_sql_templates() - .scalar_function(fun, sql_args) + .scalar_function(fun, sql_args, date_part) .map_err(|e| { DataFusionError::Internal(format!( "Can't generate SQL for scalar function: {}", diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 7b78ee83706fd..ae049085374c8 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -14292,7 +14292,7 @@ ORDER BY \"COUNT(count)\" DESC" ) .await; - query.unwrap_err(); + query.unwrap(); } #[tokio::test] @@ -18534,6 +18534,35 @@ ORDER BY \"COUNT(count)\" DESC" ); } + #[tokio::test] + async fn test_wrapper_tableau_sunday_week() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_logger(); + + let query_plan = convert_select_to_query_plan( + "SELECT (CAST(DATE_TRUNC('day', CAST(order_date AS TIMESTAMP)) AS DATE) - (((7 + CAST(EXTRACT(DOW FROM order_date) AS BIGINT) - 1) % 7) * INTERVAL '1 DAY')) AS \"twk:date:ok\", AVG(avgPrice) mp FROM KibanaSampleDataEcommerce a GROUP BY 1 ORDER BY 1 DESC" + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + let logical_plan = query_plan.as_logical_plan(); + assert!(logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + .contains("EXTRACT")); + } + #[tokio::test] async fn test_thoughtspot_pg_date_trunc_year() { init_logger(); diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs index 7e4ad03d3d43b..9548d57ec0b00 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs @@ -3,20 +3,21 @@ use crate::{ compile::{ engine::provider::CubeContext, rewrite::{ - agg_fun_expr, analysis::LogicalPlanAnalysis, binary_expr, cast_expr, + agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, binary_expr, cast_expr, cast_expr_explicit, column_expr, fun_expr, literal_expr, literal_int, literal_string, negative_expr, rewrite, rewriter::RewriteRules, to_day_interval_expr, - transforming_rewrite, udf_expr, CastExprDataType, LiteralExprValue, - LogicalPlanLanguage, + transforming_rewrite, transforming_rewrite_with_root, udf_expr, AliasExprAlias, + CastExprDataType, LiteralExprValue, LogicalPlanLanguage, }, }, var, var_iter, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, + logical_plan::DFSchema, scalar::ScalarValue, }; -use egg::{EGraph, Rewrite, Subst}; +use egg::{EGraph, Id, Rewrite, Subst}; use std::{convert::TryFrom, sync::Arc}; pub struct DateRules { @@ -221,7 +222,8 @@ impl RewriteRules for DateRules { vec![literal_string("day"), column_expr("?column")], ), ), - transforming_rewrite( + // TODO + transforming_rewrite_with_root( "cast-in-date-trunc", fun_expr( "DateTrunc", @@ -230,13 +232,16 @@ impl RewriteRules for DateRules { cast_expr(column_expr("?column"), "?data_type"), ], ), - fun_expr( - "DateTrunc", - vec![literal_expr("?granularity"), column_expr("?column")], + alias_expr( + fun_expr( + "DateTrunc", + vec![literal_expr("?granularity"), column_expr("?column")], + ), + "?alias", ), - self.unwrap_cast_to_timestamp("?data_type", "?granularity"), + self.unwrap_cast_to_timestamp("?data_type", "?granularity", "?alias"), ), - transforming_rewrite( + transforming_rewrite_with_root( "cast-in-date-trunc-double", fun_expr( "DateTrunc", @@ -254,17 +259,20 @@ impl RewriteRules for DateRules { ), ], ), - fun_expr( - "DateTrunc", - vec![ - literal_expr("?granularity"), - fun_expr( - "DateTrunc", - vec![literal_expr("?granularity"), "?expr".to_string()], - ), - ], + alias_expr( + fun_expr( + "DateTrunc", + vec![ + literal_expr("?granularity"), + fun_expr( + "DateTrunc", + vec![literal_expr("?granularity"), "?expr".to_string()], + ), + ], + ), + "?alias", ), - self.unwrap_cast_to_timestamp("?data_type", "?granularity"), + self.unwrap_cast_to_timestamp("?data_type", "?granularity", "?alias"), ), rewrite( "current-timestamp-to-now", @@ -276,12 +284,15 @@ impl RewriteRules for DateRules { udf_expr("localtimestamp", Vec::::new()), fun_expr("UtcTimestamp", Vec::::new()), ), - rewrite( + transforming_rewrite_with_root( "tableau-week", binary_expr( - fun_expr( - "DateTrunc", - vec!["?granularity".to_string(), column_expr("?column")], + alias_expr( + fun_expr( + "DateTrunc", + vec!["?granularity".to_string(), column_expr("?column")], + ), + "?date_trunc_alias", ), "+", negative_expr(binary_expr( @@ -294,10 +305,14 @@ impl RewriteRules for DateRules { literal_expr("?interval_one_day"), )), ), - fun_expr( - "DateTrunc", - vec![literal_string("week"), column_expr("?column")], + alias_expr( + fun_expr( + "DateTrunc", + vec![literal_string("week"), column_expr("?column")], + ), + "?alias", ), + self.transform_root_alias("?alias"), ), rewrite( "metabase-interval-date-range", @@ -603,33 +618,77 @@ impl DateRules { &self, data_type_var: &'static str, granularity_var: &'static str, - ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + alias_var: &'static str, + ) -> impl Fn(&mut EGraph, Id, &mut Subst) -> bool + { let data_type_var = var!(data_type_var); let granularity_var = var!(granularity_var); - move |egraph, subst| { + let alias_var = var!(alias_var); + move |egraph, root, subst| { for data_type in var_iter!(egraph[subst[data_type_var]], CastExprDataType) { - match data_type { - DataType::Timestamp(TimeUnit::Nanosecond, None) => return true, - DataType::Date32 => { - for granularity in - var_iter!(egraph[subst[granularity_var]], LiteralExprValue) - { - if let ScalarValue::Utf8(Some(granularity)) = granularity { - if let (Some(original_granularity), Some(day_granularity)) = ( - utils::granularity_str_to_int_order(&granularity, Some(false)), - utils::granularity_str_to_int_order("day", Some(false)), - ) { - if original_granularity >= day_granularity { - return true; + if let Some(original_expr) = egraph[root].data.original_expr.as_ref() { + let alias = original_expr.name(&DFSchema::empty()).unwrap(); + match data_type { + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + subst.insert( + alias_var, + egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias( + alias.to_string(), + ))), + ); + return true; + } + DataType::Date32 => { + for granularity in + var_iter!(egraph[subst[granularity_var]], LiteralExprValue) + { + if let ScalarValue::Utf8(Some(granularity)) = granularity { + if let (Some(original_granularity), Some(day_granularity)) = ( + utils::granularity_str_to_int_order( + &granularity, + Some(false), + ), + utils::granularity_str_to_int_order("day", Some(false)), + ) { + if original_granularity >= day_granularity { + subst.insert( + alias_var, + egraph.add(LogicalPlanLanguage::AliasExprAlias( + AliasExprAlias(alias.to_string()), + )), + ); + return true; + } } } } } + _ => (), } - _ => (), } } false } } + + pub fn transform_root_alias( + &self, + alias_var: &'static str, + ) -> impl Fn(&mut EGraph, Id, &mut Subst) -> bool + { + let alias_var = var!(alias_var); + move |egraph, root, subst| { + if let Some(original_expr) = egraph[root].data.original_expr.as_ref() { + let alias = original_expr.name(&DFSchema::empty()).unwrap(); + subst.insert( + alias_var, + egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias( + alias.to_string(), + ))), + ); + return true; + } + false + } + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper.rs index 2260ea99f1ae3..b9674e32d9ca9 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper.rs @@ -4,9 +4,9 @@ use crate::{ rewrite::{ agg_fun_expr, aggregate, alias_expr, analysis::LogicalPlanAnalysis, - binary_expr, case_expr_var_arg, column_expr, column_name_to_member_vec, cube_scan, - cube_scan_wrapper, fun_expr_var_arg, limit, literal_expr, original_expr_name, - projection, rewrite, + binary_expr, case_expr_var_arg, cast_expr, column_expr, column_name_to_member_vec, + cube_scan, cube_scan_wrapper, fun_expr, fun_expr_var_arg, limit, literal_expr, + original_expr_name, projection, rewrite, rewriter::RewriteRules, rules::{members::MemberRules, replacer_pull_up_node, replacer_push_down_node}, scalar_fun_expr_args, scalar_fun_expr_args_empty_tail, sort, sort_expr, @@ -610,6 +610,44 @@ impl RewriteRules for WrapperRules { ), self.transform_fun_expr("?fun", "?alias_to_cube"), ), + transforming_rewrite( + "wrapper-pull-up-date-part", + fun_expr_var_arg( + "DatePart", + scalar_fun_expr_args( + wrapper_pullup_replacer( + literal_expr("?date_part"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + scalar_fun_expr_args( + wrapper_pullup_replacer( + "?date", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + scalar_fun_expr_args_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + ), + ), + ), + wrapper_pullup_replacer( + fun_expr( + "DatePart", + vec![literal_expr("?date_part"), "?date".to_string()], + ), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + self.transform_date_part_expr("?alias_to_cube"), + ), rewrite( "wrapper-push-down-scalar-function-args", wrapper_pushdown_replacer( @@ -857,6 +895,43 @@ impl RewriteRules for WrapperRules { "?cube_members", ), ), + // Cast + rewrite( + "wrapper-push-down-cast", + wrapper_pushdown_replacer( + cast_expr("?expr", "?data_type"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + cast_expr( + wrapper_pushdown_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?data_type", + ), + ), + rewrite( + "wrapper-pull-up-cast", + cast_expr( + wrapper_pullup_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?data_type", + ), + wrapper_pullup_replacer( + cast_expr("?expr", "?data_type"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + ), // Column rewrite( "wrapper-push-down-column", @@ -1404,6 +1479,33 @@ impl WrapperRules { } } + fn transform_date_part_expr( + &self, + alias_to_cube_var: &'static str, + ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + let alias_to_cube_var = var!(alias_to_cube_var); + let meta = self.cube_context.meta.clone(); + move |egraph, subst| { + for alias_to_cube in var_iter!( + egraph[subst[alias_to_cube_var]], + WrapperPullupReplacerAliasToCube + ) + .cloned() + { + if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) { + if sql_generator + .get_sql_templates() + .templates + .contains_key("expressions/extract") + { + return true; + } + } + } + false + } + } + fn transform_case_expr( &self, alias_to_cube_var: &'static str, diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 935394530bd5b..9af135164fa35 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -207,6 +207,9 @@ pub fn get_test_tenant_ctx() -> Arc { "functions/COUNT_DISTINCT".to_string(), "COUNT(DISTINCT {{ args_concat }})".to_string(), ), + ("functions/DATETRUNC".to_string(), "DATE_TRUNC({{ args_concat }})".to_string()), + ("functions/DATEPART".to_string(), "DATE_PART({{ args_concat }})".to_string()), + ("expressions/extract".to_string(), "EXTRACT({{ date_part }} FROM {{ expr }})".to_string()), ( "statements/select".to_string(), r#"SELECT {{ select_concat | map(attribute='aliased') | join(', ') }} @@ -223,6 +226,8 @@ pub fn get_test_tenant_ctx() -> Arc { ("expressions/binary".to_string(), "{{ left }} {{ op }} {{ right }}".to_string()), ("expressions/case".to_string(), "CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END".to_string()), ("expressions/sort".to_string(), "{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST {% endif %}".to_string()), + ("expressions/cast".to_string(), "CAST({{ expr }} AS {{ data_type }})".to_string()), + ("expressions/interval".to_string(), "INTERVAL '{{ interval }}'".to_string()), ("quotes/identifiers".to_string(), "\"".to_string()), ("quotes/escape".to_string(), "\"\"".to_string()), ("params/param".to_string(), "${{ param_index + 1 }}".to_string()) diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index 8016ab5eb37fd..f77f0e2ec6f76 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -394,12 +394,13 @@ impl SqlTemplates { &self, scalar_function: BuiltinScalarFunction, args: Vec, + date_part: Option, ) -> Result { let function = scalar_function.to_string().to_uppercase(); let args_concat = args.join(", "); self.render_template( &format!("functions/{}", function), - context! { args_concat => args_concat, args => args }, + context! { args_concat => args_concat, args => args, date_part => date_part }, ) } @@ -439,6 +440,32 @@ impl SqlTemplates { ) } + pub fn extract_expr(&self, date_part: String, expr: String) -> Result { + self.render_template( + "expressions/extract", + context! { date_part => date_part, expr => expr }, + ) + } + + pub fn interval_expr( + &self, + interval: String, + num: i64, + date_part: String, + ) -> Result { + self.render_template( + "expressions/interval", + context! { interval => interval, num => num, date_part => date_part }, + ) + } + + pub fn cast_expr(&self, expr: String, data_type: String) -> Result { + self.render_template( + "expressions/cast", + context! { expr => expr, data_type => data_type }, + ) + } + pub fn param(&self, param_index: usize) -> Result { self.render_template("params/param", context! { param_index => param_index }) }