Skip to content

Commit

Permalink
Added destructors in cycle detectors. (#5335)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Mar 28, 2024
1 parent 0ceee6e commit 7bad82d
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 42 deletions.
114 changes: 83 additions & 31 deletions crates/cairo-lang-lowering/src/borrow_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
#[path = "test.rs"]
mod test;

use cairo_lang_defs::ids::ModuleFileId;
use cairo_lang_defs::ids::{ModuleFileId, TraitFunctionId};
use cairo_lang_diagnostics::{DiagnosticNote, Maybe};
use cairo_lang_semantic::corelib::get_core_trait;
use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::{zip_eq, Itertools};

use self::analysis::{Analyzer, StatementLocation};
Expand All @@ -14,7 +17,7 @@ use crate::borrow_check::analysis::BackAnalysis;
use crate::db::LoweringGroup;
use crate::diagnostic::LoweringDiagnosticKind::*;
use crate::diagnostic::LoweringDiagnostics;
use crate::ids::LocationId;
use crate::ids::{FunctionId, LocationId, SemanticFunctionIdEx};
use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};

pub mod analysis;
Expand All @@ -26,6 +29,9 @@ pub struct BorrowChecker<'a> {
diagnostics: &'a mut LoweringDiagnostics,
lowered: &'a FlatLowered,
success: Maybe<()>,
potential_destruct_calls: PotentialDestructCalls,
destruct_fn: TraitFunctionId,
panic_destruct_fn: TraitFunctionId,
}

