Skip to content

Commit

Permalink
Auto merge of rust-lang#123572 - Mark-Simulacrum:vtable-methods, r=<try>
Browse files Browse the repository at this point in the history
Increase vtable layout size

This improves LLVM's codegen by allowing vtable loads to be hoisted out of loops (as just one example). The calculation here is an under-approximation but works for simple trait hierarchies (e.g., FnMut will be improved). We have a runtime assert that the approximation is accurate, so there's no risk of UB as a result of getting this wrong.

```rust
#[no_mangle]
pub fn foo(elements: &[u32], callback: &mut dyn Callback) {
    for element in elements.iter() {
        if *element != 0 {
            callback.call(*element);
        }
    }
}

pub trait Callback {
    fn call(&mut self, _: u32);
}
```

Simplifying a bit (e.g., numbering ends up different):

```diff
 ; Function Attrs: nonlazybind uwtable
-define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(24) %callback.1) unnamed_addr #0 {
+define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(32) %callback.1) unnamed_addr #0 {
 start:
   %_15 = getelementptr inbounds i32, ptr %elements.0, i64 %elements.1
`@@` -13,4 +13,5 `@@`
 bb4.lr.ph:                                        ; preds = %start
   %1 = getelementptr inbounds i8, ptr %callback.1, i64 24
+  %2 = load ptr, ptr %1, align 8, !nonnull !3
   br label %bb4

 bb6:                                              ; preds = %bb4
-  %4 = load ptr, ptr %1, align 8, !invariant.load !3, !nonnull !3
-  tail call void %4(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9)
+  tail call void %2(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9)
   br label %bb7
 }
```
  • Loading branch information
bors committed Apr 8, 2024
2 parents 0e5f520 + 26954e7 commit 3d8cc34
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 71 deletions.
25 changes: 7 additions & 18 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,25 +771,14 @@ where
});
}

let mk_dyn_vtable = || {
let mk_dyn_vtable = |principal: Option<ty::PolyExistentialTraitRef<'tcx>>| {
let min_count = ty::vtable_min_entries(tcx, principal);
Ty::new_imm_ref(
tcx,
tcx.lifetimes.re_static,
Ty::new_array(tcx, tcx.types.usize, 3),
// FIXME: properly type (e.g. usize and fn pointers) the fields.
Ty::new_array(tcx, tcx.types.usize, min_count.try_into().unwrap()),
)
/* FIXME: use actual fn pointers
Warning: naively computing the number of entries in the
vtable by counting the methods on the trait + methods on
all parent traits does not work, because some methods can
be not object safe and thus excluded from the vtable.
Increase this counter if you tried to implement this but
failed to do it without duplicating a lot of code from
other places in the compiler: 2
Ty::new_tup(tcx,&[
Ty::new_array(tcx,tcx.types.usize, 3),
Ty::new_array(tcx,Option<fn()>),
])
*/
};

let metadata = if let Some(metadata_def_id) = tcx.lang_items().metadata_type()
Expand All @@ -808,16 +797,16 @@ where
// `std::mem::uninitialized::<&dyn Trait>()`, for example.
if let ty::Adt(def, args) = metadata.kind()
&& Some(def.did()) == tcx.lang_items().dyn_metadata()
&& args.type_at(0).is_trait()
&& let ty::Dynamic(data, _, ty::Dyn) = args.type_at(0).kind()
{
mk_dyn_vtable()
mk_dyn_vtable(data.principal())
} else {
metadata
}
} else {
match tcx.struct_tail_erasing_lifetimes(pointee, cx.param_env()).kind() {
ty::Slice(_) | ty::Str => tcx.types.usize,
ty::Dynamic(_, _, ty::Dyn) => mk_dyn_vtable(),
ty::Dynamic(data, _, ty::Dyn) => mk_dyn_vtable(data.principal()),
_ => bug!("TyAndLayout::field({:?}): not applicable", this),
}
};
Expand Down
64 changes: 64 additions & 0 deletions compiler/rustc_middle/src/ty/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt;
use crate::mir::interpret::{alloc_range, AllocId, Allocation, Pointer, Scalar};
use crate::ty::{self, Instance, PolyTraitRef, Ty, TyCtxt};
use rustc_ast::Mutability;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def_id::DefId;

#[derive(Clone, Copy, PartialEq, HashStable)]
pub enum VtblEntry<'tcx> {
Expand Down Expand Up @@ -45,6 +47,65 @@ pub const COMMON_VTABLE_ENTRIES_DROPINPLACE: usize = 0;
pub const COMMON_VTABLE_ENTRIES_SIZE: usize = 1;
pub const COMMON_VTABLE_ENTRIES_ALIGN: usize = 2;

