Skip to content

Commit

Permalink
Made non-copy variable not be forwarded in const folding. (#6324)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi committed Sep 8, 2024
1 parent 9dd5b2f commit c58fd12
Show file tree
Hide file tree
Showing 26 changed files with 391 additions and 265 deletions.
17 changes: 13 additions & 4 deletions crates/cairo-lang-defs/src/patcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ impl RewriteNode {
Self::Text(text.to_string())
}

pub fn mapped_text(text: &str, db: &dyn SyntaxGroup, origin: &impl TypedSyntaxNode) -> Self {
RewriteNode::Text(text.to_string()).mapped(db, origin)
pub fn mapped_text(
text: impl Into<String>,
db: &dyn SyntaxGroup,
origin: &impl TypedSyntaxNode,
) -> Self {
RewriteNode::Text(text.into()).mapped(db, origin)
}

pub fn empty() -> Self {
Expand Down Expand Up @@ -241,13 +245,18 @@ pub struct PatchBuilder<'a> {
origin: CodeOrigin,
}
impl<'a> PatchBuilder<'a> {
/// Creates a new patch builder, originating from `origin` node.
/// Creates a new patch builder, originating from `origin` typed node.
pub fn new(db: &'a dyn SyntaxGroup, origin: &impl TypedSyntaxNode) -> Self {
Self::new_ex(db, &origin.as_syntax_node())
}

/// Creates a new patch builder, originating from `origin` node.
pub fn new_ex(db: &'a dyn SyntaxGroup, origin: &SyntaxNode) -> Self {
Self {
db,
code: String::default(),
code_mappings: vec![],
origin: CodeOrigin::Span(origin.as_syntax_node().span_without_trivia(db)),
origin: CodeOrigin::Span(origin.span_without_trivia(db)),
}
}

Expand Down
16 changes: 12 additions & 4 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ enum VarInfo {
/// The variable is a snapshot of another variable.
Snapshot(Box<VarInfo>),
/// The variable is a struct of other variables.
Struct(Vec<VarInfo>),
/// `None` values represent variables that are not tracked.
Struct(Vec<Option<VarInfo>>),
}

/// Performs constant folding on the lowered program.
Expand Down Expand Up @@ -101,14 +102,19 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
let mut contains_info = false;
for input in inputs.iter() {
let Some(info) = ctx.var_info.get(&input.var_id) else {
all_args.push(VarInfo::Var(*input));
all_args.push(
lowered.variables[input.var_id]
.copyable
.is_ok()
.then(|| VarInfo::Var(*input)),
);
continue;
};
contains_info = true;
if let VarInfo::Const(value) = info {
const_args.push(value.clone());
}
all_args.push(info.clone());
all_args.push(Some(info.clone()));
}
if const_args.len() == inputs.len() {
let value = ConstValue::Struct(const_args, lowered.variables[*output].ty);
Expand Down Expand Up @@ -141,7 +147,9 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
}
VarInfo::Struct(members) => {
for (output, member) in zip_eq(outputs, members.clone()) {
ctx.var_info.insert(*output, wrap_with_snapshots(member));
if let Some(member) = member {
ctx.var_info.insert(*output, wrap_with_snapshots(member));
}
}
}
_ => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2946,3 +2946,101 @@ End:
Return(v6)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Construct with undroppable.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> bool {
let x: (felt252, Felt252Dict<felt252>) = (0, Default::<Felt252Dict>::default());
let (l, _) = @x;
*l == 0
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::felt252) <- 0
(v1: core::dict::Felt252Dict::<core::felt252>) <- core::dict::felt252_dict_new::<core::felt252>()
(v2: (core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- struct_construct(v0, v1)
(v3: (core::felt252, core::dict::Felt252Dict::<core::felt252>), v4: @(core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- snapshot(v2)
(v5: core::felt252, v6: core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v3)
(v7: core::dict::SquashedFelt252Dict::<core::felt252>) <- core::dict::Felt252DictImpl::<core::felt252, core::Felt252Felt252DictValue>::squash(v6)
(v9: @core::felt252, v10: @core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v4)
(v11: core::felt252) <- desnap(v9)
(v14: core::felt252) <- 0
(v19: core::felt252) <- core::felt252_sub(v11, v14)
End:
Match(match core::felt252_is_zero(v19) {
IsZeroResult::Zero => blk1,
IsZeroResult::NonZero(v20) => blk2,
})

blk1:
Statements:
(v21: ()) <- struct_construct()
(v22: core::bool) <- bool::True(v21)
End:
Goto(blk3, {v22 -> v23})

blk2:
Statements:
(v24: ()) <- struct_construct()
(v25: core::bool) <- bool::False(v24)
End:
Goto(blk3, {v25 -> v23})

blk3:
Statements:
End:
Return(v23)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::felt252) <- 0
(v1: core::dict::Felt252Dict::<core::felt252>) <- core::dict::felt252_dict_new::<core::felt252>()
(v2: (core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- struct_construct(v0, v1)
(v3: (core::felt252, core::dict::Felt252Dict::<core::felt252>), v4: @(core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- snapshot(v2)
(v5: core::felt252, v6: core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v3)
(v7: core::dict::SquashedFelt252Dict::<core::felt252>) <- core::dict::Felt252DictImpl::<core::felt252, core::Felt252Felt252DictValue>::squash(v6)
(v9: @core::felt252, v10: @core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v4)
(v11: core::felt252) <- desnap(v9)
(v14: core::felt252) <- 0
(v19: core::felt252) <- core::felt252_sub(v11, v14)
End:
Goto(blk1, {})

blk1:
Statements:
(v21: ()) <- struct_construct()
(v22: core::bool) <- bool::True(v21)
End:
Goto(blk3, {v22 -> v23})

blk2:
Statements:
(v24: ()) <- struct_construct()
(v25: core::bool) <- bool::False(v24)
End:
Goto(blk3, {v25 -> v23})

blk3:
Statements:
End:
Return(v23)

//! > lowering_diagnostics
19 changes: 12 additions & 7 deletions crates/cairo-lang-plugins/src/plugins/derive/clone.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

/// Adds derive result for the `Clone` trait.
pub fn handle_clone(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &mut DeriveResult) {
pub fn handle_clone(
info: &DeriveInfo,
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let header =
info.format_impl_header("core::clone", "Clone", &["core::clone::Clone", "Destruct"]);
let full_typename = info.full_typename();
Expand Down Expand Up @@ -41,16 +46,16 @@ pub fn handle_clone(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &m
}
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"
Some(formatdoc! {"
{header} {{
fn clone(self: @{full_typename}) -> {full_typename} {{
{body}
}}
}}
"});
"})
}
20 changes: 13 additions & 7 deletions crates/cairo-lang-plugins/src/plugins/derive/debug.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

/// Adds derive result for the `Debug` trait.
pub fn handle_debug(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &mut DeriveResult) {
pub fn handle_debug(
info: &DeriveInfo,
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let header = info.format_impl_header("core::fmt", "Debug", &["core::fmt::Debug"]);
let full_typename = info.full_typename();
let name = &info.name;
Expand Down Expand Up @@ -51,16 +56,17 @@ pub fn handle_debug(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &m
)
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"

Some(formatdoc! {"
{header} {{
fn fmt(self: @{full_typename}, ref f: core::fmt::Formatter) -> core::result::Result::<(), core::fmt::Error> {{
{body}
}}
}}
"});
"})
}
29 changes: 14 additions & 15 deletions crates/cairo-lang-plugins/src/plugins/derive/default.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::TypedSyntaxNode;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

pub const DEFAULT_ATTR: &str = "default";
Expand All @@ -16,9 +15,9 @@ pub const DEFAULT_ATTR: &str = "default";
pub fn handle_default(
db: &dyn SyntaxGroup,
info: &DeriveInfo,
stable_ptr: SyntaxStablePtrId,
result: &mut DeriveResult,
) {
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let header = info.format_impl_header(
"core::traits",
"Default",
Expand All @@ -34,15 +33,15 @@ pub fn handle_default(
Some((variant, variant.attributes.find_attr(db, DEFAULT_ATTR)?))
});
let Some((default_variant, _)) = default_variants.next() else {
result.diagnostics.push(PluginDiagnostic::error(
stable_ptr,
diagnostics.push(PluginDiagnostic::error(
derived,
"derive `Default` for enum only supported with a default variant.".into(),
));
return;
return None;
};
for (_, extra_default_attr) in default_variants {
result.diagnostics.push(PluginDiagnostic::error(
extra_default_attr.as_syntax_node().stable_ptr(),
diagnostics.push(PluginDiagnostic::error(
&extra_default_attr,
"Multiple variants annotated with `#[default]`".into(),
));
}
Expand All @@ -63,16 +62,16 @@ pub fn handle_default(
}
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"
Some(formatdoc! {"
{header} {{
fn default() -> {full_typename} {{
{body}
}}
}}
"});
"})
}
20 changes: 11 additions & 9 deletions crates/cairo-lang-plugins/src/plugins/derive/destruct.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

/// Adds derive result for the `Destruct` trait.
pub fn handle_destruct(
info: &DeriveInfo,
stable_ptr: SyntaxStablePtrId,
result: &mut DeriveResult,
) {
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let full_typename = info.full_typename();
let ty = &info.name;
let header = info.format_impl_header("core::traits", "Destruct", &["core::traits::Destruct"]);
Expand Down Expand Up @@ -45,16 +46,17 @@ pub fn handle_destruct(
)
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"

Some(formatdoc! {"
{header} {{
fn destruct(self: {full_typename}) nopanic {{
{body}
}}
}}
"});
"})
}
Loading

0 comments on commit c58fd12

Please sign in to comment.