Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: in PQ post processing, revert sort columns just before propagating down pipeline #5098

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ repos:
hooks:
- id: actionlint
- repo: https://github.com/tcort/markdown-link-check
rev: v3.13.6
rev: v3.12.2
hooks:
- id: markdown-link-check
name: markdown-link-check-local
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

- Sort steps in sub-pipelines no longer cause a column lookup error
(@lukapeschke, #5066)
- Dereferencing of sort columns when rendering SQL now done in context of main
pipeline (@kgutwin, #5098)

**Documentation**:

Expand Down
31 changes: 30 additions & 1 deletion prqlc/prqlc/src/sql/pq/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,15 +643,44 @@ impl<'a> CidRedirector<'a> {
ctx: &'a mut AnchorContext,
) -> Vec<ColumnSort<CId>> {
let cid_redirects = ctx.relation_instances[riid].cid_redirects.clone();
log::debug!("redirect sorts {sorts:?} {riid:?} cid_redirects {cid_redirects:?}");
let mut redirector = CidRedirector { ctx, cid_redirects };

fold_column_sorts(&mut redirector, sorts).unwrap()
}

// revert sort columns back to their original pre-split columns
pub fn revert_sorts(
sorts: Vec<ColumnSort<CId>>,
ctx: &'a mut AnchorContext,
) -> Vec<ColumnSort<CId>> {
sorts
.into_iter()
.map(|sort| {
let decl = ctx.column_decls.get(&sort.column).unwrap();
if let ColumnDecl::RelationColumn(riid, cid, _) = decl {
let cid_redirects = &ctx.relation_instances[riid].cid_redirects;
for (source, target) in cid_redirects.iter() {
if target == cid {
log::debug!("reverting {target:?} back to {source:?}");
return ColumnSort {
direction: sort.direction,
column: *source,
};
}
}
}
sort
})
.collect()
}
}

impl RqFold for CidRedirector<'_> {
fn fold_cid(&mut self, cid: CId) -> Result<CId> {
Ok(self.cid_redirects.get(&cid).cloned().unwrap_or(cid))
let v = self.cid_redirects.get(&cid).cloned().unwrap_or(cid);
log::debug!("mapping {cid:?} via {0:?} to {v:?}", self.cid_redirects);
Ok(v)
}

fn fold_transform(&mut self, transform: Transform) -> Result<Transform> {
Expand Down
68 changes: 35 additions & 33 deletions prqlc/prqlc/src/sql/pq/postprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn infer_sorts(query: SqlQuery, ctx: &mut Context) -> SqlQuery {
let mut s = SortingInference {
last_sorting: Vec::new(),
ctes_sorting: HashMap::new(),
main_relation: false,
ctx,
};

Expand All @@ -36,12 +37,12 @@ fn infer_sorts(query: SqlQuery, ctx: &mut Context) -> SqlQuery {
struct SortingInference<'a> {
last_sorting: Sorting,
ctes_sorting: HashMap<TId, CteSorting>,
main_relation: bool,
ctx: &'a mut Context,
}

struct CteSorting {
sorting: Sorting,
has_been_used: bool,
}

impl RqFold for SortingInference<'_> {}
Expand All @@ -50,51 +51,29 @@ impl PqFold for SortingInference<'_> {
fn fold_sql_query(&mut self, query: SqlQuery) -> Result<SqlQuery> {
let mut ctes = Vec::with_capacity(query.ctes.len());
for cte in query.ctes {
log::debug!("infer_sorts: {0:?}", cte.tid);
let cte = self.fold_cte(cte)?;

// store sorting to be used later in From references
let sorting = self.last_sorting.drain(..).collect();
let sorting = CteSorting {
sorting,
has_been_used: false,
};
log::debug!("--- sorting {sorting:?}");
let sorting = CteSorting { sorting };
self.ctes_sorting.insert(cte.tid, sorting);

ctes.push(cte);
}

// fold main_relation using a made-up tid
// fold main_relation
log::debug!("infer_sorts: main relation");
self.main_relation = true;
let mut main_relation = self.fold_sql_relation(query.main_relation)?;
log::debug!("--== last_sorting {0:?}", self.last_sorting);

// push a sort at the back of the main pipeline
if let SqlRelation::AtomicPipeline(pipeline) = &mut main_relation {
pipeline.push(SqlTransform::Sort(self.last_sorting.drain(..).collect()));
}

// make sure that all CTEs whose sorting was used actually SELECT it
for cte in &mut ctes {
let sorting = self.ctes_sorting.get(&cte.tid).unwrap();
if !sorting.has_been_used {
continue;
}

let CteKind::Normal(sql_relation) = &mut cte.kind else {
continue;
};
let Some(pipeline) = sql_relation.as_atomic_pipeline_mut() else {
continue;
};
let select = pipeline.iter_mut().find_map(|x| x.as_select_mut()).unwrap();

for column_sort in &sorting.sorting {
let cid = column_sort.column;
let is_selected = select.contains(&cid);
if !is_selected {
select.push(cid);
}
}
}

Ok(SqlQuery {
ctes,
main_relation,
Expand All @@ -116,6 +95,7 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
transforms: Vec<SqlTransform<RelationExpr, ()>>,
) -> Result<Vec<SqlTransform<RelationExpr, ()>>> {
let mut sorting = Vec::new();
let mut has_sort_transform = false;

let mut result = Vec::with_capacity(transforms.len() + 1);

Expand All @@ -126,7 +106,6 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
RelationExprKind::Ref(ref tid) => {
// infer sorting from referenced pipeline
if let Some(cte_sorting) = self.ctes_sorting.get_mut(tid) {
cte_sorting.has_been_used = true;
sorting.clone_from(&cte_sorting.sorting);
} else {
sorting = Vec::new();
Expand All @@ -147,8 +126,9 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
}

// just store sorting and don't emit Sort
SqlTransform::Sort(s) => {
sorting.clone_from(&s);
SqlTransform::Sort(expr) => {
sorting.clone_from(&expr);
has_sort_transform = true;
continue;
}

Expand All @@ -166,6 +146,28 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
result.push(transform)
}

if !self.main_relation {
// if this is a CTE, make sure that its SELECT includes the
// columns from the sort
let select = result.iter_mut().find_map(|x| x.as_select_mut()).unwrap();
for column_sort in &sorting {
let cid = column_sort.column;
let is_selected = select.contains(&cid);
if !is_selected {
log::debug!("adding {cid:?} to {select:?}");
select.push(cid);
}
}

if has_sort_transform {
// now revert the sort columns so that the output
// sorting reflects the input column cids, needed to
// ensure proper column reference lookup in the final
// steps
sorting = CidRedirector::revert_sorts(sorting, &mut self.ctx.anchor);
}
}

// remember sorting for this pipeline
self.last_sorting = sorting;

Expand Down
5 changes: 5 additions & 0 deletions prqlc/prqlc/tests/integration/queries/sort_2.prql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from albums
select { AA=album_id, artist_id }
sort AA
filter AA >= 25
join artists (==artist_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
source: prqlc/prqlc/tests/integration/queries.rs
expression: "from albums\nselect { AA=album_id, artist_id }\nsort AA\nfilter AA >= 25\njoin artists (==artist_id)\n"
input_file: prqlc/prqlc/tests/integration/queries/sort_2.prql
---
WITH table_1 AS (
SELECT
album_id AS "AA",
artist_id
FROM
albums
),
table_0 AS (
SELECT
"AA",
artist_id
FROM
table_1
WHERE
"AA" >= 25
)
SELECT
table_0."AA",
table_0.artist_id,
artists.*
FROM
table_0
JOIN artists ON table_0.artist_id = artists.artist_id
ORDER BY
table_0."AA"
Loading