Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BinOp ty() assertion and fn_sig() for closures #118846

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions compiler/rustc_smir/src/rustc_smir/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,25 @@ impl<'tcx> Context for TablesWrapper<'tcx> {
def.internal(&mut *tables).is_box()
}

fn adt_is_simd(&self, def: AdtDef) -> bool {
let mut tables = self.0.borrow_mut();
def.internal(&mut *tables).repr().simd()
}

fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig {
let mut tables = self.0.borrow_mut();
let def_id = def.0.internal(&mut *tables);
let sig = tables.tcx.fn_sig(def_id).instantiate(tables.tcx, args.internal(&mut *tables));
sig.stable(&mut *tables)
}

fn closure_sig(&self, args: &GenericArgs) -> PolyFnSig {
let mut tables = self.0.borrow_mut();
let args_ref = args.internal(&mut *tables);
let sig = args_ref.as_closure().sig();
sig.stable(&mut *tables)
}

fn adt_variants_len(&self, def: AdtDef) -> usize {
let mut tables = self.0.borrow_mut();
def.internal(&mut *tables).variants().len()
Expand Down
6 changes: 6 additions & 0 deletions compiler/stable_mir/src/compiler_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ pub trait Context {
/// Returns if the ADT is a box.
fn adt_is_box(&self, def: AdtDef) -> bool;

/// Returns whether this ADT is simd.
fn adt_is_simd(&self, def: AdtDef) -> bool;

/// Retrieve the function signature for the given generic arguments.
fn fn_sig(&self, def: FnDef, args: &GenericArgs) -> PolyFnSig;

/// Retrieve the closure signature for the given generic arguments.
fn closure_sig(&self, args: &GenericArgs) -> PolyFnSig;

/// The number of variants in this ADT.
fn adt_variants_len(&self, def: AdtDef) -> usize;

Expand Down
30 changes: 19 additions & 11 deletions compiler/stable_mir/src/mir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ pub struct InlineAsmOperand {
pub raw_rpr: String,
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum UnwindAction {
Continue,
Unreachable,
Expand All @@ -248,7 +248,7 @@ pub enum AssertMessage {
MisalignedPointerDereference { required: Operand, found: Operand },
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum BinOp {
Add,
AddUnchecked,
Expand Down Expand Up @@ -278,8 +278,6 @@ impl BinOp {
/// Return the type of this operation for the given input Ty.
/// This function does not perform type checking, and it currently doesn't handle SIMD.
pub fn ty(&self, lhs_ty: Ty, rhs_ty: Ty) -> Ty {
assert!(lhs_ty.kind().is_primitive());
assert!(rhs_ty.kind().is_primitive());
match self {
BinOp::Add
| BinOp::AddUnchecked
Expand All @@ -293,20 +291,30 @@ impl BinOp {
| BinOp::BitAnd
| BinOp::BitOr => {
assert_eq!(lhs_ty, rhs_ty);
assert!(lhs_ty.kind().is_primitive());
lhs_ty
}
BinOp::Shl | BinOp::ShlUnchecked | BinOp::Shr | BinOp::ShrUnchecked | BinOp::Offset => {
BinOp::Shl | BinOp::ShlUnchecked | BinOp::Shr | BinOp::ShrUnchecked => {
assert!(lhs_ty.kind().is_primitive());
assert!(rhs_ty.kind().is_primitive());
lhs_ty
}
BinOp::Offset => {
assert!(lhs_ty.kind().is_raw_ptr());
assert!(rhs_ty.kind().is_integral());
lhs_ty
}
BinOp::Eq | BinOp::Lt | BinOp::Le | BinOp::Ne | BinOp::Ge | BinOp::Gt => {
assert_eq!(lhs_ty, rhs_ty);
let lhs_kind = lhs_ty.kind();
assert!(lhs_kind.is_primitive() || lhs_kind.is_raw_ptr() || lhs_kind.is_fn_ptr());
Ty::bool_ty()
}
}
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum UnOp {
Not,
Neg,
Expand All @@ -319,7 +327,7 @@ pub enum CoroutineKind {
Gen(CoroutineSource),
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum CoroutineSource {
Block,
Closure,
Expand All @@ -343,15 +351,15 @@ pub enum FakeReadCause {
}

/// Describes what kind of retag is to be performed
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum RetagKind {
FnEntry,
TwoPhase,
Raw,
Default,
}

#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum Variance {
Covariant,
Invariant,
Expand Down Expand Up @@ -862,7 +870,7 @@ pub enum Safety {
Normal,
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum PointerCoercion {
/// Go from a fn-item type to a fn-pointer type.
ReifyFnPointer,
Expand All @@ -889,7 +897,7 @@ pub enum PointerCoercion {
Unsize,
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum CastKind {
PointerExposeAddress,
PointerFromExposedAddress,
Expand Down
121 changes: 119 additions & 2 deletions compiler/stable_mir/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,38 +214,62 @@ impl TyKind {
if let TyKind::RigidTy(inner) = self { Some(inner) } else { None }
}

#[inline]
pub fn is_unit(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Tuple(data)) if data.is_empty())
}

#[inline]
pub fn is_bool(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Bool))
}

#[inline]
pub fn is_char(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Char))
}

#[inline]
pub fn is_trait(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Dynamic(_, _, DynKind::Dyn)))
}

#[inline]
pub fn is_enum(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Enum)
}

#[inline]
pub fn is_struct(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Struct)
}

#[inline]
pub fn is_union(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.kind() == AdtKind::Union)
}

#[inline]
pub fn is_adt(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Adt(..)))
}

#[inline]
pub fn is_ref(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Ref(..)))
}

#[inline]
pub fn is_fn(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::FnDef(..)))
}

#[inline]
pub fn is_fn_ptr(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::FnPtr(..)))
}

#[inline]
pub fn is_primitive(&self) -> bool {
matches!(
self,
Expand All @@ -259,6 +283,84 @@ impl TyKind {
)
}

#[inline]
pub fn is_float(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Float(_)))
}

#[inline]
pub fn is_integral(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Int(_) | RigidTy::Uint(_)))
}

#[inline]
pub fn is_numeric(&self) -> bool {
self.is_integral() || self.is_float()
}

#[inline]
pub fn is_signed(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Int(_)))
}

#[inline]
pub fn is_str(&self) -> bool {
*self == TyKind::RigidTy(RigidTy::Str)
}

#[inline]
pub fn is_slice(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Slice(_)))
}

#[inline]
pub fn is_array(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Array(..)))
}

#[inline]
pub fn is_mutable_ptr(&self) -> bool {
matches!(
self,
TyKind::RigidTy(RigidTy::RawPtr(_, Mutability::Mut))
| TyKind::RigidTy(RigidTy::Ref(_, _, Mutability::Mut))
)
}

#[inline]
pub fn is_raw_ptr(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::RawPtr(..)))
}

/// Tests if this is any kind of primitive pointer type (reference, raw pointer, fn pointer).
#[inline]
pub fn is_any_ptr(&self) -> bool {
self.is_ref() || self.is_raw_ptr() || self.is_fn_ptr()
}

#[inline]
pub fn is_coroutine(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Coroutine(..)))
}

#[inline]
pub fn is_closure(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Closure(..)))
}

#[inline]
pub fn is_box(&self) -> bool {
match self {
TyKind::RigidTy(RigidTy::Adt(def, _)) => def.is_box(),
_ => false,
}
}

#[inline]
pub fn is_simd(&self) -> bool {
matches!(self, TyKind::RigidTy(RigidTy::Adt(def, _)) if def.is_simd())
}

pub fn trait_principal(&self) -> Option<Binder<ExistentialTraitRef>> {
if let TyKind::RigidTy(RigidTy::Dynamic(predicates, _, _)) = self {
if let Some(Binder { value: ExistentialPredicate::Trait(trait_ref), bound_vars }) =
Expand Down Expand Up @@ -300,12 +402,12 @@ impl TyKind {
}
}

/// Get the function signature for function like types (Fn, FnPtr, Closure, Coroutine)
/// FIXME(closure)
/// Get the function signature for function like types (Fn, FnPtr, and Closure)
pub fn fn_sig(&self) -> Option<PolyFnSig> {
match self {
TyKind::RigidTy(RigidTy::FnDef(def, args)) => Some(with(|cx| cx.fn_sig(*def, args))),
TyKind::RigidTy(RigidTy::FnPtr(sig)) => Some(sig.clone()),
TyKind::RigidTy(RigidTy::Closure(_def, args)) => Some(with(|cx| cx.closure_sig(args))),
_ => None,
}
}
Expand Down Expand Up @@ -481,6 +583,10 @@ impl AdtDef {
with(|cx| cx.adt_is_box(*self))
}

pub fn is_simd(&self) -> bool {
with(|cx| cx.adt_is_simd(*self))
}

/// The number of variants in this ADT.
pub fn num_variants(&self) -> usize {
with(|cx| cx.adt_variants_len(*self))
Expand Down Expand Up @@ -738,13 +844,24 @@ pub enum Abi {
RiscvInterruptS,
}

/// A binder represents a possibly generic type and its bound vars.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Binder<T> {
pub value: T,
pub bound_vars: Vec<BoundVariableKind>,
}

impl<T> Binder<T> {
/// Create a new binder with the given bound vars.
pub fn bind_with_vars(value: T, bound_vars: Vec<BoundVariableKind>) -> Self {
Binder { value, bound_vars }
}

/// Create a new binder with no bounded variable.
pub fn dummy(value: T) -> Self {
Binder { value, bound_vars: vec![] }
}

pub fn skip_binder(self) -> T {
self.value
}
Expand Down
Loading