Skip to content

Commit

Permalink
feat(cubesql): EXTRACT SQL push down (#7151)
Browse files Browse the repository at this point in the history
* feat(cubesql): `EXTRACT` SQL push down

* Update datafusion
  • Loading branch information
paveltiunov authored Sep 20, 2023
1 parent 6b9ae70 commit e30c4da
Show file tree
Hide file tree
Showing 14 changed files with 445 additions and 61 deletions.
3 changes: 2 additions & 1 deletion packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -2444,8 +2444,9 @@ class BaseQuery {
expressions: {
column_aliased: '{{expr}} {{quoted_alias}}',
case: 'CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END',
binary: '{{ left }} {{ op }} {{ right }}',
binary: '({{ left }} {{ op }} {{ right }})',
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
cast: 'CAST({{ expr }} AS {{ data_type }})',
},
quotes: {
identifiers: '"',
Expand Down
4 changes: 4 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/BigqueryQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ export class BigqueryQuery extends BaseQuery {
const templates = super.sqlTemplates();
templates.quotes.identifiers = '`';
templates.quotes.escape = '\\`';
templates.functions.DATETRUNC = 'DATETIME_TRUNC(CAST({{ args[1] }} AS DATETIME), {{ date_part }})';
templates.expressions.binary = '{% if op == \'%\' %}MOD({{ left }}, {{ right }}){% else %}({{ left }} {{ op }} {{ right }}){% endif %}';
templates.expressions.interval = 'INTERVAL {{ interval }}';
templates.expressions.extract = 'EXTRACT({% if date_part == \'DOW\' %}DAYOFWEEK{% else %}{{ date_part }}{% endif %} FROM {{ expr }})';
return templates;
}
}
4 changes: 4 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/PostgresQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ export class PostgresQuery extends BaseQuery {
templates.params.param = '${{ param_index + 1 }}';
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
templates.functions.CONCAT = 'CONCAT({% for arg in args %}CAST({{arg}} AS TEXT){% if not loop.last %},{% endif %}{% endfor %})';
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
templates.expressions.interval = 'INTERVAL \'{{ interval }}\'';
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';

return templates;
}
}
3 changes: 3 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ export class PrestodbQuery extends BaseQuery {
sqlTemplates() {
const templates = super.sqlTemplates();
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';
templates.expressions.interval = 'INTERVAL \'{{ num }}\' {{ date_part }}';
return templates;
}
}
3 changes: 3 additions & 0 deletions packages/cubejs-schema-compiler/src/adapter/SnowflakeQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ export class SnowflakeQuery extends BaseQuery {
sqlTemplates() {
const templates = super.sqlTemplates();
templates.functions.DATETRUNC = 'DATE_TRUNC({{ args_concat }})';
templates.functions.DATEPART = 'DATE_PART({{ args_concat }})';
templates.expressions.extract = 'EXTRACT({{ date_part }} FROM {{ expr }})';
templates.expressions.interval = 'INTERVAL \'{{ interval }}\'';
return templates;
}
}
12 changes: 6 additions & 6 deletions rust/cubesql/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ documentation = "https://cube.dev/docs"
homepage = "https://cube.dev"

[dependencies]
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "915a600ea4d3b66161cd77ff94747960f840816e", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "00ed6a6d469a69f57dd1f08c1fda3f7c2cf12d80", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
anyhow = "1.0"
thiserror = "1.0"
cubeclient = { path = "../cubeclient" }
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/compile/engine/df/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ pub fn transform_response<V: ValueObject>(
(FieldValue::String(s), builder) => {
let timestamp = NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%dT%H:%M:%S.%f")
.or_else(|_| NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%d %H:%M:%S.%f"))
.or_else(|_| NaiveDateTime::parse_from_str(s.as_str(), "%Y-%m-%dT%H:%M:%S"))
.map_err(|e| {
DataFusionError::Execution(format!(
"Can't parse timestamp: '{}': {}",
Expand Down
156 changes: 151 additions & 5 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
compile::{
engine::df::scan::{CubeScanNode, MemberField, WrappedSelectNode},
engine::df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode},
rewrite::WrappedSelectType,
},
sql::AuthContextRef,
Expand All @@ -16,7 +16,7 @@ use datafusion::{
plan::Extension, replace_col, replace_col_to_expr, Column, DFSchema, DFSchemaRef, Expr,
LogicalPlan, UserDefinedLogicalNode,
},
physical_plan::aggregates::AggregateFunction,
physical_plan::{aggregates::AggregateFunction, functions::BuiltinScalarFunction},
scalar::ScalarValue,
};
use itertools::Itertools;
Expand Down Expand Up @@ -144,6 +144,10 @@ pub struct SqlGenerationResult {
pub request: V1LoadRequestQuery,
}

lazy_static! {
static ref DATE_PART_REGEX: Regex = Regex::new("^[A-Za-z_ ]+$").unwrap();
}

impl CubeScanWrapperNode {
pub async fn generate_sql(
&self,
Expand Down Expand Up @@ -934,7 +938,61 @@ impl CubeScanWrapperNode {
);
Ok((resulting_sql, sql_query))
}
// Expr::Cast { .. } => {}
Expr::Cast { expr, data_type } => {
let (expr, sql_query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*expr,
ungrouped_scan_node.clone(),
)
.await?;
let data_type = match data_type {
DataType::Null => "NULL",
DataType::Boolean => "BOOLEAN",
DataType::Int8 => "INTEGER",
DataType::Int16 => "INTEGER",
DataType::Int32 => "INTEGER",
DataType::Int64 => "INTEGER",
DataType::UInt8 => "INTEGER",
DataType::UInt16 => "INTEGER",
DataType::UInt32 => "INTEGER",
DataType::UInt64 => "INTEGER",
DataType::Float16 => "FLOAT",
DataType::Float32 => "FLOAT",
DataType::Float64 => "DOUBLE",
DataType::Timestamp(_, _) => "TIMESTAMP",
DataType::Date32 => "DATE",
DataType::Date64 => "DATE",
DataType::Time32(_) => "TIME",
DataType::Time64(_) => "TIME",
DataType::Duration(_) => "INTERVAL",
DataType::Interval(_) => "INTERVAL",
DataType::Binary => "BYTEA",
DataType::FixedSizeBinary(_) => "BYTEA",
DataType::Utf8 => "TEXT",
DataType::LargeUtf8 => "TEXT",
x => {
return Err(DataFusionError::Execution(format!(
"Can't generate SQL for cast: type isn't supported: {:?}",
x
)));
}
};
let resulting_sql = Self::escape_interpolation_quotes(
sql_generator
.get_sql_templates()
.cast_expr(expr, data_type.to_string())
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for cast: {}",
e
))
})?,
ungrouped_scan_node.is_some(),
);
Ok((resulting_sql, sql_query))
}
// Expr::TryCast { .. } => {}
Expr::Sort {
expr,
Expand Down Expand Up @@ -1024,7 +1082,41 @@ impl CubeScanWrapperNode {
// ScalarValue::TimestampMicrosecond(_, _) => {}
// ScalarValue::TimestampNanosecond(_, _) => {}
// ScalarValue::IntervalYearMonth(_) => {}
// ScalarValue::IntervalDayTime(_) => {}
ScalarValue::IntervalDayTime(x) => {
if let Some(x) = x {
let days = x >> 32;
let millis = x & 0xFFFFFFFF;
if days > 0 && millis > 0 {
return Err(DataFusionError::Internal(format!(
"Can't generate SQL for interval: mixed intervals aren't supported: {} days {} millis encoded as {}",
days, millis, x
)));
}
let (num, date_part) = if days > 0 {
(days, "DAY")
} else {
(millis, "MILLISECOND")
};
let interval = format!("{} {}", num, date_part);
(
Self::escape_interpolation_quotes(
sql_generator
.get_sql_templates()
.interval_expr(interval, num, date_part.to_string())
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for interval: {}",
e
))
})?,
ungrouped_scan_node.is_some(),
),
sql_query,
)
} else {
("NULL".to_string(), sql_query)
}
}
// ScalarValue::IntervalMonthDayNano(_) => {}
// ScalarValue::Struct(_, _) => {}
x => {
Expand All @@ -1036,6 +1128,60 @@ impl CubeScanWrapperNode {
})
}
Expr::ScalarFunction { fun, args } => {
if let BuiltinScalarFunction::DatePart = &fun {
if args.len() >= 2 {
match &args[0] {
Expr::Literal(ScalarValue::Utf8(Some(date_part))) => {
// Security check to prevent SQL injection
if !DATE_PART_REGEX.is_match(date_part) {
return Err(DataFusionError::Internal(format!(
"Can't generate SQL for scalar function: date part '{}' is not supported",
date_part
)));
}
let (arg_sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
args[1].clone(),
ungrouped_scan_node.clone(),
)
.await?;
return Ok((
Self::escape_interpolation_quotes(
sql_generator
.get_sql_templates()
.extract_expr(date_part.to_string(), arg_sql)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for scalar function: {}",
e
))
})?,
ungrouped_scan_node.is_some(),
),
query,
));
}
_ => {}
}
}
}
let date_part = if let BuiltinScalarFunction::DateTrunc = &fun {
match &args[0] {
Expr::Literal(ScalarValue::Utf8(Some(date_part))) => {
// Security check to prevent SQL injection
if DATE_PART_REGEX.is_match(date_part) {
Some(date_part.to_string())
} else {
None
}
}
_ => None,
}
} else {
None
};
let mut sql_args = Vec::new();
for arg in args {
let (sql, query) = Self::generate_sql_for_expr(
Expand All @@ -1053,7 +1199,7 @@ impl CubeScanWrapperNode {
Self::escape_interpolation_quotes(
sql_generator
.get_sql_templates()
.scalar_function(fun, sql_args)
.scalar_function(fun, sql_args, date_part)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for scalar function: {}",
Expand Down
31 changes: 30 additions & 1 deletion rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14292,7 +14292,7 @@ ORDER BY \"COUNT(count)\" DESC"
)
.await;

query.unwrap_err();
query.unwrap();
}

#[tokio::test]
Expand Down Expand Up @@ -18534,6 +18534,35 @@ ORDER BY \"COUNT(count)\" DESC"
);
}

#[tokio::test]
async fn test_wrapper_tableau_sunday_week() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT (CAST(DATE_TRUNC('day', CAST(order_date AS TIMESTAMP)) AS DATE) - (((7 + CAST(EXTRACT(DOW FROM order_date) AS BIGINT) - 1) % 7) * INTERVAL '1 DAY')) AS \"twk:date:ok\", AVG(avgPrice) mp FROM KibanaSampleDataEcommerce a GROUP BY 1 ORDER BY 1 DESC"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

let logical_plan = query_plan.as_logical_plan();
assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("EXTRACT"));
}

#[tokio::test]
async fn test_thoughtspot_pg_date_trunc_year() {
init_logger();
Expand Down
Loading

0 comments on commit e30c4da

Please sign in to comment.