Skip to content

Commit

Permalink
Implemented some missing functions
Browse files Browse the repository at this point in the history
These cannot be linked with LLVM because of the lack of `bfloat16` and `i1` types in Rust. So, inline asm was the only way
  • Loading branch information
sayantn authored and Amanieu committed Jul 6, 2024
1 parent cc5e826 commit fe8f300
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 13 deletions.
12 changes: 0 additions & 12 deletions crates/core_arch/missing-x86.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,6 @@
</p></details>


<details><summary>["AVX512_BF16", "AVX512VL"]</summary><p>

* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
* [ ] [`_mm_cvtness_sbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
* [ ] [`_mm_mask_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
* [ ] [`_mm_maskz_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
</p></details>


<details><summary>["AVX512_FP16"]</summary><p>

* [ ] [`_mm256_castpd_ph`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_castpd_ph)
Expand Down Expand Up @@ -1125,12 +1116,9 @@
* [ ] [`_mm256_bcstnesh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_bcstnesh_ps)
* [ ] [`_mm256_cvtneeph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneeph_ps)
* [ ] [`_mm256_cvtneoph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneoph_ps)
* [ ] [`_mm256_cvtneps_avx_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneps_avx_pbh)
* [ ] [`_mm_bcstnesh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_bcstnesh_ps)
* [ ] [`_mm_cvtneeph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneeph_ps)
* [ ] [`_mm_cvtneoph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneoph_ps)
* [ ] [`_mm_cvtneps_avx_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_avx_pbh)
* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
</p></details>


Expand Down
112 changes: 111 additions & 1 deletion crates/core_arch/src/x86/avx512bf16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! [AVX512BF16 intrinsics]: https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769&avx512techs=AVX512_BF16

use crate::arch::asm;
use crate::core_arch::{simd::*, x86::*};
use crate::intrinsics::simd::*;

Expand Down Expand Up @@ -490,9 +491,85 @@ pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
f32::from_bits((a as u32) << 16)
}

/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
/// floating-point elements, and store the results in dst.
///
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
#[inline]
#[target_feature(enable = "avx512bf16,avx512vl,sse")]
#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtneps_pbh(a: __m128) -> __m128bh {
let mut dst: __m128bh;
asm!(
"vcvtneps2bf16 {dst}, {src}",
dst = lateout(xmm_reg) dst,
src = in(xmm_reg) a,
options(pure, nomem, nostack, preserves_flags)
);
dst
}

/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
/// floating-point elements, and store the results in dst using writemask k (elements are copied
/// from src when the corresponding mask bit is not set).
///
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
#[inline]
#[target_feature(enable = "avx512bf16,avx512vl,sse,avx512f")]
#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_mask_cvtneps_pbh(src: __m128bh, k: __mmask8, a: __m128) -> __m128bh {
let mut dst = src;
asm!(
"vcvtneps2bf16 {dst}{{{k}}},{src}",
dst = inlateout(xmm_reg) dst,
src = in(xmm_reg) a,
k = in(kreg) k,
options(pure, nomem, nostack, preserves_flags)
);
dst
}

/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
/// floating-point elements, and store the results in dst using zeromask k (elements are zeroed out
/// when the corresponding mask bit is not set).
///
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
#[inline]
#[target_feature(enable = "avx512bf16,avx512vl,sse,avx512f")]
#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh {
let mut dst: __m128bh;
asm!(
"vcvtneps2bf16 {dst}{{{k}}}{{z}},{src}",
dst = lateout(xmm_reg) dst,
src = in(xmm_reg) a,
k = in(kreg) k,
options(pure, nomem, nostack, preserves_flags)
);
dst
}

/// Converts a single-precision (32-bit) floating-point element in a to a BF16 (16-bit) floating-point
/// element, and store the result in dst.
///
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
#[inline]
#[target_feature(enable = "avx512bf16,avx512vl")]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 {
simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0)
}

#[cfg(test)]
mod tests {
use crate::{core_arch::x86::*, mem::transmute};
use crate::core_arch::simd::u16x4;
use crate::{
core_arch::x86::*,
mem::{transmute, transmute_copy},
};
use stdarch_test::simd_test;

#[simd_test(enable = "avx512bf16,avx512vl")]
Expand Down Expand Up @@ -1836,4 +1913,37 @@ mod tests {
let r = _mm_cvtsbh_ss(BF16_ONE);
assert_eq!(r, 1.);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_cvtneps_pbh() {
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
let r: u16x4 = transmute_copy(&_mm_cvtneps_pbh(a));
let e = u16x4::new(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR);
assert_eq!(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_mask_cvtneps_pbh() {
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
let src = __m128bh(5, 6, 7, 8, !0, !0, !0, !0);
let k = 0b1010;
let r: u16x4 = transmute_copy(&_mm_mask_cvtneps_pbh(src, k, a));
let e = u16x4::new(5, BF16_TWO, 7, BF16_FOUR);
assert_eq!(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_maskz_cvtneps_pbh() {
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
let k = 0b1010;
let r: u16x4 = transmute_copy(&_mm_maskz_cvtneps_pbh(k, a));
let e = u16x4::new(0, BF16_TWO, 0, BF16_FOUR);
assert_eq!(r, e);
}

#[simd_test(enable = "avx512bf16,avx512vl")]
unsafe fn test_mm_cvtness_sbh() {
let r = _mm_cvtness_sbh(1.);
assert_eq!(r, BF16_ONE);
}
}
65 changes: 65 additions & 0 deletions crates/core_arch/src/x86/avxneconvert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::arch::asm;
use crate::core_arch::{simd::*, x86::*};

#[cfg(test)]
Expand Down Expand Up @@ -95,6 +96,50 @@ pub unsafe fn _mm256_cvtneobf16_ps(a: *const __m256bh) -> __m256 {
transmute(cvtneobf162ps_256(a))
}

/// Convert packed single precision (32-bit) floating-point elements in a to packed BF16 (16-bit) floating-point
/// elements, and store the results in dst.
///
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_avx_bf16)
#[inline]
#[target_feature(enable = "avxneconvert,sse")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(vcvtneps2bf16)
)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm_cvtneps_avx_pbh(a: __m128) -> __m128bh {
let mut dst: __m128bh;
asm!(
"{{vex}}vcvtneps2bf16 {dst},{src}",
dst = lateout(xmm_reg) dst,
src = in(xmm_reg) a,
options(pure, nomem, nostack, preserves_flags)
);
dst
}

/// Convert packed single precision (32-bit) floating-point elements in a to packed BF16 (16-bit) floating-point
/// elements, and store the results in dst.
///
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneps_avx_bf16)
#[inline]
#[target_feature(enable = "avxneconvert,sse,avx")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(vcvtneps2bf16)
)]
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh {
let mut dst: __m128bh;
asm!(
"{{vex}}vcvtneps2bf16 {dst},{src}",
dst = lateout(xmm_reg) dst,
src = in(ymm_reg) a,
options(pure, nomem, nostack, preserves_flags)
);
dst
}

#[allow(improper_ctypes)]
extern "C" {
#[link_name = "llvm.x86.vbcstnebf162ps128"]
Expand All @@ -115,7 +160,9 @@ extern "C" {

#[cfg(test)]
mod tests {
use crate::core_arch::simd::{u16x4, u16x8};
use crate::core_arch::x86::*;
use crate::mem::transmute_copy;
use std::ptr::addr_of;
use stdarch_test::simd_test;

Expand Down Expand Up @@ -185,4 +232,22 @@ mod tests {
let e = _mm256_setr_ps(2., 4., 6., 8., 2., 4., 6., 8.);
assert_eq_m256(r, e);
}

#[simd_test(enable = "avxneconvert")]
unsafe fn test_mm_cvtneps_avx_pbh() {
let a = _mm_setr_ps(1., 2., 3., 4.);
let r: u16x4 = transmute_copy(&_mm_cvtneps_avx_pbh(a));
let e = u16x4::new(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR);
assert_eq!(r, e);
}

#[simd_test(enable = "avxneconvert")]
unsafe fn test_mm256_cvtneps_avx_pbh() {
let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.);
let r: u16x8 = transmute(_mm256_cvtneps_avx_pbh(a));
let e = u16x8::new(
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
);
assert_eq!(r, e);
}
}

0 comments on commit fe8f300

Please sign in to comment.