Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't aggregate homogeneous floats in the Rust ABI #93564

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use rustc_session::{config::OptLevel, DataTypeKind, FieldInfo, SizeKind, Variant
use rustc_span::symbol::Symbol;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::call::{
ArgAbi, ArgAttribute, ArgAttributes, ArgExtension, Conv, FnAbi, PassMode, Reg, RegKind,
ArgAbi, ArgAttribute, ArgAttributes, ArgExtension, Conv, FnAbi, HomogeneousAggregate, PassMode,
Reg, RegKind,
};
use rustc_target::abi::*;
use rustc_target::spec::{abi::Abi as SpecAbi, HasTargetSpec, PanicStrategy, Target};
Expand Down Expand Up @@ -3194,7 +3195,39 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> {
}

match arg.layout.abi {
Abi::Aggregate { .. } => {}
Abi::Aggregate { .. } => {
// Pass and return structures up to 2 pointers in size by value,
// matching `ScalarPair`. LLVM will usually pass these in 2 registers
// which is more efficient than by-ref.
let max_by_val_size = Pointer.size(self) * 2;
let size = arg.layout.size;

if arg.layout.is_unsized() || size > max_by_val_size {
arg.make_indirect();
} else if let Ok(HomogeneousAggregate::Homogeneous(Reg {
kind: RegKind::Float,
..
})) = arg.layout.homogeneous_aggregate(self)
{
// We don't want to aggregate floats as an aggregates of Integer
// because this will hurt the generated assembly (#93490)
//
// As an optimization we want to pass homogeneous aggregate of floats
// greater than pointer size as indirect
if size > Pointer.size(self) {
arg.make_indirect();
}
} else {
// We want to pass small aggregates as immediates, but using
// a LLVM aggregate type for this leads to bad optimizations,
// so we pick an appropriately sized integer type instead.
//
// NOTE: This is sub-optimal because in the case of (f32, f32, u32, u32)
// we could do ([f32; 2], u64) which is better but this is the best we
// can do right now.
arg.cast_to(Reg { kind: RegKind::Integer, size });
}
}

// This is a fun case! The gist of what this is doing is
// that we want callers and callees to always agree on the
Expand All @@ -3220,24 +3253,9 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> {
&& self.tcx.sess.target.simd_types_indirect =>
{
arg.make_indirect();
return;
}

_ => return,
}

// Pass and return structures up to 2 pointers in size by value, matching `ScalarPair`.
// LLVM will usually pass these in 2 registers, which is more efficient than by-ref.
let max_by_val_size = Pointer.size(self) * 2;
let size = arg.layout.size;

if arg.layout.is_unsized() || size > max_by_val_size {
arg.make_indirect();
} else {
// We want to pass small aggregates as immediates, but using
// a LLVM aggregate type for this leads to bad optimizations,
// so we pick an appropriately sized integer type instead.
arg.cast_to(Reg { kind: RegKind::Integer, size });
_ => {}
}
};
fixup(&mut fn_abi.ret);
Expand Down
45 changes: 45 additions & 0 deletions src/test/assembly/x86-64-homogenous-floats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// assembly-output: emit-asm
// needs-llvm-components: x86
// compile-flags: --target x86_64-unknown-linux-gnu
// compile-flags: -C llvm-args=--x86-asm-syntax=intel
// compile-flags: -C opt-level=3

#![crate_type = "rlib"]
#![no_std]

// CHECK-LABEL: sum_f32:
// CHECK: addss xmm0, xmm1
// CHECK-NEXT: ret
#[no_mangle]
pub fn sum_f32(a: f32, b: f32) -> f32 {
a + b
}

// CHECK-LABEL: sum_f32x2:
// CHECK: addss xmm{{[0-9]}}, xmm{{[0-9]}}
// CHECK-NEXT: addss xmm{{[0-9]}}, xmm{{[0-9]}}
// CHECK-NEXT: ret
#[no_mangle]
pub fn sum_f32x2(a: [f32; 2], b: [f32; 2]) -> [f32; 2] {
[
a[0] + b[0],
a[1] + b[1],
]
}

// CHECK-LABEL: sum_f32x4:
// CHECK: mov rax, [[PTR_IN:.*]]
// CHECK-NEXT: movups [[XMMA:xmm[0-9]]], xmmword ptr [rsi]
// CHECK-NEXT: movups [[XMMB:xmm[0-9]]], xmmword ptr [rdx]
// CHECK-NEXT: addps [[XMMB]], [[XMMA]]
// CHECK-NEXT: movups xmmword ptr {{\[}}[[PTR_IN]]{{\]}}, [[XMMB]]
// CHECK-NEXT: ret
#[no_mangle]
pub fn sum_f32x4(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
[
a[0] + b[0],
a[1] + b[1],
a[2] + b[2],
a[3] + b[3],
]
}
32 changes: 32 additions & 0 deletions src/test/codegen/homogeneous-floats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//! Check that small (less then 128bits on x86_64) homogeneous floats are either pass as an array
//! or by a pointer

// compile-flags: -C no-prepopulate-passes -O
// only-x86_64

#![crate_type = "lib"]

pub struct Foo {
bar1: f32,
bar2: f32,
bar3: f32,
bar4: f32,
}

// CHECK: define [2 x float] @array_f32x2([2 x float] %0, [2 x float] %1)
#[no_mangle]
pub fn array_f32x2(a: [f32; 2], b: [f32; 2]) -> [f32; 2] {
todo!()
}

// CHECK: define void @array_f32x4([4 x float]* {{.*}} sret([4 x float]) {{.*}} %0, [4 x float]* {{.*}} %a, [4 x float]* {{.*}} %b)
#[no_mangle]
pub fn array_f32x4(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
todo!()
}

// CHECK: define void @array_f32x4_nested(%Foo* {{.*}} sret(%Foo) {{.*}} %0, %Foo* {{.*}} %a, %Foo* {{.*}} %b)
#[no_mangle]
pub fn array_f32x4_nested(a: Foo, b: Foo) -> Foo {
todo!()
}
184 changes: 184 additions & 0 deletions src/test/ui/abi/homogenous-floats-target-feature-mixup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// This test check that even if we mixup target feature of function with homogenous floats,
// the abi is sound and still produce the right answer.
//
// This is basically the same test as src/test/ui/simd/target-feature-mixup.rs but for floats and
// without #[repr(simd)]

// run-pass
// ignore-emscripten
// ignore-sgx no processes

#![feature(target_feature, cfg_target_feature)]
#![feature(avx512_target_feature)]

#![allow(overflowing_literals)]
#![allow(unused_variables)]
#![allow(stable_features)]

use std::process::{Command, ExitStatus};
use std::env;

fn main() {
if let Some(level) = env::args().nth(1) {
return test::main(&level)
}

let me = env::current_exe().unwrap();
for level in ["sse", "avx", "avx512"].iter() {
let status = Command::new(&me).arg(level).status().unwrap();
if status.success() {
println!("success with {}", level);
continue
}

// We don't actually know if our computer has the requisite target features
// for the test below. Testing for that will get added to libstd later so
// for now just assume sigill means this is a machine that can't run this test.
if is_sigill(status) {
println!("sigill with {}, assuming spurious", level);
continue
}
panic!("invalid status at {}: {}", level, status);
}
}

#[cfg(unix)]
fn is_sigill(status: ExitStatus) -> bool {
use std::os::unix::prelude::*;
status.signal() == Some(4)
}

#[cfg(windows)]
fn is_sigill(status: ExitStatus) -> bool {
status.code() == Some(0xc000001d)
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[allow(nonstandard_style)]
mod test {
#[derive(PartialEq, Debug, Clone, Copy)]
struct f32x2(f32, f32);

#[derive(PartialEq, Debug, Clone, Copy)]
struct f32x4(f32, f32, f32, f32);

#[derive(PartialEq, Debug, Clone, Copy)]
struct f32x8(f32, f32, f32, f32, f32, f32, f32, f32);

pub fn main(level: &str) {
unsafe {
main_normal(level);
main_sse(level);
if level == "sse" {
return
}
main_avx(level);
if level == "avx" {
return
}
main_avx512(level);
}
}

macro_rules! mains {
($(
$(#[$attr:meta])*
unsafe fn $main:ident(level: &str) {
...
}
)*) => ($(
$(#[$attr])*
unsafe fn $main(level: &str) {
let m128 = f32x2(1., 2.);
let m256 = f32x4(3., 4., 5., 6.);
let m512 = f32x8(7., 8., 9., 10., 11., 12., 13., 14.);
assert_eq!(id_sse_128(m128), m128);
assert_eq!(id_sse_256(m256), m256);
assert_eq!(id_sse_512(m512), m512);

if level == "sse" {
return
}
assert_eq!(id_avx_128(m128), m128);
assert_eq!(id_avx_256(m256), m256);
assert_eq!(id_avx_512(m512), m512);

if level == "avx" {
return
}
assert_eq!(id_avx512_128(m128), m128);
assert_eq!(id_avx512_256(m256), m256);
assert_eq!(id_avx512_512(m512), m512);
}
)*)
}

mains! {
unsafe fn main_normal(level: &str) { ... }
#[target_feature(enable = "sse2")]
unsafe fn main_sse(level: &str) { ... }
#[target_feature(enable = "avx")]
unsafe fn main_avx(level: &str) { ... }
#[target_feature(enable = "avx512bw")]
unsafe fn main_avx512(level: &str) { ... }
}

#[target_feature(enable = "sse2")]
unsafe fn id_sse_128(a: f32x2) -> f32x2 {
assert_eq!(a, f32x2(1., 2.));
a.clone()
}

#[target_feature(enable = "sse2")]
unsafe fn id_sse_256(a: f32x4) -> f32x4 {
assert_eq!(a, f32x4(3., 4., 5., 6.));
a.clone()
}

#[target_feature(enable = "sse2")]
unsafe fn id_sse_512(a: f32x8) -> f32x8 {
assert_eq!(a, f32x8(7., 8., 9., 10., 11., 12., 13., 14.));
a.clone()
}

#[target_feature(enable = "avx")]
unsafe fn id_avx_128(a: f32x2) -> f32x2 {
assert_eq!(a, f32x2(1., 2.));
a.clone()
}

#[target_feature(enable = "avx")]
unsafe fn id_avx_256(a: f32x4) -> f32x4 {
assert_eq!(a, f32x4(3., 4., 5., 6.));
a.clone()
}

#[target_feature(enable = "avx")]
unsafe fn id_avx_512(a: f32x8) -> f32x8 {
assert_eq!(a, f32x8(7., 8., 9., 10., 11., 12., 13., 14.));
a.clone()
}

#[target_feature(enable = "avx512bw")]
unsafe fn id_avx512_128(a: f32x2) -> f32x2 {
assert_eq!(a, f32x2(1., 2.));
a.clone()
}

#[target_feature(enable = "avx512bw")]
unsafe fn id_avx512_256(a: f32x4) -> f32x4 {
assert_eq!(a, f32x4(3., 4., 5., 6.));
a.clone()
}

#[target_feature(enable = "avx512bw")]
unsafe fn id_avx512_512(a: f32x8) -> f32x8 {
assert_eq!(a, f32x8(7., 8., 9., 10., 11., 12., 13., 14.));
a.clone()
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
mod test {
pub fn main(level: &str) {}
}
46 changes: 46 additions & 0 deletions src/test/ui/abi/homogenous-floats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// This test that no matter the optimization level or the target feature enable, the non
// aggregation of homogenous floats in the abi is sound and still produce the right answer.

// revisions: opt-0 opt-0-native opt-1 opt-1-native opt-2 opt-2-native opt-3 opt-3-native
// [opt-0]: compile-flags: -C opt-level=0
// [opt-1]: compile-flags: -C opt-level=1
// [opt-2]: compile-flags: -C opt-level=2
// [opt-3]: compile-flags: -C opt-level=3
// [opt-0-native]: compile-flags: -C target-cpu=native
// [opt-1-native]: compile-flags: -C target-cpu=native
// [opt-2-native]: compile-flags: -C target-cpu=native
// [opt-3-native]: compile-flags: -C target-cpu=native
// run-pass

#![feature(core_intrinsics)]

use std::intrinsics::black_box;

pub fn sum_f32(a: f32, b: f32) -> f32 {
a + b
}

pub fn sum_f32x2(a: [f32; 2], b: [f32; 2]) -> [f32; 2] {
[a[0] + b[0], a[1] + b[1]]
}

pub fn sum_f32x3(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
[a[0] + b[0], a[1] + b[1], a[2] + b[2]]
}

pub fn sum_f32x4(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
[a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]]
}

fn main() {
assert_eq!(1., black_box(sum_f32(black_box(0.), black_box(1.))));
assert_eq!([2., 2.], black_box(sum_f32x2(black_box([2., 0.]), black_box([0., 2.]))));
assert_eq!(
[3., 3., 3.],
black_box(sum_f32x3(black_box([1., 2., 3.]), black_box([2., 1., 0.])))
);
assert_eq!(
[4., 4., 4., 4.],
black_box(sum_f32x4(black_box([1., 2., 3., 4.]), black_box([3., 2., 1., 0.])))
);
}