Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made non-copy variable not be forwarded in const folding. #6324

Merged
merged 3 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions crates/cairo-lang-defs/src/patcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ impl RewriteNode {
Self::Text(text.to_string())
}

pub fn mapped_text(text: &str, db: &dyn SyntaxGroup, origin: &impl TypedSyntaxNode) -> Self {
RewriteNode::Text(text.to_string()).mapped(db, origin)
pub fn mapped_text(
text: impl Into<String>,
db: &dyn SyntaxGroup,
origin: &impl TypedSyntaxNode,
) -> Self {
RewriteNode::Text(text.into()).mapped(db, origin)
}

pub fn empty() -> Self {
Expand Down Expand Up @@ -241,13 +245,18 @@ pub struct PatchBuilder<'a> {
origin: CodeOrigin,
}
impl<'a> PatchBuilder<'a> {
/// Creates a new patch builder, originating from `origin` node.
/// Creates a new patch builder, originating from `origin` typed node.
pub fn new(db: &'a dyn SyntaxGroup, origin: &impl TypedSyntaxNode) -> Self {
Self::new_ex(db, &origin.as_syntax_node())
}

/// Creates a new patch builder, originating from `origin` node.
pub fn new_ex(db: &'a dyn SyntaxGroup, origin: &SyntaxNode) -> Self {
Self {
db,
code: String::default(),
code_mappings: vec![],
origin: CodeOrigin::Span(origin.as_syntax_node().span_without_trivia(db)),
origin: CodeOrigin::Span(origin.span_without_trivia(db)),
}
}

Expand Down
16 changes: 12 additions & 4 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ enum VarInfo {
/// The variable is a snapshot of another variable.
Snapshot(Box<VarInfo>),
/// The variable is a struct of other variables.
Struct(Vec<VarInfo>),
/// `None` values represent variables that are not tracked.
Struct(Vec<Option<VarInfo>>),
}

/// Performs constant folding on the lowered program.
Expand Down Expand Up @@ -101,14 +102,19 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
let mut contains_info = false;
for input in inputs.iter() {
let Some(info) = ctx.var_info.get(&input.var_id) else {
all_args.push(VarInfo::Var(*input));
all_args.push(
lowered.variables[input.var_id]
.copyable
.is_ok()
.then(|| VarInfo::Var(*input)),
);
continue;
};
contains_info = true;
if let VarInfo::Const(value) = info {
const_args.push(value.clone());
}
all_args.push(info.clone());
all_args.push(Some(info.clone()));
}
if const_args.len() == inputs.len() {
let value = ConstValue::Struct(const_args, lowered.variables[*output].ty);
Expand Down Expand Up @@ -141,7 +147,9 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
}
VarInfo::Struct(members) => {
for (output, member) in zip_eq(outputs, members.clone()) {
ctx.var_info.insert(*output, wrap_with_snapshots(member));
if let Some(member) = member {
ctx.var_info.insert(*output, wrap_with_snapshots(member));
}
}
}
_ => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2946,3 +2946,101 @@ End:
Return(v6)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Construct with undroppable.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> bool {
let x: (felt252, Felt252Dict<felt252>) = (0, Default::<Felt252Dict>::default());
let (l, _) = @x;
*l == 0
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::felt252) <- 0
(v1: core::dict::Felt252Dict::<core::felt252>) <- core::dict::felt252_dict_new::<core::felt252>()
(v2: (core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- struct_construct(v0, v1)
(v3: (core::felt252, core::dict::Felt252Dict::<core::felt252>), v4: @(core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- snapshot(v2)
(v5: core::felt252, v6: core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v3)
(v7: core::dict::SquashedFelt252Dict::<core::felt252>) <- core::dict::Felt252DictImpl::<core::felt252, core::Felt252Felt252DictValue>::squash(v6)
(v9: @core::felt252, v10: @core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v4)
(v11: core::felt252) <- desnap(v9)
(v14: core::felt252) <- 0
(v19: core::felt252) <- core::felt252_sub(v11, v14)
End:
Match(match core::felt252_is_zero(v19) {
IsZeroResult::Zero => blk1,
IsZeroResult::NonZero(v20) => blk2,
})

blk1:
Statements:
(v21: ()) <- struct_construct()
(v22: core::bool) <- bool::True(v21)
End:
Goto(blk3, {v22 -> v23})

blk2:
Statements:
(v24: ()) <- struct_construct()
(v25: core::bool) <- bool::False(v24)
End:
Goto(blk3, {v25 -> v23})

blk3:
Statements:
End:
Return(v23)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::felt252) <- 0
(v1: core::dict::Felt252Dict::<core::felt252>) <- core::dict::felt252_dict_new::<core::felt252>()
(v2: (core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- struct_construct(v0, v1)
(v3: (core::felt252, core::dict::Felt252Dict::<core::felt252>), v4: @(core::felt252, core::dict::Felt252Dict::<core::felt252>)) <- snapshot(v2)
(v5: core::felt252, v6: core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v3)
(v7: core::dict::SquashedFelt252Dict::<core::felt252>) <- core::dict::Felt252DictImpl::<core::felt252, core::Felt252Felt252DictValue>::squash(v6)
(v9: @core::felt252, v10: @core::dict::Felt252Dict::<core::felt252>) <- struct_destructure(v4)
(v11: core::felt252) <- desnap(v9)
(v14: core::felt252) <- 0
(v19: core::felt252) <- core::felt252_sub(v11, v14)
End:
Goto(blk1, {})

blk1:
Statements:
(v21: ()) <- struct_construct()
(v22: core::bool) <- bool::True(v21)
End:
Goto(blk3, {v22 -> v23})

blk2:
Statements:
(v24: ()) <- struct_construct()
(v25: core::bool) <- bool::False(v24)
End:
Goto(blk3, {v25 -> v23})

blk3:
Statements:
End:
Return(v23)

//! > lowering_diagnostics
19 changes: 12 additions & 7 deletions crates/cairo-lang-plugins/src/plugins/derive/clone.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

/// Adds derive result for the `Clone` trait.
pub fn handle_clone(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &mut DeriveResult) {
pub fn handle_clone(
info: &DeriveInfo,
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let header =
info.format_impl_header("core::clone", "Clone", &["core::clone::Clone", "Destruct"]);
let full_typename = info.full_typename();
Expand Down Expand Up @@ -41,16 +46,16 @@ pub fn handle_clone(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &m
}
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"
Some(formatdoc! {"
{header} {{
fn clone(self: @{full_typename}) -> {full_typename} {{
{body}
}}
}}
"});
"})
}
20 changes: 13 additions & 7 deletions crates/cairo-lang-plugins/src/plugins/derive/debug.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

/// Adds derive result for the `Debug` trait.
pub fn handle_debug(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &mut DeriveResult) {
pub fn handle_debug(
info: &DeriveInfo,
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let header = info.format_impl_header("core::fmt", "Debug", &["core::fmt::Debug"]);
let full_typename = info.full_typename();
let name = &info.name;
Expand Down Expand Up @@ -51,16 +56,17 @@ pub fn handle_debug(info: &DeriveInfo, stable_ptr: SyntaxStablePtrId, result: &m
)
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"

Some(formatdoc! {"
{header} {{
fn fmt(self: @{full_typename}, ref f: core::fmt::Formatter) -> core::result::Result::<(), core::fmt::Error> {{
{body}
}}
}}
"});
"})
}
29 changes: 14 additions & 15 deletions crates/cairo-lang-plugins/src/plugins/derive/default.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_syntax::node::TypedSyntaxNode;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

pub const DEFAULT_ATTR: &str = "default";
Expand All @@ -16,9 +15,9 @@ pub const DEFAULT_ATTR: &str = "default";
pub fn handle_default(
db: &dyn SyntaxGroup,
info: &DeriveInfo,
stable_ptr: SyntaxStablePtrId,
result: &mut DeriveResult,
) {
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let header = info.format_impl_header(
"core::traits",
"Default",
Expand All @@ -34,15 +33,15 @@ pub fn handle_default(
Some((variant, variant.attributes.find_attr(db, DEFAULT_ATTR)?))
});
let Some((default_variant, _)) = default_variants.next() else {
result.diagnostics.push(PluginDiagnostic::error(
stable_ptr,
diagnostics.push(PluginDiagnostic::error(
derived,
"derive `Default` for enum only supported with a default variant.".into(),
));
return;
return None;
};
for (_, extra_default_attr) in default_variants {
result.diagnostics.push(PluginDiagnostic::error(
extra_default_attr.as_syntax_node().stable_ptr(),
diagnostics.push(PluginDiagnostic::error(
&extra_default_attr,
"Multiple variants annotated with `#[default]`".into(),
));
}
Expand All @@ -63,16 +62,16 @@ pub fn handle_default(
}
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"
Some(formatdoc! {"
{header} {{
fn default() -> {full_typename} {{
{body}
}}
}}
"});
"})
}
20 changes: 11 additions & 9 deletions crates/cairo-lang-plugins/src/plugins/derive/destruct.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_syntax::node::ast;
use indent::indent_by;
use indoc::formatdoc;
use itertools::Itertools;

use super::{unsupported_for_extern_diagnostic, DeriveInfo, DeriveResult};
use super::{unsupported_for_extern_diagnostic, DeriveInfo};
use crate::plugins::derive::TypeVariantInfo;

/// Adds derive result for the `Destruct` trait.
pub fn handle_destruct(
info: &DeriveInfo,
stable_ptr: SyntaxStablePtrId,
result: &mut DeriveResult,
) {
derived: &ast::ExprPath,
diagnostics: &mut Vec<PluginDiagnostic>,
) -> Option<String> {
let full_typename = info.full_typename();
let ty = &info.name;
let header = info.format_impl_header("core::traits", "Destruct", &["core::traits::Destruct"]);
Expand Down Expand Up @@ -45,16 +46,17 @@ pub fn handle_destruct(
)
}
TypeVariantInfo::Extern => {
result.diagnostics.push(unsupported_for_extern_diagnostic(stable_ptr));
return;
diagnostics.push(unsupported_for_extern_diagnostic(derived));
return None;
}
},
);
result.impls.push(formatdoc! {"

Some(formatdoc! {"
{header} {{
fn destruct(self: {full_typename}) nopanic {{
{body}
}}
}}
"});
"})
}
Loading