Skip to content

Commit

Permalink
[SYCL][ESIMD] Fix incorrect handling of non native floating types (in…
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 authored Aug 26, 2024
1 parent 5b02c4c commit 894f40a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 11 deletions.
42 changes: 31 additions & 11 deletions sycl/include/sycl/ext/intel/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ template <typename TRes, typename TArg, int SZ>
ESIMD_NODEBUG ESIMD_INLINE simd<TRes, SZ>
__esimd_abs_common_internal(simd<TArg, SZ> src0) {
simd<TArg, SZ> Result;
if constexpr (detail::is_generic_floating_point_v<TArg>)
Result = simd<TArg, SZ>(__spirv_ocl_fabs<TArg, SZ>(src0.data()));
else
if constexpr (detail::is_generic_floating_point_v<TArg>) {
using CppT = __ESIMD_DNS::element_type_traits<TArg>::EnclosingCppT;
Result =
__ESIMD_DNS::convert_vector<TArg, CppT, SZ>(__spirv_ocl_fabs<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, TArg, SZ>(src0.data())));
} else
Result = simd<TArg, SZ>(__spirv_ocl_s_abs<TArg, SZ>(src0.data()));
return convert<TRes>(Result);
}
Expand Down Expand Up @@ -184,8 +187,12 @@ template <typename T, int SZ, class Sat = saturation_off_tag>
__ESIMD_API simd<T, SZ>(max)(simd<T, SZ> src0, simd<T, SZ> src1, Sat sat = {}) {
constexpr bool is_sat = std::is_same_v<Sat, saturation_on_tag>;

if constexpr (std::is_floating_point<T>::value) {
auto Result = __spirv_ocl_fmax<T, SZ>(src0.data(), src1.data());
if constexpr (detail::is_generic_floating_point_v<T>) {
using CppT = __ESIMD_DNS::element_type_traits<T>::EnclosingCppT;
auto Result =
__ESIMD_DNS::convert_vector<T, CppT, SZ>(__spirv_ocl_fmax<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src0.data()),
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src1.data())));
if constexpr (is_sat)
Result = __esimd_sat<T, T, SZ>(Result);
return simd<T, SZ>(Result);
Expand Down Expand Up @@ -269,8 +276,12 @@ template <typename T, int SZ, class Sat = saturation_off_tag>
__ESIMD_API simd<T, SZ>(min)(simd<T, SZ> src0, simd<T, SZ> src1, Sat sat = {}) {
constexpr bool is_sat = std::is_same_v<Sat, saturation_on_tag>;

if constexpr (std::is_floating_point<T>::value) {
auto Result = __spirv_ocl_fmin<T, SZ>(src0.data(), src1.data());
if constexpr (detail::is_generic_floating_point_v<T>) {
using CppT = __ESIMD_DNS::element_type_traits<T>::EnclosingCppT;
auto Result =
__ESIMD_DNS::convert_vector<T, CppT, SZ>(__spirv_ocl_fmin<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src0.data()),
__ESIMD_DNS::convert_vector<CppT, T, SZ>(src1.data())));
if constexpr (is_sat)
Result = __esimd_sat<T, T, SZ>(Result);
return simd<T, SZ>(Result);
Expand Down Expand Up @@ -1465,8 +1476,12 @@ template <typename T0, typename T1, int SZ> struct esimd_apply_prod {
template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_max {
template <typename... T>
simd<T0, SZ> operator()(simd<T1, SZ> v1, simd<T1, SZ> v2) {
if constexpr (std::is_floating_point<T1>::value) {
return __spirv_ocl_fmax<T1, SZ>(v1.data(), v2.data());
if constexpr (detail::is_generic_floating_point_v<T1>) {
using CppT = __ESIMD_DNS::element_type_traits<T1>::EnclosingCppT;
return __ESIMD_DNS::convert_vector<T1, CppT, SZ>(
__spirv_ocl_fmax<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v1.data()),
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v2.data())));
} else if constexpr (std::is_unsigned<T1>::value) {
return __esimd_umax<T1, SZ>(v1.data(), v2.data());
} else {
Expand All @@ -1478,8 +1493,13 @@ template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_max {
template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_min {
template <typename... T>
simd<T0, SZ> operator()(simd<T1, SZ> v1, simd<T1, SZ> v2) {
if constexpr (std::is_floating_point<T1>::value) {
return __spirv_ocl_fmin<T1, SZ>(v1.data(), v2.data());

if constexpr (detail::is_generic_floating_point_v<T1>) {
using CppT = __ESIMD_DNS::element_type_traits<T1>::EnclosingCppT;
return __ESIMD_DNS::convert_vector<T1, CppT, SZ>(
__spirv_ocl_fmin<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v1.data()),
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v2.data())));
} else if constexpr (std::is_unsigned<T1>::value) {
return __esimd_umin<T1, SZ>(v1.data(), v2.data());
} else {
Expand Down
92 changes: 92 additions & 0 deletions sycl/test-e2e/ESIMD/spirv_fp_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//==- spirv_fp_test.cpp - Test for abs function -==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: arch-intel_gpu_pvc
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <sycl/detail/core.hpp>
#include <sycl/ext/intel/esimd.hpp>

#include <sycl/usm.hpp>
#include <sycl/usm/usm_allocator.hpp>

using namespace sycl;
using namespace sycl::ext::intel::esimd;
using bf16 = sycl::ext::oneapi::bfloat16;
using tfloat32 = sycl::ext::intel::experimental::esimd::tfloat32;

template <typename DataT>
using shared_allocator = sycl::usm_allocator<DataT, sycl::usm::alloc::shared>;
template <typename DataT>
using shared_vector = std::vector<DataT, shared_allocator<DataT>>;

template <typename T, int N>
bool test(sycl::queue &Queue, T testValue1, T testValue2) {
shared_allocator<T> Allocator(Queue);

shared_vector<T> OutputAbs(N, 0, Allocator);
shared_vector<T> OutputMin(N, 0, Allocator);
shared_vector<T> OutputMax(N, 0, Allocator);

auto *OutputAbsPtr = OutputAbs.data();
auto *OutputMinPtr = OutputMin.data();
auto *OutputMaxPtr = OutputMax.data();

Queue.submit([&](sycl::handler &cgh) {
auto Kernel = ([=]() SYCL_ESIMD_KERNEL {
simd<T, N> Input1 = testValue1;
simd<T, N> Input2 = testValue2;
simd<T, N> ResultAbs = __ESIMD_NS::abs(Input1);
simd<T, N> ResultMin = __ESIMD_NS::min(Input1, Input2);
simd<T, N> ResultMax = __ESIMD_NS::max(Input1, Input2);
ResultAbs.copy_to(OutputAbsPtr);
ResultMin.copy_to(OutputMinPtr);
ResultMax.copy_to(OutputMaxPtr);
});
cgh.single_task(Kernel);
});
Queue.wait();

for (int I = 0; I < N; I++) {
if (std::abs(testValue1) != OutputAbs[I]) {
std::cout << "Incorrect value for abs at index " << I << " "
<< std::abs(testValue1) << " != " << OutputAbs[I] << std::endl;
return false;
}
if (std::min(testValue1, testValue2) != OutputMin[I]) {
std::cout << "Incorrect value for min at index " << I << " "
<< std::min(testValue1, testValue2) << " != " << OutputMin[I]
<< std::endl;
return false;
}

if (std::max(testValue1, testValue2) != OutputMax[I]) {
std::cout << "Incorrect value for max at index " << I << " "
<< std::max(testValue1, testValue2) << " != " << OutputMax[I]
<< std::endl;
return false;
}
}

return true;
}

int main() {

bool Pass = true;
sycl::queue Q;
Pass &= test<bf16, 8>(Q, -1, -2);
Pass &= test<tfloat32, 8>(Q, -1, -2);

if (Pass)
std::cout << "Pass" << std::endl;
else
std::cout << "Fail" << std::endl;

return !Pass;
}

0 comments on commit 894f40a

Please sign in to comment.