Skip to content

Commit

Permalink
feat: improve codegen (#3737)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Oct 24, 2023
1 parent fda495f commit 7e8e8f9
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 42 deletions.
81 changes: 49 additions & 32 deletions prqlc/prql-compiler/src/codegen/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,8 @@ use crate::utils::VALID_IDENT;

use super::{WriteOpt, WriteSource};

pub fn write_stmts(stmts: &Vec<Stmt>) -> String {
let mut opt = WriteOpt::default();

loop {
if let Some(s) = stmts.write(opt.clone()) {
break s;
} else {
opt.max_width += opt.max_width / 2;
}
}
}

pub fn write_expr(expr: &Expr) -> String {
let opt = WriteOpt::new_width(u16::MAX);
expr.write(opt).unwrap()
pub(crate) fn write_expr(expr: &Expr) -> String {
expr.write(WriteOpt::new_width(u16::MAX)).unwrap()
}

fn write_within<T: WriteSource>(node: &T, parent: &ExprKind, mut opt: WriteOpt) -> Option<String> {
Expand All @@ -39,19 +26,25 @@ impl WriteSource for Expr {
let mut r = String::new();

if let Some(alias) = &self.alias {
r += alias;
r += " = ";
r += opt.consume(alias)?;
r += opt.consume(" = ")?;
opt.unbound_expr = false;
}

let needs_parenthesis = (opt.unbound_expr && can_bind_left(&self.kind))
|| (opt.context_strength >= binding_strength(&self.kind));

if needs_parenthesis {
r += &self.kind.write_between("(", ")", opt)?;
if !needs_parenthesis {
r += &self.kind.write(opt.clone())?;
} else {
r += &self.kind.write(opt)?;
}
let value = self.kind.write_between("(", ")", opt.clone());

if let Some(value) = value {
r += &value;
} else {
r += &break_line_within_parenthesis(&self.kind, opt)?;
}
};
Some(r)
}
}
Expand Down Expand Up @@ -145,23 +138,29 @@ impl WriteSource for ExprKind {
Func(c) => {
let mut r = String::new();
for param in &c.params {
r += &write_ident_part(&param.name);
r += " ";
r += opt.consume(&write_ident_part(&param.name))?;
r += opt.consume(" ")?;
}
for param in &c.named_params {
r += &write_ident_part(&param.name);
r += ":";
r += &param.default_value.as_ref().unwrap().write(opt.clone())?;
r += " ";
r += opt.consume(&write_ident_part(&param.name))?;
r += opt.consume(":")?;
r += opt.consume(&param.default_value.as_ref().unwrap().write(opt.clone())?)?;
r += opt.consume(" ")?;
}
r += opt.consume("-> ")?;

// try a single line
if let Some(body) = c.body.write(opt.clone()) {
r += &body;
} else {
r += &break_line_within_parenthesis(c.body.as_ref(), opt)?;
}
r += "-> ";
r += &c.body.write(opt)?;

Some(r)
}
SString(parts) => display_interpolation("s", parts, opt),
FString(parts) => display_interpolation("f", parts, opt),
Literal(literal) => Some(literal.to_string()),
Literal(literal) => opt.consume(literal.to_string()),
Case(cases) => {
let mut r = String::new();
r += "case ";
Expand All @@ -179,6 +178,19 @@ impl WriteSource for ExprKind {
}
}

fn break_line_within_parenthesis<T: WriteSource>(expr: &T, mut opt: WriteOpt) -> Option<String> {
let mut r = "(\n".to_string();
opt.indent += 1;
r += &opt.write_indent();
opt.reset_line()?;
r += &expr.write(opt.clone())?;
r += "\n";
opt.indent -= 1;
r += &opt.write_indent();
r += ")";
Some(r)
}

fn binding_strength(expr: &ExprKind) -> u8 {
match expr {
// For example, if it's an Ident, it's basically infinite — a simple
Expand Down Expand Up @@ -269,7 +281,7 @@ impl WriteSource for Vec<Stmt> {
}

r += &opt.write_indent();
r += &stmt.write(opt.clone())?;
r += &stmt.write_or_expand(opt.clone());
}
Some(r)
}
Expand Down Expand Up @@ -463,7 +475,12 @@ mod test {

#[test]
fn test_simple() {
assert_is_formatted(r#"aggregate average_country_salary = (average salary)"#);
assert_is_formatted(
r#"
aggregate average_country_salary = (
average salary
)"#,
);
}

#[test]
Expand Down
24 changes: 18 additions & 6 deletions prqlc/prql-compiler/src/codegen/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod ast;
mod pl;

pub use ast::{write_expr, write_stmts};
pub(crate) use ast::write_expr;

pub trait WriteSource {
/// Converts self to its source representation according to specified
Expand All @@ -27,6 +27,17 @@ pub trait WriteSource {
r += opt.consume(suffix)?;
Some(r)
}

fn write_or_expand(&self, mut opt: WriteOpt) -> String {
loop {
if let Some(s) = self.write(opt.clone()) {
return s;
} else {
opt.max_width += opt.max_width / 2;
opt.reset_line();
}
}
}
}

impl<T: WriteSource> WriteSource for &T {
Expand Down Expand Up @@ -96,13 +107,14 @@ impl WriteOpt {
Some(())
}

fn consume<'a>(&mut self, source: &'a str) -> Option<&'a str> {
let width = if let Some(new_line) = source.rfind('\n') {
source.len() - new_line
/// Subtracts the width of the source from the remaining width and returns the source unchanged.
fn consume<S: AsRef<str>>(&mut self, source: S) -> Option<S> {
let width = if let Some(new_line) = source.as_ref().rfind('\n') {
source.as_ref().len() - new_line
} else {
source.len()
source.as_ref().len()
};
self.consume_width(width as u16);
self.consume_width(width as u16)?;
Some(source)
}

Expand Down
2 changes: 1 addition & 1 deletion prqlc/prql-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ pub fn rq_to_sql(rq: ir::rq::RelationalQuery, options: &Options) -> Result<Strin

/// Generate PRQL code from PL AST
pub fn pl_to_prql(pl: Vec<prqlc_ast::stmt::Stmt>) -> Result<String, ErrorMessages> {
Ok(codegen::write_stmts(&pl))
Ok(codegen::WriteSource::write(&pl, codegen::WriteOpt::default()).unwrap())
}

/// JSON serialization and deserialization functions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
---
source: prql-compiler/tests/integration/main.rs
source: prqlc/prql-compiler/tests/integration/main.rs
expression: "# mssql:test\nfrom a=albums\ntake 10\njoin tracks (==album_id)\ngroup {a.album_id, a.title} (aggregate price = ((sum tracks.unit_price)))\nsort album_id\n"
input_file: prql-compiler/tests/integration/queries/group_all.prql
input_file: prqlc/prql-compiler/tests/integration/queries/group_all.prql
---
from a = albums
take 10
join tracks (==album_id)
group {a.album_id, a.title} (aggregate price = (sum tracks.unit_price))
group {a.album_id, a.title} (aggregate price = (
sum tracks.unit_price
))
sort album_id

0 comments on commit 7e8e8f9

Please sign in to comment.