Skip to content

Commit

Permalink
add type to ConstValue (#5652)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerStarkware authored May 29, 2024
1 parent 0cba472 commit fab0b48
Show file tree
Hide file tree
Showing 24 changed files with 314 additions and 264 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub fn generic_completions(

fn resolved_generic_item_completion_kind(item: ResolvedGenericItem) -> CompletionItemKind {
match item {
ResolvedGenericItem::Constant(_) => CompletionItemKind::CONSTANT,
ResolvedGenericItem::GenericConstant(_) => CompletionItemKind::CONSTANT,
ResolvedGenericItem::Module(_) => CompletionItemKind::MODULE,
ResolvedGenericItem::GenericFunction(_) | ResolvedGenericItem::TraitFunction(_) => {
CompletionItemKind::FUNCTION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl SemanticTokenKind {
db.lookup_resolved_generic_item_by_ptr(lookup_item_id, identifier.stable_ptr())
{
return Some(match item {
ResolvedGenericItem::Constant(_) => SemanticTokenKind::EnumMember,
ResolvedGenericItem::GenericConstant(_) => SemanticTokenKind::EnumMember,
ResolvedGenericItem::Module(_) => SemanticTokenKind::Namespace,
ResolvedGenericItem::GenericFunction(_)
| ResolvedGenericItem::TraitFunction(_) => SemanticTokenKind::Function,
Expand All @@ -163,8 +163,7 @@ impl SemanticTokenKind {
ResolvedConcreteItem::Module(_) => SemanticTokenKind::Namespace,
ResolvedConcreteItem::Function(_)
| ResolvedConcreteItem::TraitFunction(_) => SemanticTokenKind::Function,
ResolvedConcreteItem::Type(_)
| ResolvedConcreteItem::ConstGenericParameter(_) => SemanticTokenKind::Type,
ResolvedConcreteItem::Type(_) => SemanticTokenKind::Type,
ResolvedConcreteItem::Variant(_) => SemanticTokenKind::EnumMember,
ResolvedConcreteItem::Trait(_) => SemanticTokenKind::Interface,
ResolvedConcreteItem::Impl(_) => SemanticTokenKind::Class,
Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-language-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ fn resolved_generic_item_def(
) -> SyntaxStablePtrId {
let defs_db = db.upcast();
match item {
ResolvedGenericItem::Constant(item) => item.untyped_stable_ptr(defs_db),
ResolvedGenericItem::GenericConstant(item) => item.untyped_stable_ptr(defs_db),
ResolvedGenericItem::Module(module_id) => {
// Check if the module is an inline submodule.
if let ModuleId::Submodule(submodule_id) = module_id {
Expand Down
5 changes: 4 additions & 1 deletion crates/cairo-lang-lowering/src/add_withdraw_gas/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ fn create_panic_block(
location,
}),
Statement::Const(StatementConst {
value: ConstValue::Int(BigInt::from_bytes_be(Sign::Plus, "Out of gas".as_bytes())),
value: ConstValue::Int(
BigInt::from_bytes_be(Sign::Plus, "Out of gas".as_bytes()),
core_felt252_ty(db.upcast()),
),
output: out_of_gas_err_var,
}),
Statement::Call(StatementCall {
Expand Down
10 changes: 7 additions & 3 deletions crates/cairo-lang-lowering/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ use cairo_lang_semantic::{self as semantic, corelib, ConcreteTypeId, TypeId, Typ
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::{extract_matches, Intern, LookupIntern, Upcast};
use cairo_lang_utils::{Intern, LookupIntern, Upcast};
use defs::ids::NamedLanguageElementId;
use itertools::Itertools;
use num_traits::ToPrimitive;
use semantic::items::constant::ConstValue;

use crate::add_withdraw_gas::add_withdraw_gas;
use crate::borrow_check::{borrow_check, PotentialDestructCalls};
Expand Down Expand Up @@ -760,7 +759,12 @@ fn type_size(db: &dyn LoweringGroup, ty: TypeId) -> usize {
TypeLongId::Snapshot(ty) => db.type_size(ty),
TypeLongId::FixedSizeArray { type_id, size } => {
db.type_size(type_id)
* extract_matches!(size.lookup_intern(db), ConstValue::Int).to_usize().unwrap()
* size
.lookup_intern(db)
.into_int()
.expect("Expected ConstValue::Int for size")
.to_usize()
.unwrap()
}
TypeLongId::Coupon(_) => 0,
TypeLongId::GenericParameter(_)
Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-lowering/src/lower/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl StatementsBuilder {
pub struct Const {
pub value: ConstValue,
pub location: LocationId,
// TODO(TomerStarkware): Remove this field and use the type from value.
pub ty: semantic::TypeId,
}
impl Const {
Expand Down
19 changes: 13 additions & 6 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,10 @@ fn lower_tuple_like_pattern_helper(
let tys = match long_type_id {
TypeLongId::Tuple(tys) => tys,
TypeLongId::FixedSizeArray { type_id, size } => {
let size = extract_matches!(size.lookup_intern(ctx.db), ConstValue::Int)
let size = size
.lookup_intern(ctx.db)
.into_int()
.expect("Expected ConstValue::Int for size")
.to_usize()
.unwrap();
vec![type_id; size]
Expand Down Expand Up @@ -874,7 +877,7 @@ fn add_chunks_to_data_array<'a>(
let remainder = chunks.remainder();
for chunk in chunks {
let chunk_usage = generators::Const {
value: ConstValue::Int(BigInt::from_bytes_be(Sign::Plus, chunk)),
value: ConstValue::Int(BigInt::from_bytes_be(Sign::Plus, chunk), bytes31_ty),
ty: bytes31_ty,
location: ctx.get_location(expr_stable_ptr),
}
Expand Down Expand Up @@ -908,15 +911,15 @@ fn add_pending_word(
let felt252_ty = core_felt252_ty(ctx.db.upcast());

let pending_word_usage = generators::Const {
value: ConstValue::Int(BigInt::from_bytes_be(Sign::Plus, pending_word_bytes)),
value: ConstValue::Int(BigInt::from_bytes_be(Sign::Plus, pending_word_bytes), felt252_ty),
ty: felt252_ty,
location: ctx.get_location(expr_stable_ptr),
}
.add(ctx, &mut builder.statements);

let pending_word_len = expr.value.len() % 31;
let pending_word_len_usage = generators::Const {
value: ConstValue::Int(pending_word_len.into()),
value: ConstValue::Int(pending_word_len.into(), u32_ty),
ty: u32_ty,
location: ctx.get_location(expr_stable_ptr),
}
Expand Down Expand Up @@ -971,8 +974,12 @@ fn lower_expr_fixed_size_array(
semantic::FixedSizeArrayItems::ValueAndSize(value, size) => {
let lowered_value = lower_expr(ctx, builder, *value)?;
let var_usage = lowered_value.as_var_usage(ctx, builder)?;
let size =
extract_matches!(size.lookup_intern(ctx.db), ConstValue::Int).to_usize().unwrap();
let size = size
.lookup_intern(ctx.db)
.into_int()
.expect("Expected ConstValue::Int for size")
.to_usize()
.unwrap();
if size == 0 {
return Err(LoweringFlowError::Failed(
ctx.diagnostics
Expand Down
43 changes: 23 additions & 20 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
mod test;

use cairo_lang_defs::ids::ModuleItemId;
use cairo_lang_semantic::corelib;
use cairo_lang_semantic::corelib::{self};
use cairo_lang_semantic::items::constant::ConstValue;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::{extract_matches, Intern};
use cairo_lang_utils::Intern;
use itertools::{chain, zip_eq};
use num_traits::Zero;

Expand Down Expand Up @@ -89,20 +89,22 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
Statement::Snapshot(stmt) => {
if let Some(VarInfo::Const(val)) = var_info.get(&stmt.input.var_id) {
let val = val.clone();
// TODO(Tomerstarkware): add snapshot to value type.
var_info.insert(stmt.original(), VarInfo::Const(val.clone()));
var_info.insert(stmt.snapshot(), VarInfo::Const(val));
}
}
Statement::Desnap(StatementDesnap { input, output }) => {
if let Some(VarInfo::Const(val)) = var_info.get(&input.var_id) {
let val = val.clone();
// TODO(Tomerstarkware): remove snapshot from value type.
var_info.insert(*output, VarInfo::Const(val));
}
}
Statement::Call(StatementCall { function, ref mut inputs, outputs, .. }) => {
// (a - 0) can be replaced by a.
if function == &felt_sub {
if let Some(VarInfo::Const(ConstValue::Int(val))) =
if let Some(VarInfo::Const(ConstValue::Int(val, _))) =
var_info.get(&inputs[1].var_id)
{
if val.is_zero() {
Expand All @@ -112,18 +114,19 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
} else if let Some(extrn) = function.get_extern(db) {
if extrn == into_box {
if let Some(VarInfo::Const(val)) = var_info.get(&inputs[0].var_id) {
let value = ConstValue::Boxed(
lowered.variables[inputs[0].var_id].ty,
val.clone().into(),
);
let value = ConstValue::Boxed(val.clone().into());
// Not inserting the value into the `var_info` map because the
// resulting box isn't an actual const at the Sierra level.
*stmt =
Statement::Const(StatementConst { value, output: outputs[0] });
}
} else if extrn == upcast {
if let Some(VarInfo::Const(value)) = var_info.get(&inputs[0].var_id) {
let value = value.clone();
let value = ConstValue::Int(
value.clone().into_int().unwrap(),
lowered.variables[outputs[0]].ty,
);

var_info.insert(outputs[0], VarInfo::Const(value.clone()));
*stmt =
Statement::Const(StatementConst { value, output: outputs[0] });
Expand All @@ -136,22 +139,22 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
.iter()
.map(|input| {
if let Some(VarInfo::Const(val)) = var_info.get(&input.var_id) {
Some((lowered.variables[input.var_id].ty, val.clone()))
Some(val.clone())
} else {
None
}
})
.collect::<Option<Vec<_>>>()
{
let value = ConstValue::Struct(args);
var_info.insert(*output, VarInfo::Const(value.clone()));
let value = ConstValue::Struct(args, lowered.variables[*output].ty);
var_info.insert(*output, VarInfo::Const(value));
}
}
Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
if let Some(VarInfo::Const(ConstValue::Struct(args))) =
if let Some(VarInfo::Const(ConstValue::Struct(args, _))) =
var_info.get(&input.var_id)
{
for (output, (_, val)) in zip_eq(outputs, args.clone()) {
for (output, val) in zip_eq(outputs, args.clone()) {
var_info.insert(*output, VarInfo::Const(val));
}
}
Expand Down Expand Up @@ -189,9 +192,12 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
let input_var = inputs[0].var_id;
if let Some(VarInfo::Const(val)) = var_info.get(&input_var) {
let is_zero = match val {
ConstValue::Int(v) => v.is_zero(),
ConstValue::Struct(s) => s.iter().all(|(_, v)| {
extract_matches!(v, ConstValue::Int).is_zero()
ConstValue::Int(v, _) => v.is_zero(),
ConstValue::Struct(s, _) => s.iter().all(|v| {
v.clone()
.into_int()
.expect("Expected ConstValue::Int for size")
.is_zero()
}),
_ => unreachable!(),
};
Expand All @@ -201,10 +207,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
} else {
let arm = &arms[1];
let nz_var = arm.var_ids[0];
let nz_val = ConstValue::NonZero(
lowered.variables[input_var].ty,
Box::new(val.clone()),
);
let nz_val = ConstValue::NonZero(Box::new(val.clone()));
var_info.insert(nz_var, VarInfo::Const(nz_val.clone()));
block.statements.push(Statement::Const(StatementConst {
value: nz_val,
Expand Down
17 changes: 14 additions & 3 deletions crates/cairo-lang-semantic/src/corelib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ pub fn core_felt252_ty(db: &dyn SemanticGroup) -> TypeId {
/// Returns the concrete type of a bounded int type with a given min and max.
pub fn bounded_int_ty(db: &dyn SemanticGroup, min: BigInt, max: BigInt) -> TypeId {
let internal = core_submodule(db, "internal");
let lower_id = ConstValue::Int(min).intern(db);
let upper_id = ConstValue::Int(max).intern(db);
let size_ty = core_felt252_ty(db);
let lower_id = ConstValue::Int(min, size_ty).intern(db);
let upper_id = ConstValue::Int(max, size_ty).intern(db);
try_get_ty_by_name(
db,
internal,
Expand Down Expand Up @@ -111,6 +112,15 @@ pub fn core_option_ty(db: &dyn SemanticGroup, some_type: TypeId) -> TypeId {
)
}

pub fn core_box_ty(db: &dyn SemanticGroup, inner_type: TypeId) -> TypeId {
get_ty_by_name(
db,
core_submodule(db, "box"),
"Box".into(),
vec![GenericArgumentId::Type(inner_type)],
)
}

pub fn core_array_felt252_ty(db: &dyn SemanticGroup) -> TypeId {
get_core_ty_by_name(db, "Array".into(), vec![GenericArgumentId::Type(core_felt252_ty(db))])
}
Expand Down Expand Up @@ -831,6 +841,7 @@ fn try_extract_bounded_int_type_ranges(
else {
return None;
};
let to_int = |id| try_extract_matches!(db.lookup_intern_const_value(id), ConstValue::Int);
let to_int = |id| db.lookup_intern_const_value(id).into_int();

Some((to_int(min)?, to_int(max)?))
}
1 change: 0 additions & 1 deletion crates/cairo-lang-semantic/src/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,6 @@ impl From<&ResolvedConcreteItem> for ElementKind {
fn from(val: &ResolvedConcreteItem) -> Self {
match val {
ResolvedConcreteItem::Constant(_) => ElementKind::Constant,
ResolvedConcreteItem::ConstGenericParameter(_) => ElementKind::Constant,
ResolvedConcreteItem::Module(_) => ElementKind::Module,
ResolvedConcreteItem::Function(_) => ElementKind::Function,
ResolvedConcreteItem::TraitFunction(_) => ElementKind::TraitFunction,
Expand Down
Loading

0 comments on commit fab0b48

Please sign in to comment.