Skip to content

Commit

Permalink
improve cold_path()
Browse files Browse the repository at this point in the history
  • Loading branch information
x17jiri committed Feb 17, 2025
1 parent c705b7d commit 7bb5f4d
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 15 deletions.
48 changes: 46 additions & 2 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{iter, ptr};

pub(crate) mod autodiff;

use libc::{c_char, c_uint};
use libc::{c_char, c_uint, size_t};
use rustc_abi as abi;
use rustc_abi::{Align, Size, WrappingRange};
use rustc_codegen_ssa::MemFlags;
Expand Down Expand Up @@ -32,7 +32,7 @@ use crate::abi::FnAbiLlvmExt;
use crate::attributes;
use crate::common::Funclet;
use crate::context::{CodegenCx, SimpleCx};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True};
use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
use crate::value::Value;
Expand Down Expand Up @@ -333,6 +333,50 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

fn switch_with_weights(
&mut self,
v: Self::Value,
else_llbb: Self::BasicBlock,
else_is_cold: bool,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
) {
if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No {
self.switch(v, else_llbb, cases.map(|(val, dest, _)| (val, dest)));
return;
}

let id_str = "branch_weights";
let id = unsafe {
llvm::LLVMMDStringInContext2(self.cx.llcx, id_str.as_ptr().cast(), id_str.len())
};

// For switch instructions with 2 targets, the `llvm.expect` intrinsic is used.
// This function handles switch instructions with more than 2 targets and it needs to
// emit branch weights metadata instead of using the intrinsic.
// The values 1 and 2000 are the same as the values used by the `llvm.expect` intrinsic.
let cold_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(1)) };
let hot_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(2000)) };
let weight =
|is_cold: bool| -> &Metadata { if is_cold { cold_weight } else { hot_weight } };

let mut md: SmallVec<[&Metadata; 16]> = SmallVec::with_capacity(cases.len() + 2);
md.push(id);
md.push(weight(else_is_cold));

let switch =
unsafe { llvm::LLVMBuildSwitch(self.llbuilder, v, else_llbb, cases.len() as c_uint) };
for (on_val, dest, is_cold) in cases {
let on_val = self.const_uint_big(self.val_ty(v), on_val);
unsafe { llvm::LLVMAddCase(switch, on_val, dest) }
md.push(weight(is_cold));
}

unsafe {
let md_node = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len() as size_t);
self.cx.set_metadata(switch, llvm::MD_prof, md_node);
}
}

fn invoke(
&mut self,
llty: &'ll Type,
Expand Down
33 changes: 28 additions & 5 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,34 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval);
bx.cond_br(cmp, ll1, ll2);
} else {
bx.switch(
discr_value,
helper.llbb_with_cleanup(self, targets.otherwise()),
target_iter.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
);
let otherwise = targets.otherwise();
let otherwise_cold = self.cold_blocks[otherwise];
let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable();
let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count();
let none_cold = cold_count == 0;
let all_cold = cold_count == targets.iter().len();
if (none_cold && (!otherwise_cold || otherwise_unreachable))
|| (all_cold && (otherwise_cold || otherwise_unreachable))
{
// All targets have the same weight,
// or `otherwise` is unreachable and it's the only target with a different weight.
bx.switch(
discr_value,
helper.llbb_with_cleanup(self, targets.otherwise()),
target_iter
.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))),
);
} else {
// Targets have different weights
bx.switch_with_weights(
discr_value,
helper.llbb_with_cleanup(self, targets.otherwise()),
otherwise_cold,
target_iter.map(|(value, target)| {
(value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target])
}),
);
}
}
}

Expand Down
27 changes: 19 additions & 8 deletions compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,25 @@ fn find_cold_blocks<'tcx>(
for (bb, bb_data) in traversal::postorder(mir) {
let terminator = bb_data.terminator();

// If a BB ends with a call to a cold function, mark it as cold.
if let mir::TerminatorKind::Call { ref func, .. } = terminator.kind
&& let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
&& let attrs = tcx.codegen_fn_attrs(def_id)
&& attrs.flags.contains(CodegenFnAttrFlags::COLD)
{
cold_blocks[bb] = true;
continue;
match terminator.kind {
// If a BB ends with a call to a cold function, mark it as cold.
mir::TerminatorKind::Call { ref func, .. }
| mir::TerminatorKind::TailCall { ref func, .. }
if let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind()
&& let attrs = tcx.codegen_fn_attrs(def_id)
&& attrs.flags.contains(CodegenFnAttrFlags::COLD) =>
{
cold_blocks[bb] = true;
continue;
}

// If a BB ends with an `unreachable`, also mark it as cold.
mir::TerminatorKind::Unreachable => {
cold_blocks[bb] = true;
continue;
}

_ => {}
}

// If all successors of a BB are cold and there's at least one of them, mark this BB as cold
Expand Down
14 changes: 14 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,20 @@ pub trait BuilderMethods<'a, 'tcx>:
else_llbb: Self::BasicBlock,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock)>,
);