// FIXME: This is duplicating equivalent code in compiler/rustc_trait_selection/src/traits/util.rs
// But that is a downstream crate, and this code is pretty simple. Probably OK for now.
struct SupertraitDefIds<'tcx> {
tcx: TyCtxt<'tcx>,
stack: Vec<DefId>,
visited: FxHashSet<DefId>,
}

fn supertrait_def_ids(tcx: TyCtxt<'_>, trait_def_id: DefId) -> SupertraitDefIds<'_> {
SupertraitDefIds {
tcx,
stack: vec![trait_def_id],
visited: Some(trait_def_id).into_iter().collect(),
}
}

impl Iterator for SupertraitDefIds<'_> {
type Item = DefId;

fn next(&mut self) -> Option<DefId> {
let def_id = self.stack.pop()?;
let predicates = self.tcx.super_predicates_of(def_id);
let visited = &mut self.visited;
self.stack.extend(
predicates
.predicates
.iter()
.filter_map(|(pred, _)| pred.as_trait_clause())
.map(|trait_ref| trait_ref.def_id())
.filter(|&super_def_id| visited.insert(super_def_id)),
);
Some(def_id)
}
}

// Note that we don't have access to a self type here, this has to be purely based on the trait (and
// supertrait) definitions. That means we can't call into the same vtable_entries code since that
// returns a specific instantiation (e.g., with Vacant slots when bounds aren't satisfied). The goal
// here is to do a best-effort approximation without duplicating a lot of code.
//
// This function is used in layout computation for e.g. &dyn Trait, so it's critical that this
// function is an accurate approximation. We verify this when actually computing the vtable below.
pub(crate) fn vtable_min_entries<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
) -> usize {
let mut count = TyCtxt::COMMON_VTABLE_ENTRIES.len();
let Some(trait_ref) = trait_ref else {
return count;
};

// This includes self in supertraits.
for def_id in supertrait_def_ids(tcx, trait_ref.def_id()) {
count += tcx.own_existential_vtable_entries(def_id).len();
}

count
}

/// Retrieves an allocation that represents the contents of a vtable.
/// Since this is a query, allocations are cached and not duplicated.
pub(super) fn vtable_allocation_provider<'tcx>(
Expand All @@ -62,6 +123,9 @@ pub(super) fn vtable_allocation_provider<'tcx>(
TyCtxt::COMMON_VTABLE_ENTRIES
};

// This confirms that the layout computation for &dyn Trait has an accurate sizing.
assert!(vtable_entries.len() >= vtable_min_entries(tcx, poly_trait_ref));

let layout = tcx
.layout_of(ty::ParamEnv::reveal_all().and(ty))
.expect("failed to build vtable representation");
Expand Down
59 changes: 6 additions & 53 deletions compiler/rustc_trait_selection/src/traits/object_safety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,59 +497,12 @@ fn virtual_call_violations_for_method<'tcx>(
};
errors.push(MethodViolationCode::UndispatchableReceiver(span));
} else {
// Do sanity check to make sure the receiver actually has the layout of a pointer.

use rustc_target::abi::Abi;

let param_env = tcx.param_env(method.def_id);

let abi_of_ty = |ty: Ty<'tcx>| -> Option<Abi> {
match tcx.layout_of(param_env.and(ty)) {
Ok(layout) => Some(layout.abi),
Err(err) => {
// #78372
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!("error: {err}\n while computing layout for type {ty:?}"),
);
None
}
}
};

// e.g., `Rc<()>`
let unit_receiver_ty =
receiver_for_self_ty(tcx, receiver_ty, Ty::new_unit(tcx), method.def_id);

match abi_of_ty(unit_receiver_ty) {
Some(Abi::Scalar(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!(
"receiver when `Self = ()` should have a Scalar ABI; found {abi:?}"
),
);
}
}

let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);

// e.g., `Rc<dyn Trait>`
let trait_object_receiver =
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method.def_id);

match abi_of_ty(trait_object_receiver) {
Some(Abi::ScalarPair(..)) => (),
abi => {
tcx.dcx().span_delayed_bug(
tcx.def_span(method.def_id),
format!(
"receiver when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
),
);
}
}
// We used to have a sanity check here for whether the ABI of the receiver with `()` and `dyn Trait`
// self types was correct. For now that code has been deleted since
// computing the layout_of such types requires knowing the number of methods in the
// virtual table, which in turn requires this code. So we skip these checks. Both
// of these are just sanity checking (i.e. this code is not responsible for ABI) so this
// should be fine.
}
}

Expand Down

0 comments on commit 3d8cc34

Please sign in to comment.