Skip to content

Commit

Permalink
Use a Flag to enable more niche optimizations.
Browse files Browse the repository at this point in the history
Also:
Try to fit other variants by moving fields around the niche.
Keep multiple niches, not just the largest one.
Look for multiple largest variants.
Introduce repr(flag).

Fixes rust-lang#101567
  • Loading branch information
mikebenfield committed Oct 7, 2022
1 parent 0ca3565 commit 418c278
Show file tree
Hide file tree
Showing 27 changed files with 1,301 additions and 631 deletions.
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4249,6 +4249,7 @@ dependencies = [
"rustc_target",
"rustc_trait_selection",
"rustc_type_ir",
"smallvec",
"tracing",
]

Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3065,6 +3065,8 @@ mod size_asserts {
static_assert_size!(PathSegment, 24);
static_assert_size!(Stmt, 32);
static_assert_size!(StmtKind, 16);
static_assert_size!(Ty, 96);
static_assert_size!(TyKind, 72);
#[cfg(not(bootstrap))]
static_assert_size!(Ty, 88);
#[cfg(not(bootstrap))]
static_assert_size!(TyKind, 64);
}
2 changes: 2 additions & 0 deletions compiler/rustc_attr/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ pub enum ReprAttr {
ReprSimd,
ReprTransparent,
ReprAlign(u32),
ReprFlag,
}

#[derive(Eq, PartialEq, Debug, Copy, Clone)]
Expand Down Expand Up @@ -998,6 +999,7 @@ pub fn parse_repr_attr(sess: &Session, attr: &Attribute) -> Vec<ReprAttr> {
recognised = true;
None
}
sym::flag => Some(ReprFlag),
name => int_type_of_word(name).map(ReprInt),
};

Expand Down
10 changes: 2 additions & 8 deletions compiler/rustc_codegen_cranelift/src/abi/comments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,8 @@ pub(super) fn add_local_place_comments<'tcx>(
return;
}
let TyAndLayout { ty, layout } = place.layout();
let rustc_target::abi::LayoutS {
size,
align,
abi: _,
variants: _,
fields: _,
largest_niche: _,
} = layout.0.0;
let rustc_target::abi::LayoutS { size, align, abi: _, variants: _, fields: _, niches: _ } =
layout.0.0;

let (kind, extra) = match *place.inner() {
CPlaceInner::Var(place_local, var) => {
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_codegen_cranelift/src/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
Variants::Multiple {
tag: _,
tag_field,
tag_encoding: TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, .. },
variants: _,
} => {
if variant_index != untagged_variant {
Expand Down Expand Up @@ -113,7 +114,7 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
let res = CValue::by_val(val, dest_layout);
dest.write_cvalue(fx, res);
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, .. } => {
// Rebase from niche values to discriminants, and check
// whether the result is in range for the niche variants.

Expand Down
50 changes: 26 additions & 24 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,32 +430,34 @@ fn compute_discriminant_value<'ll, 'tcx>(
enum_type_and_layout.ty.discriminant_for_variant(cx.tcx, variant_index).unwrap().val,
),
&Variants::Multiple {
tag_encoding: TagEncoding::Niche { ref niche_variants, niche_start, untagged_variant },
tag,
//tag_encoding: TagEncoding::Niche { ref niche_variants, niche_start, untagged_variant, .. },
//tag,
..
} => {
if variant_index == untagged_variant {
let valid_range = enum_type_and_layout
.for_variant(cx, variant_index)
.largest_niche
.as_ref()
.unwrap()
.valid_range;

let min = valid_range.start.min(valid_range.end);
let min = tag.size(cx).truncate(min);

let max = valid_range.start.max(valid_range.end);
let max = tag.size(cx).truncate(max);

DiscrResult::Range(min, max)
} else {
let value = (variant_index.as_u32() as u128)
.wrapping_sub(niche_variants.start().as_u32() as u128)
.wrapping_add(niche_start);
let value = tag.size(cx).truncate(value);
DiscrResult::Value(value)
}
// YYY
DiscrResult::Range(0, 1)
//if variant_index == untagged_variant {
// let valid_range = enum_type_and_layout
// .for_variant(cx, variant_index)
// .largest_niche
// .as_ref()
// .unwrap()
// .valid_range;

// let min = valid_range.start.min(valid_range.end);
// let min = tag.size(cx).truncate(min);

// let max = valid_range.start.max(valid_range.end);
// let max = tag.size(cx).truncate(max);

// DiscrResult::Range(min, max)
//} else {
// let value = (variant_index.as_u32() as u128)
// .wrapping_sub(niche_variants.start().as_u32() as u128)
// .wrapping_add(niche_start);
// let value = tag.size(cx).truncate(value);
// DiscrResult::Value(value)
//}
}
}
}
111 changes: 70 additions & 41 deletions compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::operand::OperandValue;
use super::{FunctionCx, LocalRef};

use crate::common::IntPredicate;
use crate::common::{IntPredicate, TypeKind};
use crate::glue;
use crate::traits::*;

Expand Down Expand Up @@ -227,13 +227,13 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}
};