// This is like `switch()`, but every case has a bool flag indicating whether it's cold.
//
// Default implementation throws away the cold flags and calls `switch()`.
fn switch_with_weights(
&mut self,
v: Self::Value,
else_llbb: Self::BasicBlock,
_else_is_cold: bool,
cases: impl ExactSizeIterator<Item = (u128, Self::BasicBlock, bool)>,
) {
self.switch(v, else_llbb, cases.map(|(val, bb, _)| (val, bb)))
}

fn invoke(
&mut self,
llty: Self::Type,
Expand Down
36 changes: 36 additions & 0 deletions tests/codegen/intrinsics/cold_path2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//@ compile-flags: -O
#![crate_type = "lib"]
#![feature(core_intrinsics)]

use std::intrinsics::cold_path;

#[inline(never)]
#[no_mangle]
pub fn path_a() {
println!("path a");
}

#[inline(never)]
#[no_mangle]
pub fn path_b() {
println!("path b");
}

#[no_mangle]
pub fn test(x: Option<bool>) {
if let Some(_) = x {
path_a();
} else {
cold_path();
path_b();
}

// CHECK-LABEL: @test(
// CHECK: br i1 %1, label %bb2, label %bb1, !prof ![[NUM:[0-9]+]]
// CHECK: bb1:
// CHECK: path_a
// CHECK: bb2:
// CHECK: path_b
}

// CHECK: ![[NUM]] = !{!"branch_weights", {{(!"expected", )?}}i32 1, i32 2000}
87 changes: 87 additions & 0 deletions tests/codegen/intrinsics/cold_path3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//@ compile-flags: -O
#![crate_type = "lib"]
#![feature(core_intrinsics)]

use std::intrinsics::cold_path;

#[inline(never)]
#[no_mangle]
pub fn path_a() {
println!("path a");
}

#[inline(never)]
#[no_mangle]
pub fn path_b() {
println!("path b");
}

#[inline(never)]
#[no_mangle]
pub fn path_c() {
println!("path c");
}

#[inline(never)]
#[no_mangle]
pub fn path_d() {
println!("path d");
}

#[no_mangle]
pub fn test(x: Option<u32>) {
match x {
Some(0) => path_a(),
Some(1) => {
cold_path();
path_b()
}
Some(2) => path_c(),
Some(3) => {
cold_path();
path_d()
}
_ => path_a(),
}

// CHECK-LABEL: @test(
// CHECK: switch i32 %1, label %bb1 [
// CHECK: i32 0, label %bb6
// CHECK: i32 1, label %bb5
// CHECK: i32 2, label %bb4
// CHECK: i32 3, label %bb3
// CHECK: ], !prof ![[NUM1:[0-9]+]]
}

#[no_mangle]
pub fn test2(x: Option<u32>) {
match x {
Some(10) => path_a(),
Some(11) => {
cold_path();
path_b()
}
Some(12) => {
unsafe { core::intrinsics::unreachable() };
path_c()
}
Some(13) => {
cold_path();
path_d()
}
_ => {
cold_path();
path_a()
}
}

// CHECK-LABEL: @test2(
// CHECK: switch i32 %1, label %bb1 [
// CHECK: i32 10, label %bb5
// CHECK: i32 11, label %bb4
// CHECK: i32 13, label %bb3
// CHECK: ], !prof ![[NUM2:[0-9]+]]
}

// CHECK: ![[NUM1]] = !{!"branch_weights", i32 2000, i32 2000, i32 1, i32 2000, i32 1}
// CHECK: ![[NUM2]] = !{!"branch_weights", i32 1, i32 2000, i32 1, i32 1}

0 comments on commit 7bb5f4d

Please sign in to comment.