Skip to content

Commit

Permalink
Ban non-array SIMD
Browse files Browse the repository at this point in the history
  • Loading branch information
scottmcm committed Aug 22, 2024
1 parent a32d4a0 commit f48908d
Show file tree
Hide file tree
Showing 85 changed files with 659 additions and 811 deletions.
42 changes: 22 additions & 20 deletions compiler/rustc_hir_analysis/src/check/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1063,20 +1063,29 @@ pub fn check_simd(tcx: TyCtxt<'_>, sp: Span, def_id: LocalDefId) {
struct_span_code_err!(tcx.dcx(), sp, E0075, "SIMD vector cannot be empty").emit();
return;
}
let e = fields[FieldIdx::ZERO].ty(tcx, args);
if !fields.iter().all(|f| f.ty(tcx, args) == e) {
struct_span_code_err!(tcx.dcx(), sp, E0076, "SIMD vector should be homogeneous")
.with_span_label(sp, "SIMD elements must have the same type")

let array_field = &fields[FieldIdx::ZERO];
let array_ty = array_field.ty(tcx, args);
let ty::Array(element_ty, len_const) = array_ty.kind() else {
struct_span_code_err!(
tcx.dcx(),
sp,
E0076,
"SIMD vector's only field must be an array"
)
.with_span_label(tcx.def_span(array_field.did), "not an array")
.emit();
return;
};

if let Some(second_field) = fields.get(FieldIdx::from_u32(1)) {
struct_span_code_err!(tcx.dcx(), sp, E0075, "SIMD vector cannot have multiple fields")
.with_span_label(tcx.def_span(second_field.did), "excess field")
.emit();
return;
}

let len = if let ty::Array(_ty, c) = e.kind() {
c.try_eval_target_usize(tcx, tcx.param_env(def.did()))
} else {
Some(fields.len() as u64)
};
if let Some(len) = len {
if let Some(len) = len_const.try_eval_target_usize(tcx, tcx.param_env(def.did())) {
if len == 0 {
struct_span_code_err!(tcx.dcx(), sp, E0075, "SIMD vector cannot be empty").emit();
return;
Expand All @@ -1096,16 +1105,9 @@ pub fn check_simd(tcx: TyCtxt<'_>, sp: Span, def_id: LocalDefId) {
// These are scalar types which directly match a "machine" type
// Yes: Integers, floats, "thin" pointers
// No: char, "fat" pointers, compound types
match e.kind() {
ty::Param(_) => (), // pass struct<T>(T, T, T, T) through, let monomorphization catch errors
ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::RawPtr(_, _) => (), // struct(u8, u8, u8, u8) is ok
ty::Array(t, _) if matches!(t.kind(), ty::Param(_)) => (), // pass struct<T>([T; N]) through, let monomorphization catch errors
ty::Array(t, _clen)
if matches!(
t.kind(),
ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::RawPtr(_, _)
) =>
{ /* struct([f32; 4]) is ok */ }
match element_ty.kind() {
ty::Param(_) => (), // pass struct<T>([T; 4]) through, let monomorphization catch errors
ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::RawPtr(_, _) => (), // struct([u8; 4]) is ok
_ => {
struct_span_code_err!(
tcx.dcx(),
Expand Down
38 changes: 15 additions & 23 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1091,29 +1091,21 @@ impl<'tcx> Ty<'tcx> {
}

pub fn simd_size_and_type(self, tcx: TyCtxt<'tcx>) -> (u64, Ty<'tcx>) {
match self.kind() {
Adt(def, args) => {
assert!(def.repr().simd(), "`simd_size_and_type` called on non-SIMD type");
let variant = def.non_enum_variant();
let f0_ty = variant.fields[FieldIdx::ZERO].ty(tcx, args);

match f0_ty.kind() {
// If the first field is an array, we assume it is the only field and its
// elements are the SIMD components.
Array(f0_elem_ty, f0_len) => {
// FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112
// The way we evaluate the `N` in `[T; N]` here only works since we use
// `simd_size_and_type` post-monomorphization. It will probably start to ICE
// if we use it in generic code. See the `simd-array-trait` ui test.
(f0_len.eval_target_usize(tcx, ParamEnv::empty()), *f0_elem_ty)
}
// Otherwise, the fields of this Adt are the SIMD components (and we assume they
// all have the same type).
_ => (variant.fields.len() as u64, f0_ty),
}
}
_ => bug!("`simd_size_and_type` called on invalid type"),
}
let Adt(def, args) = self.kind() else {
bug!("`simd_size_and_type` called on invalid type")
};
assert!(def.repr().simd(), "`simd_size_and_type` called on non-SIMD type");
let variant = def.non_enum_variant();
assert_eq!(variant.fields.len(), 1);
let field_ty = variant.fields[FieldIdx::ZERO].ty(tcx, args);
let Array(f0_elem_ty, f0_len) = field_ty.kind() else {
bug!("Simd type has non-array field type {field_ty:?}")
};
// FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112
// The way we evaluate the `N` in `[T; N]` here only works since we use
// `simd_size_and_type` post-monomorphization. It will probably start to ICE
// if we use it in generic code. See the `simd-array-trait` ui test.
(f0_len.eval_target_usize(tcx, ParamEnv::empty()), *f0_elem_ty)
}

#[inline]
Expand Down
8 changes: 4 additions & 4 deletions tests/codegen/align-byval-vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ trait Freeze {}
trait Copy {}

#[repr(simd)]
pub struct i32x4(i32, i32, i32, i32);
pub struct i32x4([i32; 4]);

#[repr(C)]
pub struct Foo {
Expand All @@ -47,12 +47,12 @@ extern "C" {
}

pub fn main() {
unsafe { f(Foo { a: i32x4(1, 2, 3, 4), b: 0 }) }
unsafe { f(Foo { a: i32x4([1, 2, 3, 4]), b: 0 }) }

unsafe {
g(DoubleFoo {
one: Foo { a: i32x4(1, 2, 3, 4), b: 0 },
two: Foo { a: i32x4(1, 2, 3, 4), b: 0 },
one: Foo { a: i32x4([1, 2, 3, 4]), b: 0 },
two: Foo { a: i32x4([1, 2, 3, 4]), b: 0 },
})
}
}
36 changes: 10 additions & 26 deletions tests/codegen/const-vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,11 @@
// Setting up structs that can be used as const vectors
#[repr(simd)]
#[derive(Clone)]
pub struct i8x2(i8, i8);
pub struct i8x2([i8; 2]);

#[repr(simd)]
#[derive(Clone)]
pub struct i8x2_arr([i8; 2]);

#[repr(simd)]
#[derive(Clone)]
pub struct f32x2(f32, f32);

#[repr(simd)]
#[derive(Clone)]
pub struct f32x2_arr([f32; 2]);
pub struct f32x2([f32; 2]);

#[repr(simd, packed)]
#[derive(Copy, Clone)]
Expand All @@ -35,42 +27,34 @@ pub struct Simd<T, const N: usize>([T; N]);
// that they are called with a const vector

extern "unadjusted" {
#[no_mangle]
fn test_i8x2(a: i8x2);
}

extern "unadjusted" {
#[no_mangle]
fn test_i8x2_two_args(a: i8x2, b: i8x2);
}

extern "unadjusted" {
#[no_mangle]
fn test_i8x2_mixed_args(a: i8x2, c: i32, b: i8x2);
}

extern "unadjusted" {
#[no_mangle]
fn test_i8x2_arr(a: i8x2_arr);
fn test_i8x2_arr(a: i8x2);
}

extern "unadjusted" {
#[no_mangle]
fn test_f32x2(a: f32x2);
}

extern "unadjusted" {
#[no_mangle]
fn test_f32x2_arr(a: f32x2_arr);
fn test_f32x2_arr(a: f32x2);
}

extern "unadjusted" {
#[no_mangle]
fn test_simd(a: Simd<i32, 4>);
}

extern "unadjusted" {
#[no_mangle]
fn test_simd_unaligned(a: Simd<i32, 3>);
}

Expand All @@ -81,22 +65,22 @@ extern "unadjusted" {
pub fn do_call() {
unsafe {
// CHECK: call void @test_i8x2(<2 x i8> <i8 32, i8 64>
test_i8x2(const { i8x2(32, 64) });
test_i8x2(const { i8x2([32, 64]) });

// CHECK: call void @test_i8x2_two_args(<2 x i8> <i8 32, i8 64>, <2 x i8> <i8 8, i8 16>
test_i8x2_two_args(const { i8x2(32, 64) }, const { i8x2(8, 16) });
test_i8x2_two_args(const { i8x2([32, 64]) }, const { i8x2([8, 16]) });

// CHECK: call void @test_i8x2_mixed_args(<2 x i8> <i8 32, i8 64>, i32 43, <2 x i8> <i8 8, i8 16>
test_i8x2_mixed_args(const { i8x2(32, 64) }, 43, const { i8x2(8, 16) });
test_i8x2_mixed_args(const { i8x2([32, 64]) }, 43, const { i8x2([8, 16]) });

// CHECK: call void @test_i8x2_arr(<2 x i8> <i8 32, i8 64>
test_i8x2_arr(const { i8x2_arr([32, 64]) });
test_i8x2_arr(const { i8x2([32, 64]) });

// CHECK: call void @test_f32x2(<2 x float> <float 0x3FD47AE140000000, float 0x3FE47AE140000000>
test_f32x2(const { f32x2(0.32, 0.64) });
test_f32x2(const { f32x2([0.32, 0.64]) });

// CHECK: void @test_f32x2_arr(<2 x float> <float 0x3FD47AE140000000, float 0x3FE47AE140000000>
test_f32x2_arr(const { f32x2_arr([0.32, 0.64]) });
test_f32x2_arr(const { f32x2([0.32, 0.64]) });

// CHECK: call void @test_simd(<4 x i32> <i32 2, i32 4, i32 6, i32 8>
test_simd(const { Simd::<i32, 4>([2, 4, 6, 8]) });
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/repr/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub extern "C" fn test_Nested2(_: Nested2) -> Nested2 {
}

#[repr(simd)]
struct f32x4(f32, f32, f32, f32);
struct f32x4([f32; 4]);

#[repr(transparent)]
pub struct Vector(f32x4);
Expand Down
19 changes: 7 additions & 12 deletions tests/codegen/simd-intrinsic/simd-intrinsic-float-abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x2(pub f32, pub f32);
pub struct f32x2(pub [f32; 2]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x4(pub f32, pub f32, pub f32, pub f32);
pub struct f32x4(pub [f32; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x8(pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32);
pub struct f32x8(pub [f32; 8]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x16(pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32);
pub struct f32x16(pub [f32; 16]);

extern "rust-intrinsic" {
fn simd_fabs<T>(x: T) -> T;
Expand Down Expand Up @@ -59,16 +55,15 @@ pub unsafe fn fabs_32x16(a: f32x16) -> f32x16 {

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x2(pub f64, pub f64);
pub struct f64x2(pub [f64; 2]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x4(pub f64, pub f64, pub f64, pub f64);
pub struct f64x4(pub [f64; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
pub f64, pub f64, pub f64, pub f64);
pub struct f64x8(pub [f64; 8]);

// CHECK-LABEL: @fabs_64x4
#[no_mangle]
Expand Down
19 changes: 7 additions & 12 deletions tests/codegen/simd-intrinsic/simd-intrinsic-float-ceil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x2(pub f32, pub f32);
pub struct f32x2(pub [f32; 2]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x4(pub f32, pub f32, pub f32, pub f32);
pub struct f32x4(pub [f32; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x8(pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32);
pub struct f32x8(pub [f32; 8]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x16(pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32);
pub struct f32x16(pub [f32; 16]);

extern "rust-intrinsic" {
fn simd_ceil<T>(x: T) -> T;
Expand Down Expand Up @@ -59,16 +55,15 @@ pub unsafe fn ceil_32x16(a: f32x16) -> f32x16 {

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x2(pub f64, pub f64);
pub struct f64x2(pub [f64; 2]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x4(pub f64, pub f64, pub f64, pub f64);
pub struct f64x4(pub [f64; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
pub f64, pub f64, pub f64, pub f64);
pub struct f64x8(pub [f64; 8]);

// CHECK-LABEL: @ceil_64x4
#[no_mangle]
Expand Down
19 changes: 7 additions & 12 deletions tests/codegen/simd-intrinsic/simd-intrinsic-float-cos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x2(pub f32, pub f32);
pub struct f32x2(pub [f32; 2]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x4(pub f32, pub f32, pub f32, pub f32);
pub struct f32x4(pub [f32; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x8(pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32);
pub struct f32x8(pub [f32; 8]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f32x16(pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32,
pub f32, pub f32, pub f32, pub f32);
pub struct f32x16(pub [f32; 16]);

extern "rust-intrinsic" {
fn simd_fcos<T>(x: T) -> T;
Expand Down Expand Up @@ -59,16 +55,15 @@ pub unsafe fn fcos_32x16(a: f32x16) -> f32x16 {

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x2(pub f64, pub f64);
pub struct f64x2(pub [f64; 2]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x4(pub f64, pub f64, pub f64, pub f64);
pub struct f64x4(pub [f64; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct f64x8(pub f64, pub f64, pub f64, pub f64,
pub f64, pub f64, pub f64, pub f64);
pub struct f64x8(pub [f64; 8]);

// CHECK-LABEL: @fcos_64x4
#[no_mangle]
Expand Down
Loading

0 comments on commit f48908d

Please sign in to comment.