// Read the tag/niche-encoded discriminant from memory.
let tag = self.project_field(bx, tag_field);
let tag = bx.load_operand(tag);
let tag_place = self.project_field(bx, tag_field);

// Decode the discriminant (specifically if it's niche-encoded).
match *tag_encoding {
TagEncoding::Direct => {
// Read the tag from memory.
let tag = bx.load_operand(tag_place);
let signed = match tag_scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
Expand All @@ -244,11 +244,30 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
};
bx.intcast(tag.immediate(), cast_to, signed)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
// Rebase from niche values to discriminants, and check
// whether the result is in range for the niche variants.
let niche_llty = bx.cx().immediate_backend_type(tag.layout);
let tag = tag.immediate();
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, } => {
let read = |bx: &mut Bx, place: Self| -> (V, <Bx as BackendTypes>::Type) {
let ty = bx.cx().immediate_backend_type(place.layout);
let op = bx.load_operand(place);
let val = op.immediate();
if bx.cx().type_kind(ty) == TypeKind::Pointer {
let new_ty = bx.cx().type_isize();
let new_val = bx.ptrtoint(val, new_ty);
(new_val, new_ty)
} else {
(val, ty)
}
};

let (tag, niche_llty) = read(bx, tag_place);

let (untagged_in_niche, flag_eq_magic_value_opt) = if let Some(flag) = flag {
let flag_place = self.project_field(bx, flag.field);
let (flag_imm, flag_llty) = read(bx, flag_place);
let magic_value = bx.cx().const_uint_big(flag_llty, flag.magic_value);
(flag.untagged_in_niche, Some(bx.icmp(IntPredicate::IntEQ, flag_imm, magic_value)))
} else {
(true, None)
};

// We first compute the "relative discriminant" (wrt `niche_variants`),
// that is, if `n = niche_variants.end() - niche_variants.start()`,
Expand All @@ -259,23 +278,8 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
// and check that it is in the range `niche_variants`, because
// that might not fit in the same type, on top of needing an extra
// comparison (see also the comment on `let niche_discr`).
let relative_discr = if niche_start == 0 {
// Avoid subtracting `0`, which wouldn't work for pointers.
// FIXME(eddyb) check the actual primitive type here.
tag
} else {
bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start))
};
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start));
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let is_niche = if relative_max == 0 {
// Avoid calling `const_uint`, which wouldn't work for pointers.
// Also use canonical == 0 instead of non-canonical u<= 0.
// FIXME(eddyb) check the actual primitive type here.
bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty))
} else {
let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64);
bx.icmp(IntPredicate::IntULE, relative_discr, relative_max)
};

// NOTE(eddyb) this addition needs to be performed on the final
// type, in case the niche itself can't represent all variant
Expand All @@ -285,7 +289,7 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
// In other words, `niche_variants.end - niche_variants.start`
// is representable in the niche, but `niche_variants.end`
// might not be, in extreme cases.
let niche_discr = {
let potential_niche_discr = {
let relative_discr = if relative_max == 0 {
// HACK(eddyb) since we have only one niche, we know which
// one it is, and we can avoid having a dynamic value here.
Expand All @@ -299,11 +303,29 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
)
};

bx.select(
is_niche,
niche_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
)
let untagged_discr = bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64);

let niche_discr = if untagged_in_niche {
let relative_max_const = bx.cx().const_uint(niche_llty, relative_max as u64);
let is_niche = bx.icmp(IntPredicate::IntULE, relative_discr, relative_max_const);
bx.select(
is_niche,
potential_niche_discr,
untagged_discr,
)
} else {
potential_niche_discr
};

