From 7bad82de2538c8a3c5ab8c87a7a3fa3ebb3084c7 Mon Sep 17 00:00:00 2001 From: orizi <104711814+orizi@users.noreply.github.com> Date: Thu, 28 Mar 2024 10:48:26 +0200 Subject: [PATCH] Added destructors in cycle detectors. (#5335) --- .../src/borrow_check/mod.rs | 114 +++++++++++++----- crates/cairo-lang-lowering/src/db.rs | 36 ++++-- .../src/graph_algorithms/cycles.rs | 5 +- .../cairo-lang-lowering/src/test_data/cycles | 50 ++++++++ .../src/test_data/destruct | 10 ++ 5 files changed, 173 insertions(+), 42 deletions(-) diff --git a/crates/cairo-lang-lowering/src/borrow_check/mod.rs b/crates/cairo-lang-lowering/src/borrow_check/mod.rs index dd86a997a36..5b2e6a1e58a 100644 --- a/crates/cairo-lang-lowering/src/borrow_check/mod.rs +++ b/crates/cairo-lang-lowering/src/borrow_check/mod.rs @@ -2,8 +2,11 @@ #[path = "test.rs"] mod test; -use cairo_lang_defs::ids::ModuleFileId; +use cairo_lang_defs::ids::{ModuleFileId, TraitFunctionId}; use cairo_lang_diagnostics::{DiagnosticNote, Maybe}; +use cairo_lang_semantic::corelib::get_core_trait; +use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId}; +use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; use itertools::{zip_eq, Itertools}; use self::analysis::{Analyzer, StatementLocation}; @@ -14,7 +17,7 @@ use crate::borrow_check::analysis::BackAnalysis; use crate::db::LoweringGroup; use crate::diagnostic::LoweringDiagnosticKind::*; use crate::diagnostic::LoweringDiagnostics; -use crate::ids::LocationId; +use crate::ids::{FunctionId, LocationId, SemanticFunctionIdEx}; use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId}; pub mod analysis; @@ -26,6 +29,9 @@ pub struct BorrowChecker<'a> { diagnostics: &'a mut LoweringDiagnostics, lowered: &'a FlatLowered, success: Maybe<()>, + potential_destruct_calls: PotentialDestructCalls, + destruct_fn: TraitFunctionId, + panic_destruct_fn: TraitFunctionId, } /// A state saved for each position in the back analysis. @@ -77,12 +83,12 @@ impl DropPosition { impl<'a> DemandReporter for BorrowChecker<'a> { // Note that for in BorrowChecker `IntroducePosition` is used to pass the cause of // the drop. - type IntroducePosition = Option; + type IntroducePosition = (Option, BlockId); type UsePosition = LocationId; fn drop_aux( &mut self, - opt_drop_position: Option, + (opt_drop_position, block_id): (Option, BlockId), var_id: VariableId, panic_state: PanicState, ) { @@ -90,14 +96,36 @@ impl<'a> DemandReporter for BorrowChecker<'a> { let Err(drop_err) = var.droppable.clone() else { return; }; - let Err(destruct_err) = var.destruct_impl.clone() else { - return; + let mut add_called_fn = |impl_id, function| { + self.potential_destruct_calls.entry(block_id).or_default().push( + self.db + .intern_function(cairo_lang_semantic::FunctionLongId { + function: cairo_lang_semantic::ConcreteFunction { + generic_function: GenericFunctionId::Impl(ImplGenericFunctionId { + impl_id, + function, + }), + generic_args: vec![], + }, + }) + .lowered(self.db), + ); }; - let panic_destruct_err = if matches!(panic_state, PanicState::EndsWithPanic) { - let Err(panic_destruct_err) = var.panic_destruct_impl.clone() else { + let destruct_err = match var.destruct_impl.clone() { + Ok(impl_id) => { + add_called_fn(impl_id, self.destruct_fn); return; - }; - Some(panic_destruct_err) + } + Err(err) => err, + }; + let panic_destruct_err = if matches!(panic_state, PanicState::EndsWithPanic) { + match var.panic_destruct_impl.clone() { + Ok(impl_id) => { + add_called_fn(impl_id, self.panic_destruct_fn); + return; + } + Err(err) => Some(err), + } } else { None }; @@ -139,10 +167,10 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> { fn visit_stmt( &mut self, info: &mut Self::Info, - _statement_location: StatementLocation, + (block_id, _): StatementLocation, stmt: &Statement, ) { - info.variables_introduced(self, stmt.outputs(), None); + info.variables_introduced(self, stmt.outputs(), (None, block_id)); match stmt { Statement::Call(stmt) => { if let Ok(signature) = stmt.function.signature(self.db) { @@ -152,11 +180,9 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> { aux: PanicState::EndsWithPanic, ..Default::default() }; + let location = (Some(DropPosition::Panic(stmt.location)), block_id); *info = BorrowCheckerDemand::merge_demands( - &[ - (panic_demand, Some(DropPosition::Panic(stmt.location))), - (info.clone(), Some(DropPosition::Panic(stmt.location))), - ], + &[(panic_demand, location), (info.clone(), location)], self, ); } @@ -198,7 +224,7 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> { fn merge_match( &mut self, - _statement_location: StatementLocation, + (block_id, _): StatementLocation, match_info: &MatchInfo, infos: impl Iterator, ) -> Self::Info { @@ -206,8 +232,8 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> { let arm_demands = zip_eq(match_info.arms(), &infos) .map(|(arm, demand)| { let mut demand = demand.clone(); - demand.variables_introduced(self, &arm.var_ids, None); - (demand, Some(DropPosition::Diverge(*match_info.location()))) + demand.variables_introduced(self, &arm.var_ids, (None, block_id)); + (demand, (Some(DropPosition::Diverge(*match_info.location())), block_id)) }) .collect_vec(); let mut demand = BorrowCheckerDemand::merge_demands(&arm_demands, self); @@ -242,27 +268,53 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> { } } +/// The possible destruct calls per block. +pub type PotentialDestructCalls = UnorderedHashMap>; + /// Report borrow checking diagnostics. +/// Returns the potential destruct function calls per block. pub fn borrow_check( db: &dyn LoweringGroup, module_file_id: ModuleFileId, lowered: &mut FlatLowered, -) { +) -> PotentialDestructCalls { + if lowered.blocks.has_root().is_err() { + return Default::default(); + } let mut diagnostics = LoweringDiagnostics::new(module_file_id.file_id(db.upcast()).unwrap()); diagnostics.diagnostics.extend(std::mem::take(&mut lowered.diagnostics)); + let destruct_trait_id = get_core_trait(db.upcast(), "Destruct".into()); + let destruct_fn = + db.trait_function_by_name(destruct_trait_id, "destruct".into()).unwrap().unwrap(); + let panic_destruct_trait_id = get_core_trait(db.upcast(), "PanicDestruct".into()); + let panic_destruct_fn = db + .trait_function_by_name(panic_destruct_trait_id, "panic_destruct".into()) + .unwrap() + .unwrap(); + let checker = BorrowChecker { + db, + diagnostics: &mut diagnostics, + lowered, + success: Ok(()), + potential_destruct_calls: Default::default(), + destruct_fn, + panic_destruct_fn, + }; + let mut analysis = BackAnalysis::new(lowered, checker); + let mut root_demand = analysis.get_root_info(); + root_demand.variables_introduced( + &mut analysis.analyzer, + &lowered.parameters, + (None, BlockId::root()), + ); + let block_extra_calls = analysis.analyzer.potential_destruct_calls; + let success = analysis.analyzer.success; + assert!(root_demand.finalize(), "Undefined variable should not happen at this stage"); - if lowered.blocks.has_root().is_ok() { - let checker = BorrowChecker { db, diagnostics: &mut diagnostics, lowered, success: Ok(()) }; - let mut analysis = BackAnalysis::new(lowered, checker); - let mut root_demand = analysis.get_root_info(); - root_demand.variables_introduced(&mut analysis.analyzer, &lowered.parameters, None); - let success = analysis.analyzer.success; - assert!(root_demand.finalize(), "Undefined variable should not happen at this stage"); - - if let Err(diag_added) = success { - lowered.blocks = Blocks::new_errored(diag_added); - } + if let Err(diag_added) = success { + lowered.blocks = Blocks::new_errored(diag_added); } lowered.diagnostics = diagnostics.build(); + block_extra_calls } diff --git a/crates/cairo-lang-lowering/src/db.rs b/crates/cairo-lang-lowering/src/db.rs index aa21a4e97d7..b9c275291a0 100644 --- a/crates/cairo-lang-lowering/src/db.rs +++ b/crates/cairo-lang-lowering/src/db.rs @@ -9,6 +9,7 @@ use cairo_lang_semantic::items::enm::SemanticEnumEx; use cairo_lang_semantic::items::structure::SemanticStructEx; use cairo_lang_semantic::{self as semantic, corelib, ConcreteTypeId, TypeId, TypeLongId}; use cairo_lang_utils::ordered_hash_set::OrderedHashSet; +use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; use cairo_lang_utils::unordered_hash_set::UnorderedHashSet; use cairo_lang_utils::{extract_matches, Upcast}; use defs::ids::NamedLanguageElementId; @@ -17,12 +18,12 @@ use num_traits::ToPrimitive; use semantic::items::constant::ConstValue; use crate::add_withdraw_gas::add_withdraw_gas; -use crate::borrow_check::borrow_check; +use crate::borrow_check::{borrow_check, PotentialDestructCalls}; use crate::concretize::concretize_lowered; use crate::destructs::add_destructs; use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind}; use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas; -use crate::ids::FunctionLongId; +use crate::ids::{FunctionId, FunctionLongId}; use crate::inline::get_inline_diagnostics; use crate::lower::{lower_semantic_function, MultiLowering}; use crate::optimizations::config::OptimizationConfig; @@ -68,6 +69,13 @@ pub trait LoweringGroup: SemanticGroup + Upcast { function_id: ids::FunctionWithBodyId, ) -> Maybe>; + /// Computes the lowered representation of a function with a body. + /// Additionally applies borrow checking testing, and returns the possible calls per block. + fn function_with_body_lowering_with_borrow_check( + &self, + function_id: ids::FunctionWithBodyId, + ) -> Maybe<(Arc, Arc)>; + /// Computes the lowered representation of a function with a body. fn function_with_body_lowering( &self, @@ -375,14 +383,21 @@ fn priv_function_with_body_lowering( Ok(Arc::new(lowered)) } -fn function_with_body_lowering( +fn function_with_body_lowering_with_borrow_check( db: &dyn LoweringGroup, function_id: ids::FunctionWithBodyId, -) -> Maybe> { +) -> Maybe<(Arc, Arc)> { let mut lowered = (*db.priv_function_with_body_lowering(function_id)?).clone(); let module_file_id = function_id.base_semantic_function(db).module_file_id(db.upcast()); - borrow_check(db, module_file_id, &mut lowered); - Ok(Arc::new(lowered)) + let block_extra_calls = borrow_check(db, module_file_id, &mut lowered); + Ok((Arc::new(lowered), Arc::new(block_extra_calls))) +} + +fn function_with_body_lowering( + db: &dyn LoweringGroup, + function_id: ids::FunctionWithBodyId, +) -> Maybe> { + Ok(db.function_with_body_lowering_with_borrow_check(function_id)?.0) } // * Concretizes lowered representation (monomorphization). @@ -452,8 +467,8 @@ pub(crate) fn get_direct_callees( db: &dyn LoweringGroup, lowered_function: &FlatLowered, dependency_type: DependencyType, + block_extra_calls: &UnorderedHashMap>, ) -> Vec { - // TODO(orizi): Follow calls for destructors as well. let mut direct_callees = Vec::new(); if lowered_function.blocks.is_empty() { return direct_callees; @@ -478,6 +493,9 @@ pub(crate) fn get_direct_callees( } } } + if let Some(extra_calls) = block_extra_calls.get(&block_id) { + direct_callees.extend(extra_calls.iter().copied()); + } match &block.end { FlatBlockEnd::NotSet | FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(_) => {} FlatBlockEnd::Goto(next, _) => stack.push(*next), @@ -505,7 +523,7 @@ fn concrete_function_with_body_direct_callees( dependency_type: DependencyType, ) -> Maybe> { let lowered_function = db.priv_concrete_function_with_body_lowered_flat(function_id)?; - Ok(get_direct_callees(db, &lowered_function, dependency_type)) + Ok(get_direct_callees(db, &lowered_function, dependency_type, &Default::default())) } fn concrete_function_with_body_postpanic_direct_callees( @@ -514,7 +532,7 @@ fn concrete_function_with_body_postpanic_direct_callees( dependency_type: DependencyType, ) -> Maybe> { let lowered_function = db.concrete_function_with_body_postpanic_lowered(function_id)?; - Ok(get_direct_callees(db, &lowered_function, dependency_type)) + Ok(get_direct_callees(db, &lowered_function, dependency_type, &Default::default())) } /// Given a vector of FunctionIds returns the vector of FunctionWithBodyIds of the diff --git a/crates/cairo-lang-lowering/src/graph_algorithms/cycles.rs b/crates/cairo-lang-lowering/src/graph_algorithms/cycles.rs index 1efea551e19..b1cdd57d377 100644 --- a/crates/cairo-lang-lowering/src/graph_algorithms/cycles.rs +++ b/crates/cairo-lang-lowering/src/graph_algorithms/cycles.rs @@ -12,8 +12,9 @@ pub fn function_with_body_direct_callees( function_id: FunctionWithBodyId, dependency_type: DependencyType, ) -> Maybe> { - let lowered = db.function_with_body_lowering(function_id)?; - Ok(get_direct_callees(db, &lowered, dependency_type).into_iter().collect()) + let (lowered, block_extra_calls) = + db.function_with_body_lowering_with_borrow_check(function_id)?; + Ok(get_direct_callees(db, &lowered, dependency_type, &block_extra_calls).into_iter().collect()) } /// Query implementation of diff --git a/crates/cairo-lang-lowering/src/test_data/cycles b/crates/cairo-lang-lowering/src/test_data/cycles index 1f2e8dde123..6bb863684f4 100644 --- a/crates/cairo-lang-lowering/src/test_data/cycles +++ b/crates/cairo-lang-lowering/src/test_data/cycles @@ -137,3 +137,53 @@ blk2: Statements: End: Return(v6, v7) + +//! > ========================================================================== + +//! > Test destructor basic cycle. + +//! > test_runner_name +test_function_lowering + +//! > function +fn foo() {} + +//! > function_name +foo + +//! > module_code +struct A {} +impl ADestruct of Destruct { + fn destruct(self: A) nopanic { + let A { } = self; + B {}; + } +} + +struct B {} +impl BDestruct of Destruct { + fn destruct(self: B) nopanic { + let B { } = self; + A {}; + } +} + +//! > semantic_diagnostics + +//! > lowering_diagnostics +error: Call cycle of `nopanic` functions is not allowed. + --> lib.cairo:3:5 + fn destruct(self: A) nopanic { + ^****************************^ + +error: Call cycle of `nopanic` functions is not allowed. + --> lib.cairo:11:5 + fn destruct(self: B) nopanic { + ^****************************^ + +//! > lowering_flat +Parameters: +blk0 (root): +Statements: +End: + Return() diff --git a/crates/cairo-lang-lowering/src/test_data/destruct b/crates/cairo-lang-lowering/src/test_data/destruct index ccedb7a73a6..4d61ec9bd8c 100644 --- a/crates/cairo-lang-lowering/src/test_data/destruct +++ b/crates/cairo-lang-lowering/src/test_data/destruct @@ -22,6 +22,7 @@ struct A {} impl ADestruct of Destruct { #[inline(never)] fn destruct(self: A) nopanic { + let A { } = self; // Use RangeCheck, a previously unused implicit. match u128_overflowing_add(1_u128, 2_u128) { Result::Ok(v) => v, @@ -356,6 +357,15 @@ impl ADestruct of Destruct { //! > semantic_diagnostics //! > lowering_diagnostics +error: Call cycle of `nopanic` functions is not allowed. + --> lib.cairo:3:5 + #[inline(always)] + ^***************^ + +error: Cannot inline a function that might call itself. + --> lib.cairo:3:5 + #[inline(always)] + ^***************^ //! > lowering_flat Parameters: