Skip to content

Commit

Permalink
Auto merge of rust-lang#96862 - oli-obk:enum_cast_mir, r=RalfJung
Browse files Browse the repository at this point in the history
Change enum->int casts to not go through MIR casts.

follow-up to rust-lang#96814

this simplifies all backends and even gives LLVM more information about the return value of `Rvalue::Discriminant`, enabling optimizations in more cases.
  • Loading branch information
bors committed Jul 5, 2022
2 parents 4045ce6 + 82c73af commit 53792b9
Show file tree
Hide file tree
Showing 21 changed files with 238 additions and 143 deletions.
23 changes: 0 additions & 23 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,29 +635,6 @@ fn codegen_stmt<'tcx>(
let (ptr, _extra) = operand.load_scalar_pair(fx);
lval.write_cvalue(fx, CValue::by_val(ptr, dest_layout))
}
} else if let ty::Adt(adt_def, _substs) = from_ty.kind() {
// enum -> discriminant value
assert!(adt_def.is_enum());
match to_ty.kind() {
ty::Uint(_) | ty::Int(_) => {}
_ => unreachable!("cast adt {} -> {}", from_ty, to_ty),
}
let to_clif_ty = fx.clif_type(to_ty).unwrap();

let discriminant = crate::discriminant::codegen_get_discriminant(
fx,
operand,
fx.layout_of(operand.layout().ty.discriminant_ty(fx.tcx)),
)
.load_scalar(fx);

let res = crate::cast::clif_intcast(
fx,
discriminant,
to_clif_ty,
to_ty.is_signed(),
);
lval.write_cvalue(fx, CValue::by_val(res, dest_layout));
} else {
let to_clif_ty = fx.clif_type(to_ty).unwrap();
let from = operand.load_scalar(fx);
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::ffi::CStr;
use std::iter;
use std::ops::Deref;
use std::ptr;
use tracing::debug;
use tracing::{debug, instrument};

// All Builders must have an llfn associated with them
#[must_use]
Expand Down Expand Up @@ -464,15 +464,15 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

#[instrument(level = "trace", skip(self))]
fn load_operand(&mut self, place: PlaceRef<'tcx, &'ll Value>) -> OperandRef<'tcx, &'ll Value> {
debug!("PlaceRef::load: {:?}", place);

assert_eq!(place.llextra.is_some(), place.layout.is_unsized());

if place.layout.is_zst() {
return OperandRef::new_zst(self, place.layout);
}

#[instrument(level = "trace", skip(bx))]
fn scalar_load_metadata<'a, 'll, 'tcx>(
bx: &mut Builder<'a, 'll, 'tcx>,
load: &'ll Value,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ extern "C" {
B: &Builder<'a>,
Val: &'a Value,
DestTy: &'a Type,
IsSized: bool,
IsSigned: bool,
) -> &'a Value;

// Comparisons
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}

/// Obtain the actual discriminant of a value.
#[instrument(level = "trace", skip(bx))]
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
bx: &mut Bx,
Expand Down Expand Up @@ -420,12 +421,12 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "trace", skip(self, bx))]
pub fn codegen_place(
&mut self,
bx: &mut Bx,
place_ref: mir::PlaceRef<'tcx>,
) -> PlaceRef<'tcx, Bx::Value> {
debug!("codegen_place(place_ref={:?})", place_ref);
let cx = self.cx;
let tcx = self.cx.tcx();

