diff --git a/crates/core_arch/src/x86/avx512bf16.rs b/crates/core_arch/src/x86/avx512bf16.rs index a9be7de5db..7d99809353 100644 --- a/crates/core_arch/src/x86/avx512bf16.rs +++ b/crates/core_arch/src/x86/avx512bf16.rs @@ -486,9 +486,9 @@ pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss) #[inline] #[target_feature(enable = "avx512bf16,avx512f")] -#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] -pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 { - f32::from_bits((a as u32) << 16) +#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] +pub unsafe fn _mm_cvtsbh_ss(a: bf16) -> f32 { + f32::from_bits((a.to_bits() as u32) << 16) } /// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit) @@ -558,9 +558,10 @@ pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh { /// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh) #[inline] #[target_feature(enable = "avx512bf16,avx512vl")] -#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] -pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 { - simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0) +#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] +pub unsafe fn _mm_cvtness_sbh(a: f32) -> bf16 { + let value: u16 = simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0); + bf16::from_bits(value) } #[cfg(test)] @@ -1910,7 +1911,7 @@ mod tests { #[simd_test(enable = "avx512bf16")] unsafe fn test_mm_cvtsbh_ss() { - let r = _mm_cvtsbh_ss(BF16_ONE); + let r = _mm_cvtsbh_ss(bf16::from_bits(BF16_ONE)); assert_eq!(r, 1.); } @@ -1944,6 +1945,6 @@ mod tests { #[simd_test(enable = "avx512bf16,avx512vl")] unsafe fn test_mm_cvtness_sbh() { let r = _mm_cvtness_sbh(1.); - assert_eq!(r, BF16_ONE); + assert_eq!(r.to_bits(), BF16_ONE); } } diff --git a/crates/core_arch/src/x86/avxneconvert.rs b/crates/core_arch/src/x86/avxneconvert.rs index 1834b43a81..4eb1a9b30c 100644 --- a/crates/core_arch/src/x86/avxneconvert.rs +++ b/crates/core_arch/src/x86/avxneconvert.rs @@ -1,5 +1,5 @@ use crate::arch::asm; -use crate::core_arch::{simd::*, x86::*}; +use crate::core_arch::x86::*; #[cfg(test)] use stdarch_test::assert_instr; @@ -15,9 +15,9 @@ use stdarch_test::assert_instr; all(test, any(target_os = "linux", target_env = "msvc")), assert_instr(vbcstnebf162ps) )] -#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] -pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 { - transmute(bcstnebf162ps_128(a)) +#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] +pub unsafe fn _mm_bcstnebf16_ps(a: *const bf16) -> __m128 { + bcstnebf162ps_128(a) } /// Convert scalar BF16 (16-bit) floating point element stored at memory locations starting at location @@ -31,9 +31,9 @@ pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 { all(test, any(target_os = "linux", target_env = "msvc")), assert_instr(vbcstnebf162ps) )] -#[unstable(feature = "stdarch_x86_avx512", issue = "111137")] -pub unsafe fn _mm256_bcstnebf16_ps(a: *const u16) -> __m256 { - transmute(bcstnebf162ps_256(a)) +#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] +pub unsafe fn _mm256_bcstnebf16_ps(a: *const bf16) -> __m256 { + bcstnebf162ps_256(a) } /// Convert packed BF16 (16-bit) floating-point even-indexed elements stored at memory locations starting at @@ -143,9 +143,9 @@ pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh { #[allow(improper_ctypes)] extern "C" { #[link_name = "llvm.x86.vbcstnebf162ps128"] - fn bcstnebf162ps_128(a: *const u16) -> f32x4; + fn bcstnebf162ps_128(a: *const bf16) -> __m128; #[link_name = "llvm.x86.vbcstnebf162ps256"] - fn bcstnebf162ps_256(a: *const u16) -> f32x8; + fn bcstnebf162ps_256(a: *const bf16) -> __m256; #[link_name = "llvm.x86.vcvtneebf162ps128"] fn cvtneebf162ps_128(a: *const __m128bh) -> __m128; @@ -177,7 +177,7 @@ mod tests { #[simd_test(enable = "avxneconvert")] unsafe fn test_mm_bcstnebf16_ps() { - let a = BF16_ONE; + let a = bf16::from_bits(BF16_ONE); let r = _mm_bcstnebf16_ps(addr_of!(a)); let e = _mm_set_ps(1., 1., 1., 1.); assert_eq_m128(r, e); @@ -185,7 +185,7 @@ mod tests { #[simd_test(enable = "avxneconvert")] unsafe fn test_mm256_bcstnebf16_ps() { - let a = BF16_ONE; + let a = bf16::from_bits(BF16_ONE); let r = _mm256_bcstnebf16_ps(addr_of!(a)); let e = _mm256_set_ps(1., 1., 1., 1., 1., 1., 1., 1.); assert_eq_m256(r, e); diff --git a/crates/core_arch/src/x86/mod.rs b/crates/core_arch/src/x86/mod.rs index 6f8c51c16a..9365fe10a2 100644 --- a/crates/core_arch/src/x86/mod.rs +++ b/crates/core_arch/src/x86/mod.rs @@ -337,6 +337,31 @@ types! { ); } +/// The BFloat16 type used in AVX-512 intrinsics. +#[repr(transparent)] +#[derive(Copy, Clone, Debug)] +#[allow(non_camel_case_types)] +#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] +pub struct bf16(u16); + +impl bf16 { + /// Raw transmutation from `u16` + #[inline] + #[must_use] + #[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] + pub const fn from_bits(bits: u16) -> bf16 { + bf16(bits) + } + + /// Raw transmutation to `u16` + #[inline] + #[must_use = "this returns the result of the operation, without modifying the original"] + #[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")] + pub const fn to_bits(self) -> u16 { + self.0 + } +} + /// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer #[allow(non_camel_case_types)] #[unstable(feature = "stdarch_x86_avx512", issue = "111137")] diff --git a/crates/stdarch-verify/src/lib.rs b/crates/stdarch-verify/src/lib.rs index ff31c31c89..94569dfd0c 100644 --- a/crates/stdarch-verify/src/lib.rs +++ b/crates/stdarch-verify/src/lib.rs @@ -197,6 +197,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream { "_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM }, "_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM }, "bool" => quote! { &BOOL }, + "bf16" => quote! { &BF16 }, "f32" => quote! { &F32 }, "f64" => quote! { &F64 }, "i16" => quote! { &I16 }, diff --git a/crates/stdarch-verify/tests/x86-intel.rs b/crates/stdarch-verify/tests/x86-intel.rs index f38c7b1a3f..8de2c88b81 100644 --- a/crates/stdarch-verify/tests/x86-intel.rs +++ b/crates/stdarch-verify/tests/x86-intel.rs @@ -22,6 +22,7 @@ struct Function { has_test: bool, } +static BF16: Type = Type::BFloat16; static F32: Type = Type::PrimFloat(32); static F64: Type = Type::PrimFloat(64); static I8: Type = Type::PrimSigned(8); @@ -65,6 +66,7 @@ enum Type { PrimFloat(u8), PrimSigned(u8), PrimUnsigned(u8), + BFloat16, MutPtr(&'static Type), ConstPtr(&'static Type), M128, @@ -699,7 +701,8 @@ fn equate( (&Type::PrimSigned(32), "__int32" | "const int" | "int") => {} (&Type::PrimSigned(64), "__int64" | "long long") => {} (&Type::PrimUnsigned(8), "unsigned char") => {} - (&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {} + (&Type::PrimUnsigned(16), "unsigned short") => {} + (&Type::BFloat16, "__bfloat16") => {} ( &Type::PrimUnsigned(32), "unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int", @@ -758,9 +761,10 @@ fn equate( (&Type::ConstPtr(&Type::PrimSigned(8)), "char const*") => {} (&Type::ConstPtr(&Type::PrimSigned(32)), "__int32 const*" | "int const*") => {} (&Type::ConstPtr(&Type::PrimSigned(64)), "__int64 const*") => {} - (&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*" | "__bf16 const*") => {} + (&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*") => {} (&Type::ConstPtr(&Type::PrimUnsigned(32)), "unsigned int const*") => {} (&Type::ConstPtr(&Type::PrimUnsigned(64)), "unsigned __int64 const*") => {} + (&Type::ConstPtr(&Type::BFloat16), "__bf16 const*") => {} (&Type::ConstPtr(&Type::M128), "__m128 const*") => {} (&Type::ConstPtr(&Type::M128BH), "__m128bh const*") => {}