/// A state saved for each position in the back analysis.
Expand Down Expand Up @@ -77,27 +83,49 @@ impl DropPosition {
impl<'a> DemandReporter<VariableId, PanicState> for BorrowChecker<'a> {
// Note that for in BorrowChecker `IntroducePosition` is used to pass the cause of
// the drop.
type IntroducePosition = Option<DropPosition>;
type IntroducePosition = (Option<DropPosition>, BlockId);
type UsePosition = LocationId;

fn drop_aux(
&mut self,
opt_drop_position: Option<DropPosition>,
(opt_drop_position, block_id): (Option<DropPosition>, BlockId),
var_id: VariableId,
panic_state: PanicState,
) {
let var = &self.lowered.variables[var_id];
let Err(drop_err) = var.droppable.clone() else {
return;
};
let Err(destruct_err) = var.destruct_impl.clone() else {
return;
let mut add_called_fn = |impl_id, function| {
self.potential_destruct_calls.entry(block_id).or_default().push(
self.db
.intern_function(cairo_lang_semantic::FunctionLongId {
function: cairo_lang_semantic::ConcreteFunction {
generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
impl_id,
function,
}),
generic_args: vec![],
},
})
.lowered(self.db),
);
};
let panic_destruct_err = if matches!(panic_state, PanicState::EndsWithPanic) {
let Err(panic_destruct_err) = var.panic_destruct_impl.clone() else {
let destruct_err = match var.destruct_impl.clone() {
Ok(impl_id) => {
add_called_fn(impl_id, self.destruct_fn);
return;
};
Some(panic_destruct_err)
}
Err(err) => err,
};
let panic_destruct_err = if matches!(panic_state, PanicState::EndsWithPanic) {
match var.panic_destruct_impl.clone() {
Ok(impl_id) => {
add_called_fn(impl_id, self.panic_destruct_fn);
return;
}
Err(err) => Some(err),
}
} else {
None
};
Expand Down Expand Up @@ -139,10 +167,10 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> {
fn visit_stmt(
&mut self,
info: &mut Self::Info,
_statement_location: StatementLocation,
(block_id, _): StatementLocation,
stmt: &Statement,
) {
info.variables_introduced(self, stmt.outputs(), None);
info.variables_introduced(self, stmt.outputs(), (None, block_id));
match stmt {
Statement::Call(stmt) => {
if let Ok(signature) = stmt.function.signature(self.db) {
Expand All @@ -152,11 +180,9 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> {
aux: PanicState::EndsWithPanic,
..Default::default()
};
let location = (Some(DropPosition::Panic(stmt.location)), block_id);
*info = BorrowCheckerDemand::merge_demands(
&[
(panic_demand, Some(DropPosition::Panic(stmt.location))),
(info.clone(), Some(DropPosition::Panic(stmt.location))),
],
&[(panic_demand, location), (info.clone(), location)],
self,
);
}
Expand Down Expand Up @@ -198,16 +224,16 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> {

fn merge_match(
&mut self,
_statement_location: StatementLocation,
(block_id, _): StatementLocation,
match_info: &MatchInfo,
infos: impl Iterator<Item = Self::Info>,
) -> Self::Info {
let infos: Vec<_> = infos.collect();
let arm_demands = zip_eq(match_info.arms(), &infos)
.map(|(arm, demand)| {
let mut demand = demand.clone();
demand.variables_introduced(self, &arm.var_ids, None);
(demand, Some(DropPosition::Diverge(*match_info.location())))
demand.variables_introduced(self, &arm.var_ids, (None, block_id));
(demand, (Some(DropPosition::Diverge(*match_info.location())), block_id))
})
.collect_vec();
let mut demand = BorrowCheckerDemand::merge_demands(&arm_demands, self);
Expand Down Expand Up @@ -242,27 +268,53 @@ impl<'a> Analyzer<'_> for BorrowChecker<'a> {
}
}

/// The possible destruct calls per block.
pub type PotentialDestructCalls = UnorderedHashMap<BlockId, Vec<FunctionId>>;

/// Report borrow checking diagnostics.
/// Returns the potential destruct function calls per block.
pub fn borrow_check(
db: &dyn LoweringGroup,
module_file_id: ModuleFileId,
lowered: &mut FlatLowered,
) {
) -> PotentialDestructCalls {
if lowered.blocks.has_root().is_err() {
return Default::default();
}
let mut diagnostics = LoweringDiagnostics::new(module_file_id.file_id(db.upcast()).unwrap());
diagnostics.diagnostics.extend(std::mem::take(&mut lowered.diagnostics));
let destruct_trait_id = get_core_trait(db.upcast(), "Destruct".into());
let destruct_fn =
db.trait_function_by_name(destruct_trait_id, "destruct".into()).unwrap().unwrap();
let panic_destruct_trait_id = get_core_trait(db.upcast(), "PanicDestruct".into());
let panic_destruct_fn = db
.trait_function_by_name(panic_destruct_trait_id, "panic_destruct".into())
.unwrap()
.unwrap();
let checker = BorrowChecker {
db,
diagnostics: &mut diagnostics,
lowered,
success: Ok(()),
potential_destruct_calls: Default::default(),
destruct_fn,
panic_destruct_fn,
};
let mut analysis = BackAnalysis::new(lowered, checker);
let mut root_demand = analysis.get_root_info();
root_demand.variables_introduced(
&mut analysis.analyzer,
&lowered.parameters,
(None, BlockId::root()),
);
let block_extra_calls = analysis.analyzer.potential_destruct_calls;
let success = analysis.analyzer.success;
assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");

if lowered.blocks.has_root().is_ok() {
let checker = BorrowChecker { db, diagnostics: &mut diagnostics, lowered, success: Ok(()) };
let mut analysis = BackAnalysis::new(lowered, checker);
let mut root_demand = analysis.get_root_info();
root_demand.variables_introduced(&mut analysis.analyzer, &lowered.parameters, None);
let success = analysis.analyzer.success;
assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");

if let Err(diag_added) = success {
lowered.blocks = Blocks::new_errored(diag_added);
}
if let Err(diag_added) = success {
lowered.blocks = Blocks::new_errored(diag_added);
}

lowered.diagnostics = diagnostics.build();
block_extra_calls
}
36 changes: 27 additions & 9 deletions crates/cairo-lang-lowering/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_semantic::{self as semantic, corelib, ConcreteTypeId, TypeId, TypeLongId};
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::{extract_matches, Upcast};
use defs::ids::NamedLanguageElementId;
Expand All @@ -17,12 +18,12 @@ use num_traits::ToPrimitive;
use semantic::items::constant::ConstValue;

use crate::add_withdraw_gas::add_withdraw_gas;
use crate::borrow_check::borrow_check;
use crate::borrow_check::{borrow_check, PotentialDestructCalls};
use crate::concretize::concretize_lowered;
use crate::destructs::add_destructs;
use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind};
use crate::graph_algorithms::feedback_set::flag_add_withdraw_gas;
use crate::ids::FunctionLongId;
use crate::ids::{FunctionId, FunctionLongId};
use crate::inline::get_inline_diagnostics;
use crate::lower::{lower_semantic_function, MultiLowering};
use crate::optimizations::config::OptimizationConfig;
Expand Down Expand Up @@ -68,6 +69,13 @@ pub trait LoweringGroup: SemanticGroup + Upcast<dyn SemanticGroup> {
function_id: ids::FunctionWithBodyId,
) -> Maybe<Arc<FlatLowered>>;

/// Computes the lowered representation of a function with a body.
/// Additionally applies borrow checking testing, and returns the possible calls per block.
fn function_with_body_lowering_with_borrow_check(
&self,
function_id: ids::FunctionWithBodyId,
) -> Maybe<(Arc<FlatLowered>, Arc<PotentialDestructCalls>)>;

/// Computes the lowered representation of a function with a body.
fn function_with_body_lowering(
&self,
Expand Down Expand Up @@ -375,14 +383,21 @@ fn priv_function_with_body_lowering(
Ok(Arc::new(lowered))
}

fn function_with_body_lowering(
fn function_with_body_lowering_with_borrow_check(
db: &dyn LoweringGroup,
function_id: ids::FunctionWithBodyId,
) -> Maybe<Arc<FlatLowered>> {
) -> Maybe<(Arc<FlatLowered>, Arc<PotentialDestructCalls>)> {
let mut lowered = (*db.priv_function_with_body_lowering(function_id)?).clone();
let module_file_id = function_id.base_semantic_function(db).module_file_id(db.upcast());
borrow_check(db, module_file_id, &mut lowered);
Ok(Arc::new(lowered))
let block_extra_calls = borrow_check(db, module_file_id, &mut lowered);
Ok((Arc::new(lowered), Arc::new(block_extra_calls)))
}

fn function_with_body_lowering(
db: &dyn LoweringGroup,
function_id: ids::FunctionWithBodyId,
) -> Maybe<Arc<FlatLowered>> {
Ok(db.function_with_body_lowering_with_borrow_check(function_id)?.0)
}

// * Concretizes lowered representation (monomorphization).
Expand Down Expand Up @@ -452,8 +467,8 @@ pub(crate) fn get_direct_callees(
db: &dyn LoweringGroup,
lowered_function: &FlatLowered,
dependency_type: DependencyType,
block_extra_calls: &UnorderedHashMap<BlockId, Vec<FunctionId>>,
) -> Vec<ids::FunctionId> {
// TODO(orizi): Follow calls for destructors as well.
let mut direct_callees = Vec::new();
if lowered_function.blocks.is_empty() {
return direct_callees;
Expand All @@ -478,6 +493,9 @@ pub(crate) fn get_direct_callees(
}
}
}
if let Some(extra_calls) = block_extra_calls.get(&block_id) {
direct_callees.extend(extra_calls.iter().copied());
}
match &block.end {
FlatBlockEnd::NotSet | FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(_) => {}
FlatBlockEnd::Goto(next, _) => stack.push(*next),
Expand Down Expand Up @@ -505,7 +523,7 @@ fn concrete_function_with_body_direct_callees(
dependency_type: DependencyType,
) -> Maybe<Vec<ids::FunctionId>> {
let lowered_function = db.priv_concrete_function_with_body_lowered_flat(function_id)?;
Ok(get_direct_callees(db, &lowered_function, dependency_type))
Ok(get_direct_callees(db, &lowered_function, dependency_type, &Default::default()))
}

fn concrete_function_with_body_postpanic_direct_callees(
Expand All @@ -514,7 +532,7 @@ fn concrete_function_with_body_postpanic_direct_callees(
dependency_type: DependencyType,
) -> Maybe<Vec<ids::FunctionId>> {
let lowered_function = db.concrete_function_with_body_postpanic_lowered(function_id)?;
Ok(get_direct_callees(db, &lowered_function, dependency_type))
Ok(get_direct_callees(db, &lowered_function, dependency_type, &Default::default()))
}

/// Given a vector of FunctionIds returns the vector of FunctionWithBodyIds of the
Expand Down
5 changes: 3 additions & 2 deletions crates/cairo-lang-lowering/src/graph_algorithms/cycles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ pub fn function_with_body_direct_callees(
function_id: FunctionWithBodyId,
dependency_type: DependencyType,
) -> Maybe<OrderedHashSet<FunctionId>> {
let lowered = db.function_with_body_lowering(function_id)?;
Ok(get_direct_callees(db, &lowered, dependency_type).into_iter().collect())
let (lowered, block_extra_calls) =
db.function_with_body_lowering_with_borrow_check(function_id)?;
Ok(get_direct_callees(db, &lowered, dependency_type, &block_extra_calls).into_iter().collect())
}

/// Query implementation of
Expand Down
50 changes: 50 additions & 0 deletions crates/cairo-lang-lowering/src/test_data/cycles
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,53 @@ blk2:
Statements:
End:
Return(v6, v7)

//! > ==========================================================================

//! > Test destructor basic cycle.

//! > test_runner_name
test_function_lowering

//! > function
fn foo() {}

//! > function_name
foo

//! > module_code
struct A {}
impl ADestruct of Destruct<A> {
fn destruct(self: A) nopanic {
let A { } = self;
B {};
}
}

struct B {}
impl BDestruct of Destruct<B> {
fn destruct(self: B) nopanic {
let B { } = self;
A {};
}
}

//! > semantic_diagnostics

//! > lowering_diagnostics
error: Call cycle of `nopanic` functions is not allowed.
--> lib.cairo:3:5
fn destruct(self: A) nopanic {
^****************************^

error: Call cycle of `nopanic` functions is not allowed.
--> lib.cairo:11:5
fn destruct(self: B) nopanic {
^****************************^

//! > lowering_flat
Parameters:
blk0 (root):
Statements:
End:
Return()
10 changes: 10 additions & 0 deletions crates/cairo-lang-lowering/src/test_data/destruct
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct A {}
impl ADestruct of Destruct<A> {
#[inline(never)]
fn destruct(self: A) nopanic {
let A { } = self;
// Use RangeCheck, a previously unused implicit.
match u128_overflowing_add(1_u128, 2_u128) {
Result::Ok(v) => v,
Expand Down Expand Up @@ -356,6 +357,15 @@ impl ADestruct of Destruct<A> {
//! > semantic_diagnostics

//! > lowering_diagnostics
error: Call cycle of `nopanic` functions is not allowed.
--> lib.cairo:3:5
#[inline(always)]
^***************^

error: Cannot inline a function that might call itself.
--> lib.cairo:3:5
#[inline(always)]
^***************^

//! > lowering_flat
Parameters:
Expand Down

0 comments on commit 7bad82d

Please sign in to comment.