Expand Down
81 changes: 9 additions & 72 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ use rustc_middle::ty::cast::{CastTy, IntTy};
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf};
use rustc_middle::ty::{self, adjustment::PointerCast, Instance, Ty, TyCtxt};
use rustc_span::source_map::{Span, DUMMY_SP};
use rustc_target::abi::{Abi, Int, Variants};

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "trace", skip(self, bx))]
pub fn codegen_rvalue(
&mut self,
mut bx: Bx,
dest: PlaceRef<'tcx, Bx::Value>,
rvalue: &mir::Rvalue<'tcx>,
) -> Bx {
debug!("codegen_rvalue(dest.llval={:?}, rvalue={:?})", dest.llval, rvalue);

match *rvalue {
mir::Rvalue::Use(ref operand) => {
let cg_operand = self.codegen_operand(&mut bx, operand);
Expand Down Expand Up @@ -285,74 +283,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
CastTy::from_ty(operand.layout.ty).expect("bad input type for cast");
let r_t_out = CastTy::from_ty(cast.ty).expect("bad output type for cast");
let ll_t_in = bx.cx().immediate_backend_type(operand.layout);
match operand.layout.variants {
Variants::Single { index } => {
if let Some(discr) =
operand.layout.ty.discriminant_for_variant(bx.tcx(), index)
{
let discr_layout = bx.cx().layout_of(discr.ty);
let discr_t = bx.cx().immediate_backend_type(discr_layout);
let discr_val = bx.cx().const_uint_big(discr_t, discr.val);
let discr_val =
bx.intcast(discr_val, ll_t_out, discr.ty.is_signed());

return (
bx,
OperandRef {
val: OperandValue::Immediate(discr_val),
layout: cast,
},
);
}
}
Variants::Multiple { .. } => {}
}
let llval = operand.immediate();

let mut signed = false;
if let Abi::Scalar(scalar) = operand.layout.abi {
if let Int(_, s) = scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
// let LLVM interpret the `i1` as signed, because
// then `i1 1` (i.e., E::B) is effectively `i8 -1`.
signed = !scalar.is_bool() && s;

if !scalar.is_always_valid(bx.cx())
&& scalar.valid_range(bx.cx()).end
>= scalar.valid_range(bx.cx()).start
{
// We want `table[e as usize ± k]` to not
// have bound checks, and this is the most
// convenient place to put the `assume`s.
if scalar.valid_range(bx.cx()).start > 0 {
let enum_value_lower_bound = bx.cx().const_uint_big(
ll_t_in,
scalar.valid_range(bx.cx()).start,
);
let cmp_start = bx.icmp(
IntPredicate::IntUGE,
llval,
enum_value_lower_bound,
);
bx.assume(cmp_start);
}

let enum_value_upper_bound = bx
.cx()
.const_uint_big(ll_t_in, scalar.valid_range(bx.cx()).end);
let cmp_end = bx.icmp(
IntPredicate::IntULE,
llval,
enum_value_upper_bound,
);
bx.assume(cmp_end);
}
}
}

let newval = match (r_t_in, r_t_out) {
(CastTy::Int(_), CastTy::Int(_)) => bx.intcast(llval, ll_t_out, signed),
(CastTy::Int(i), CastTy::Int(_)) => {
bx.intcast(llval, ll_t_out, i.is_signed())
}
(CastTy::Float, CastTy::Float) => {
let srcsz = bx.cx().float_width(ll_t_in);
let dstsz = bx.cx().float_width(ll_t_out);
Expand All @@ -364,8 +300,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
llval
}
}
(CastTy::Int(_), CastTy::Float) => {
if signed {
(CastTy::Int(i), CastTy::Float) => {
if i.is_signed() {
bx.sitofp(llval, ll_t_out)
} else {
bx.uitofp(llval, ll_t_out)
Expand All @@ -374,8 +310,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
(CastTy::Ptr(_) | CastTy::FnPtr, CastTy::Ptr(_)) => {
bx.pointercast(llval, ll_t_out)
}
(CastTy::Int(_), CastTy::Ptr(_)) => {
let usize_llval = bx.intcast(llval, bx.cx().type_isize(), signed);
(CastTy::Int(i), CastTy::Ptr(_)) => {
let usize_llval =
bx.intcast(llval, bx.cx().type_isize(), i.is_signed());
bx.inttoptr(usize_llval, ll_t_out)
}
(CastTy::Float, CastTy::Int(IntTy::I)) => {
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ use crate::traits::BuilderMethods;
use crate::traits::*;

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_statement(&mut self, mut bx: Bx, statement: &mir::Statement<'tcx>) -> Bx {
debug!("codegen_statement(statement={:?})", statement);

self.set_debug_loc(&mut bx, statement.source_info);
match statement.kind {
mir::StatementKind::Assign(box (ref place, ref rvalue)) => {
Expand Down
25 changes: 2 additions & 23 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rustc_middle::mir::CastKind;
use rustc_middle::ty::adjustment::PointerCast;
use rustc_middle::ty::layout::{IntegerExt, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, FloatTy, Ty, TypeAndMut};
use rustc_target::abi::{Integer, Variants};
use rustc_target::abi::Integer;
use rustc_type_ir::sty::TyKind::*;

use super::{
Expand Down Expand Up @@ -128,12 +128,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Float(FloatTy::F64) => {
return Ok(self.cast_from_float(src.to_scalar()?.to_f64()?, cast_ty).into());
}
// The rest is integer/pointer-"like", including fn ptr casts and casts from enums that
// are represented as integers.
// The rest is integer/pointer-"like", including fn ptr casts
_ => assert!(
src.layout.ty.is_bool()
|| src.layout.ty.is_char()
|| src.layout.ty.is_enum()
|| src.layout.ty.is_integral()
|| src.layout.ty.is_any_ptr(),
"Unexpected cast from type {:?}",
Expand All @@ -143,25 +141,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

// # First handle non-scalar source values.

// Handle cast from a ZST enum (0 or 1 variants).
match src.layout.variants {
Variants::Single { index } => {
if src.layout.abi.is_uninhabited() {
// This is dead code, because an uninhabited enum is UB to
// instantiate.
throw_ub!(Unreachable);
}
if let Some(discr) = src.layout.ty.discriminant_for_variant(*self.tcx, index) {
assert!(src.layout.is_zst());
let discr_layout = self.layout_of(discr.ty)?;

let scalar = Scalar::from_uint(discr.val, discr_layout.layout.size());
return Ok(self.cast_from_int_like(scalar, discr_layout, cast_ty)?.into());
}
}
Variants::Multiple { .. } => {}
}

// Handle casting any ptr to raw ptr (might be a fat ptr).
if src.layout.ty.is_any_ptr() && cast_ty.is_unsafe_ptr() {
let dest_layout = self.layout_of(cast_ty)?;
Expand Down
32 changes: 28 additions & 4 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
use rustc_middle::mir::{
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, Local, Location, MirPass,
MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, Rvalue, SourceScope, Statement,
StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, Local, Location,
MirPass, MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, Rvalue, SourceScope,
Statement, StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
};
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeFoldable};
Expand Down Expand Up @@ -361,6 +361,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
);
}
}
Rvalue::Ref(..) => {}
Rvalue::Len(p) => {
let pty = p.ty(&self.body.local_decls, self.tcx).ty;
check_kinds!(
Expand Down Expand Up @@ -503,7 +504,30 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
let a = operand.ty(&self.body.local_decls, self.tcx);
check_kinds!(a, "Cannot shallow init type {:?}", ty::RawPtr(..));
}
_ => {}
Rvalue::Cast(kind, operand, target_type) => {
match kind {
CastKind::Misc => {
let op_ty = operand.ty(self.body, self.tcx);
if op_ty.is_enum() {
self.fail(
location,
format!(
"enum -> int casts should go through `Rvalue::Discriminant`: {operand:?}:{op_ty} as {target_type}",
),
);
}
}
// Nothing to check here
CastKind::PointerFromExposedAddress
| CastKind::PointerExposeAddress
| CastKind::Pointer(_) => {}
}
}
Rvalue::Repeat(_, _)
| Rvalue::ThreadLocalRef(_)
| Rvalue::AddressOf(_, _)
| Rvalue::NullaryOp(_, _)
| Rvalue::Discriminant(_) => {}
}
self.super_rvalue(rvalue, location);
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_middle/src/ty/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ pub enum IntTy {
Char,
}

impl IntTy {
pub fn is_signed(self) -> bool {
matches!(self, Self::I)
}
}

// Valid types for the result of a non-coercion cast
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CastTy<'tcx> {
Expand Down
29 changes: 25 additions & 4 deletions compiler/rustc_mir_build/src/build/expr/as_rvalue.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! See docs in `build/expr/mod.rs`.

use rustc_index::vec::Idx;
use rustc_middle::ty::util::IntTypeExt;

use crate::build::expr::as_place::PlaceBase;
use crate::build::expr::category::{Category, RvalueFunc};
Expand Down Expand Up @@ -190,7 +191,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
ExprKind::Cast { source } => {
let source = &this.thir[source];
let from_ty = CastTy::from_ty(source.ty);

// Casting an enum to an integer is equivalent to computing the discriminant and casting the
// discriminant. Previously every backend had to repeat the logic for this operation. Now we
// create all the steps directly in MIR with operations all backends need to support anyway.
let (source, ty) = if let ty::Adt(adt_def, ..) = source.ty.kind() && adt_def.is_enum() {
let discr_ty = adt_def.repr().discr_type().to_ty(this.tcx);
let place = unpack!(block = this.as_place(block, source));
let discr = this.temp(discr_ty, source.span);
this.cfg.push_assign(
block,
source_info,
discr,
Rvalue::Discriminant(place),
);

(Operand::Move(discr), discr_ty)
} else {
let ty = source.ty;
let source = unpack!(
block = this.as_operand(block, scope, source, None, NeedsTemporary::No)
);
(source, ty)
};
let from_ty = CastTy::from_ty(ty);
let cast_ty = CastTy::from_ty(expr.ty);
let cast_kind = match (from_ty, cast_ty) {
(Some(CastTy::Ptr(_) | CastTy::FnPtr), Some(CastTy::Int(_))) => {
Expand All @@ -201,9 +225,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
(_, _) => CastKind::Misc,
};
let source = unpack!(
block = this.as_operand(block, scope, source, None, NeedsTemporary::No)
);
block.and(Rvalue::Cast(cast_kind, source, expr.ty))
}
ExprKind::Pointer { cast, source } => {
Expand Down
Loading

0 comments on commit 53792b9

Please sign in to comment.