Skip to content

Commit

Permalink
Added a bf16 type
Browse files Browse the repository at this point in the history
  • Loading branch information
sayantn authored and Amanieu committed Jul 6, 2024
1 parent fe8f300 commit 3dd9579
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 21 deletions.
17 changes: 9 additions & 8 deletions crates/core_arch/src/x86/avx512bf16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
#[inline]
#[target_feature(enable = "avx512bf16,avx512f")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
f32::from_bits((a as u32) << 16)
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm_cvtsbh_ss(a: bf16) -> f32 {
f32::from_bits((a.to_bits() as u32) << 16)
}

/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
Expand Down Expand Up @@ -558,9 +558,10 @@ pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh {
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 {
simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0)
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm_cvtness_sbh(a: f32) -> bf16 {
let value: u16 = simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0);
bf16::from_bits(value)
}

#[cfg(test)]
Expand Down Expand Up @@ -1910,7 +1911,7 @@ mod tests {

#[simd_test(enable = "avx512bf16")]
unsafe fn test_mm_cvtsbh_ss() {
let r = _mm_cvtsbh_ss(BF16_ONE);
let r = _mm_cvtsbh_ss(bf16::from_bits(BF16_ONE));
assert_eq!(r, 1.);
}

Expand Down Expand Up @@ -1944,6 +1945,6 @@ mod tests {
#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_cvtness_sbh() {
let r = _mm_cvtness_sbh(1.);
assert_eq!(r, BF16_ONE);
assert_eq!(r.to_bits(), BF16_ONE);
}
}
22 changes: 11 additions & 11 deletions crates/core_arch/src/x86/avxneconvert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::arch::asm;
use crate::core_arch::{simd::*, x86::*};
use crate::core_arch::x86::*;

#[cfg(test)]
use stdarch_test::assert_instr;
Expand All @@ -15,9 +15,9 @@ use stdarch_test::assert_instr;
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(vbcstnebf162ps)
)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
transmute(bcstnebf162ps_128(a))
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm_bcstnebf16_ps(a: *const bf16) -> __m128 {
bcstnebf162ps_128(a)
}

/// Convert scalar BF16 (16-bit) floating point element stored at memory locations starting at location
Expand All @@ -31,9 +31,9 @@ pub unsafe fn _mm_bcstnebf16_ps(a: *const u16) -> __m128 {
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(vbcstnebf162ps)
)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm256_bcstnebf16_ps(a: *const u16) -> __m256 {
transmute(bcstnebf162ps_256(a))
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub unsafe fn _mm256_bcstnebf16_ps(a: *const bf16) -> __m256 {
bcstnebf162ps_256(a)
}

/// Convert packed BF16 (16-bit) floating-point even-indexed elements stored at memory locations starting at
Expand Down Expand Up @@ -143,9 +143,9 @@ pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh {
#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.x86.vbcstnebf162ps128"]
fn bcstnebf162ps_128(a: *const u16) -> f32x4;
fn bcstnebf162ps_128(a: *const bf16) -> __m128;
#[link_name = "llvm.x86.vbcstnebf162ps256"]
fn bcstnebf162ps_256(a: *const u16) -> f32x8;
fn bcstnebf162ps_256(a: *const bf16) -> __m256;

#[link_name = "llvm.x86.vcvtneebf162ps128"]
fn cvtneebf162ps_128(a: *const __m128bh) -> __m128;
Expand Down Expand Up @@ -177,15 +177,15 @@ mod tests {

#[simd_test(enable = "avxneconvert")]
unsafe fn test_mm_bcstnebf16_ps() {
let a = BF16_ONE;
let a = bf16::from_bits(BF16_ONE);
let r = _mm_bcstnebf16_ps(addr_of!(a));
let e = _mm_set_ps(1., 1., 1., 1.);
assert_eq_m128(r, e);
}

#[simd_test(enable = "avxneconvert")]
unsafe fn test_mm256_bcstnebf16_ps() {
let a = BF16_ONE;
let a = bf16::from_bits(BF16_ONE);
let r = _mm256_bcstnebf16_ps(addr_of!(a));
let e = _mm256_set_ps(1., 1., 1., 1., 1., 1., 1., 1.);
assert_eq_m256(r, e);
Expand Down
25 changes: 25 additions & 0 deletions crates/core_arch/src/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,31 @@ types! {
);
}

/// The BFloat16 type used in AVX-512 intrinsics.
#[repr(transparent)]
#[derive(Copy, Clone, Debug)]
#[allow(non_camel_case_types)]
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub struct bf16(u16);

impl bf16 {
/// Raw transmutation from `u16`
#[inline]
#[must_use]
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub const fn from_bits(bits: u16) -> bf16 {
bf16(bits)
}

/// Raw transmutation to `u16`
#[inline]
#[must_use = "this returns the result of the operation, without modifying the original"]
#[unstable(feature = "stdarch_x86_avx512_bf16", issue = "127356")]
pub const fn to_bits(self) -> u16 {
self.0
}
}

/// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer
#[allow(non_camel_case_types)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
Expand Down
1 change: 1 addition & 0 deletions crates/stdarch-verify/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
"_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM },
"_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM },
"bool" => quote! { &BOOL },
"bf16" => quote! { &BF16 },
"f32" => quote! { &F32 },
"f64" => quote! { &F64 },
"i16" => quote! { &I16 },
Expand Down
8 changes: 6 additions & 2 deletions crates/stdarch-verify/tests/x86-intel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct Function {
has_test: bool,
}

static BF16: Type = Type::BFloat16;
static F32: Type = Type::PrimFloat(32);
static F64: Type = Type::PrimFloat(64);
static I8: Type = Type::PrimSigned(8);
Expand Down Expand Up @@ -65,6 +66,7 @@ enum Type {
PrimFloat(u8),
PrimSigned(u8),
PrimUnsigned(u8),
BFloat16,
MutPtr(&'static Type),
ConstPtr(&'static Type),
M128,
Expand Down Expand Up @@ -699,7 +701,8 @@ fn equate(
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
(&Type::PrimSigned(64), "__int64" | "long long") => {}
(&Type::PrimUnsigned(8), "unsigned char") => {}
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
(&Type::PrimUnsigned(16), "unsigned short") => {}
(&Type::BFloat16, "__bfloat16") => {}
(
&Type::PrimUnsigned(32),
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",
Expand Down Expand Up @@ -758,9 +761,10 @@ fn equate(
(&Type::ConstPtr(&Type::PrimSigned(8)), "char const*") => {}
(&Type::ConstPtr(&Type::PrimSigned(32)), "__int32 const*" | "int const*") => {}
(&Type::ConstPtr(&Type::PrimSigned(64)), "__int64 const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*" | "__bf16 const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(16)), "unsigned short const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(32)), "unsigned int const*") => {}
(&Type::ConstPtr(&Type::PrimUnsigned(64)), "unsigned __int64 const*") => {}
(&Type::ConstPtr(&Type::BFloat16), "__bf16 const*") => {}

(&Type::ConstPtr(&Type::M128), "__m128 const*") => {}
(&Type::ConstPtr(&Type::M128BH), "__m128bh const*") => {}
Expand Down

0 comments on commit 3dd9579

Please sign in to comment.