Skip to content

Commit

Permalink
Remove assert on downcast, fixing a panic (#822)
Browse files Browse the repository at this point in the history
* Remove assert on downcast, you can downcast smaller types to bigger types in cairo

* optimize
  • Loading branch information
edg-l authored Oct 1, 2024
1 parent ae17dd3 commit c866fbf
Showing 1 changed file with 108 additions and 79 deletions.
187 changes: 108 additions & 79 deletions src/libfuncs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ pub fn build_downcast<'ctx, 'this>(
} else {
src_ty.integer_range(registry)?
};
assert!(
dst_range.lower > src_range.lower || dst_range.upper < src_range.upper,
"invalid downcast `{}` into `{}`: target range contains the source range",
info.signature.param_signatures[1].ty,
info.signature.branch_signatures[0].vars[1].ty
);

let src_width = if src_ty.is_bounded_int(registry)? {
src_range.offset_bit_width()
Expand Down Expand Up @@ -169,86 +163,121 @@ pub fn build_downcast<'ctx, 'this>(
src_value
};

let lower_check = if dst_range.lower > src_range.lower {
let dst_lower = entry.const_int_from_type(
context,
location,
dst_range.lower.clone(),
src_value.r#type(),
)?;
Some(entry.append_op_result(arith::cmpi(
context,
if !is_signed {
CmpiPredicate::Uge
} else {
CmpiPredicate::Sge
},
src_value,
dst_lower,
location,
))?)
} else {
None
};
let upper_check = if dst_range.upper < src_range.upper {
let dst_upper = entry.const_int_from_type(
context,
location,
dst_range.upper.clone(),
src_value.r#type(),
)?;
Some(entry.append_op_result(arith::cmpi(
context,
if !is_signed {
CmpiPredicate::Ult
} else {
CmpiPredicate::Slt
},
src_value,
dst_upper,
location,
))?)
} else {
None
};
if !(dst_range.lower > src_range.lower || dst_range.upper < src_range.upper) {
let dst_value = if dst_ty.is_bounded_int(registry)? && dst_range.lower != BigInt::ZERO {
let dst_offset = entry.const_int_from_type(
context,
location,
dst_range.lower.clone(),
src_value.r#type(),
)?;
entry.append_op_result(arith::subi(src_value, dst_offset, location))?
} else {
src_value
};

let is_in_bounds = match (lower_check, upper_check) {
(Some(lower_check), Some(upper_check)) => {
entry.append_op_result(arith::andi(lower_check, upper_check, location))?
}
(Some(lower_check), None) => lower_check,
(None, Some(upper_check)) => upper_check,
(None, None) => unreachable!(),
};
let dst_value = if dst_width < compute_width {
entry.append_op_result(arith::trunci(
dst_value,
IntegerType::new(context, dst_width).into(),
location,
))?
} else {
dst_value
};

let dst_value = if dst_ty.is_bounded_int(registry)? && dst_range.lower != BigInt::ZERO {
let dst_offset = entry.const_int_from_type(
let is_in_bounds = entry.const_int(context, location, 1, 1)?;

entry.append_operation(helper.cond_br(
context,
is_in_bounds,
[0, 1],
[&[range_check, dst_value], &[range_check]],
location,
dst_range.lower.clone(),
src_value.r#type(),
)?;
entry.append_op_result(arith::subi(src_value, dst_offset, location))?
));
} else {
src_value
};
let lower_check = if dst_range.lower > src_range.lower {
let dst_lower = entry.const_int_from_type(
context,
location,
dst_range.lower.clone(),
src_value.r#type(),
)?;
Some(entry.append_op_result(arith::cmpi(
context,
if !is_signed {
CmpiPredicate::Uge
} else {
CmpiPredicate::Sge
},
src_value,
dst_lower,
location,
))?)
} else {
None
};
let upper_check = if dst_range.upper < src_range.upper {
let dst_upper = entry.const_int_from_type(
context,
location,
dst_range.upper.clone(),
src_value.r#type(),
)?;
Some(entry.append_op_result(arith::cmpi(
context,
if !is_signed {
CmpiPredicate::Ult
} else {
CmpiPredicate::Slt
},
src_value,
dst_upper,
location,
))?)
} else {
None
};

let dst_value = if dst_width < compute_width {
entry.append_op_result(arith::trunci(
dst_value,
IntegerType::new(context, dst_width).into(),
let is_in_bounds = match (lower_check, upper_check) {
(Some(lower_check), Some(upper_check)) => {
entry.append_op_result(arith::andi(lower_check, upper_check, location))?
}
(Some(lower_check), None) => lower_check,
(None, Some(upper_check)) => upper_check,
// its always in bounds since dst is larger than src (i.e no bounds checks needed)
(None, None) => unreachable!(),
};

let dst_value = if dst_ty.is_bounded_int(registry)? && dst_range.lower != BigInt::ZERO {
let dst_offset = entry.const_int_from_type(
context,
location,
dst_range.lower.clone(),
src_value.r#type(),
)?;
entry.append_op_result(arith::subi(src_value, dst_offset, location))?
} else {
src_value
};

let dst_value = if dst_width < compute_width {
entry.append_op_result(arith::trunci(
dst_value,
IntegerType::new(context, dst_width).into(),
location,
))?
} else {
dst_value
};
entry.append_operation(helper.cond_br(
context,
is_in_bounds,
[0, 1],
[&[range_check, dst_value], &[range_check]],
location,
))?
} else {
dst_value
};
entry.append_operation(helper.cond_br(
context,
is_in_bounds,
[0, 1],
[&[range_check, dst_value], &[range_check]],
location,
));
));
}

Ok(())
}
Expand Down

0 comments on commit c866fbf

Please sign in to comment.