if let Some(flag_eq_magic_value) = flag_eq_magic_value_opt {
bx.select(
flag_eq_magic_value,
niche_discr,
untagged_discr,
)
} else {
niche_discr
}
}
}
}
Expand Down Expand Up @@ -337,23 +359,30 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}
Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, },
tag_field,
..
} => {
let store = |bx: &mut Bx, value: u128, place: Self| {
let ty = bx.cx().immediate_backend_type(place.layout);
let val = if bx.cx().type_kind(ty) == TypeKind::Pointer {
let ty_isize = bx.cx().type_isize();
let llvalue = bx.cx().const_uint_big(ty_isize, value);
bx.inttoptr(llvalue, ty)
} else {
bx.cx().const_uint_big(ty, value)
};
OperandValue::Immediate(val).store(bx, place);
};
if variant_index != untagged_variant {
if let Some(flag) = flag {
let place = self.project_field(bx, flag.field);
store(bx, flag.magic_value, place);
}
let niche = self.project_field(bx, tag_field);
let niche_llty = bx.cx().immediate_backend_type(niche.layout);
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
let niche_value = (niche_value as u128).wrapping_add(niche_start);
// FIXME(eddyb): check the actual primitive type here.
let niche_llval = if niche_value == 0 {
// HACK(eddyb): using `c_null` as it works on all types.
bx.cx().const_null(niche_llty)
} else {
bx.cx().const_uint_big(niche_llty, niche_value)
};
OperandValue::Immediate(niche_llval).store(bx, niche);
store(bx, niche_value, niche);
}
}
}
Expand Down
46 changes: 38 additions & 8 deletions compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,27 +722,55 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// Return the cast value, and the index.
(discr_val, index.0)
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, } => {
let is_magic_val = if let Some(flag) = flag {
let flag_val = self.read_immediate(&self.operand_field(op, flag.field)?)?;
let flag_val = flag_val.to_scalar();
match flag_val.try_to_int() {
Err(dbg_val) => {
// So this is a pointer then, and casting to an int
// failed. Can only happen during CTFE. If the magic
// value is 0 and the scalar is not null, we know
// the pointer cannot be the magic value. Anything
// else we conservatively reject.
let ptr_definitely_not_magic_value =
flag.magic_value == 0 && !self.scalar_may_be_null(flag_val)?;
if !ptr_definitely_not_magic_value {
throw_ub!(InvalidTag(dbg_val))
}
false
}
Ok(flag_bits) => {
let flag_layout =
self.layout_of(flag.scalar.primitive().to_int_ty(*self.tcx))?;
let flag_bits = flag_bits.assert_bits(flag_layout.size);
flag_bits == flag.magic_value
}
}
} else {
true
};
let tag_val = tag_val.to_scalar();
// Compute the variant this niche value/"tag" corresponds to. With niche layout,
// discriminant (encoded in niche/tag) and variant index are the same.
let variants_start = niche_variants.start().as_u32();
let variants_end = niche_variants.end().as_u32();
let variant = match tag_val.try_to_int() {
Err(dbg_val) => {
let variant = match (is_magic_val, tag_val.try_to_int()) {
(false, _) => untagged_variant,
(true, Err(dbg_val)) => {
// So this is a pointer then, and casting to an int failed.
// Can only happen during CTFE.
// The niche must be just 0, and the ptr not null, then we know this is
// okay. Everything else, we conservatively reject.
let ptr_valid = niche_start == 0
let ptr_definitely_not_in_niche_variants = niche_start == 0
&& variants_start == variants_end
&& !self.scalar_may_be_null(tag_val)?;
if !ptr_valid {
if !ptr_definitely_not_in_niche_variants {
throw_ub!(InvalidTag(dbg_val))
}
untagged_variant
}
Ok(tag_bits) => {
(true, Ok(tag_bits)) => {
let tag_bits = tag_bits.assert_bits(tag_layout.size);
// We need to use machine arithmetic to get the relative variant idx:
// variant_index_relative = tag_val - niche_start_val
Expand Down Expand Up @@ -791,6 +819,8 @@ mod size_asserts {
// These are in alphabetical order, which is easy to maintain.
static_assert_size!(Immediate, 48);
static_assert_size!(ImmTy<'_>, 64);
static_assert_size!(Operand, 56);
static_assert_size!(OpTy<'_>, 80);
#[cfg(not(bootstrap))]
static_assert_size!(Operand, 48);
#[cfg(not(bootstrap))]
static_assert_size!(OpTy<'_>, 72);
}
11 changes: 9 additions & 2 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,15 +823,22 @@ where
}
abi::Variants::Multiple {
tag_encoding:
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start, ref flag, },
tag: tag_layout,
tag_field,
..
} => {
// No need to validate that the discriminant here because the
// No need to validate the discriminant here because the
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.

if variant_index != untagged_variant {
if let Some(flag) = flag {
let flag_layout = self.layout_of(flag.scalar.primitive().to_int_ty(*self.tcx))?;
let val = ImmTy::from_uint(flag.magic_value, flag_layout);
let flag_dest = self.place_field(dest, flag.field)?;
self.write_immediate(*val, &flag_dest)?;
}

let variants_start = niche_variants.start().as_u32();
let variant_index_relative = variant_index
.as_u32()
Expand Down
Loading

0 comments on commit 418c278

Please sign in to comment.