diff --git a/src/aten/sycl/Loops.h b/src/aten/sycl/Loops.h index cc83a91fe..4fbbca77f 100644 --- a/src/aten/sycl/Loops.h +++ b/src/aten/sycl/Loops.h @@ -1,10 +1,10 @@ #pragma once +#include #include #include #include #include -#include #include #include @@ -14,7 +14,9 @@ using namespace at::xpu; -namespace at { namespace native { namespace xpu { +namespace at { +namespace native { +namespace xpu { template inline void elementwise_kernel_helper(func_t f, policy_t policy) { @@ -40,28 +42,6 @@ inline void elementwise_kernel_helper(func_t f, policy_t policy) { policy.store(results); } -template -struct ElementwiseKernel { - void operator()(sycl::nd_item<1> item) const { - int glbsz = item.get_global_range(0); - int gid = item.get_global_linear_id(); - #pragma unroll - for (int i = 0; i < vec_size; i++) { - if (gid < numel_) { - f_(gid); - gid += glbsz; - } - } - }; - - ElementwiseKernel(int numel, func_t f) - : numel_(numel), f_(f) {} - - private: - int numel_; - func_t f_; -}; - template < typename func_t, typename array_t, @@ -83,8 +63,7 @@ struct UnrolledElementwiseKernel { in_calc_t, out_calc_t, loader_t, - storer_t>( - data_, remaining, ic_, oc_, l_, s_, lid, grpid, grpsz); + storer_t>(data_, remaining, ic_, oc_, l_, s_, lid, grpid, grpsz); elementwise_kernel_helper(f_, policy); }; @@ -129,15 +108,7 @@ struct VectorizedElementwiseKernel { decltype(oc), at::native::memory::LoadWithoutCast, at::native::memory::StoreWithoutCast>( - data_, - remaining, - ic_, - oc, - l, - s, - lid, - grpid, - grpsz); + data_, remaining, ic_, oc, l, s, lid, grpid, grpsz); elementwise_kernel_helper(f_, policy); } else { auto policy = at::native::memory::policies:: @@ -239,6 +210,48 @@ struct UnrolledElementwiseKernelForMultiOutputs { out_calc_t oc_; }; +template +struct ElementwiseGroupRangeKernel { + void operator()(sycl::nd_item<1> item) const { + int wg_sz = item.get_local_range(0); + int group_work_size = wg_sz * vec_size; + int idx = group_work_size * item.get_group(0) + item.get_local_id(0); +#pragma unroll + for (int i = 0; i < vec_size; i++) { + if (idx < numel_) { + f_(idx); + idx += wg_sz; + } + } + }; + + ElementwiseGroupRangeKernel(int numel, func_t f) : numel_(numel), f_(f) {} + + private: + int numel_; + func_t f_; +}; + +template +struct ElementwiseGlobalRangeKernel { + void operator()(sycl::nd_item<1> item) const { + int linear_idx = + item.get_group(0) * item.get_local_range(0) + item.get_local_id(0); + for (int idx = linear_idx; idx < numel_; + idx += item.get_group_range(0) * item.get_local_range(0)) { + if (idx < numel_) { + f_(idx); + } + } + }; + + ElementwiseGlobalRangeKernel(int numel, func_t f) : numel_(numel), f_(f) {} + + private: + int numel_; + func_t f_; +}; + template < typename arg0_t, int ntensors, @@ -292,19 +305,35 @@ struct LegacyKernelWithCastScalarFunctor { }; template -static void launch_legacy_kernel(int64_t N, const func_t& f) { +static void launch_legacy_group_range_kernel(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { return; } - auto ker = ElementwiseKernel(N, f); + auto ker = ElementwiseGroupRangeKernel(N, f); int wg_sz = syclMaxWorkItemsPerEU(); int num_wg = ceil_div(N, wg_sz * vec_size); sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); } +template +static void launch_legacy_global_range_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + + auto ker = ElementwiseGlobalRangeKernel(N, f); + + int wg_sz = syclMaxWorkItemsPerEU(); + int num_wg = ceil_div(N, wg_sz); + int hw_max_num_wg = syclMaxWorkItemsPerTile() / wg_sz; + num_wg = num_wg > hw_max_num_wg ? hw_max_num_wg : num_wg; + sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); +} + template < typename func_t, typename array_t, @@ -330,6 +359,25 @@ static inline void launch_unrolled_kernel( sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); } +constexpr int max_scalar_size_(std::tuple<>) { + return 0; +} + +template +constexpr int max_scalar_size_(std::tuple) { + return std::max( + sizeof(scalar_t), max_scalar_size_(std::tuple{})); +} + +template +constexpr static inline int max_scalar_size() { + using traits = function_traits; + using args_t = typename traits::ArgsTuple; + constexpr auto size = max_scalar_size_(args_t{}); + using return_t = typename traits::result_type; + return std::max(sizeof(return_t), size); +} + template static inline void launch_vectorized_kernel( int64_t N, @@ -337,17 +385,21 @@ static inline void launch_vectorized_kernel( array_t data, in_calc_t input_calc, int vec_size) { + constexpr auto max_scalar_bytes = max_scalar_size(); TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); using traits = function_traits; auto wg_sz = syclMaxWorkItemsPerEU(); #define VEC_KER(vec_size) \ { \ - auto ker = \ - VectorizedElementwiseKernel( \ - N, f, data, input_calc); \ - int num_wg = ceil_div(N, wg_sz * vec_size); \ - sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), ker); \ + TORCH_CHECK(max_scalar_bytes* vec_size <= 16); \ + if constexpr (max_scalar_bytes * vec_size <= 16) { \ + auto ker = \ + VectorizedElementwiseKernel( \ + N, f, data, input_calc); \ + int num_wg = ceil_div(N, wg_sz * vec_size); \ + sycl_kernel_submit(wg_sz* num_wg, wg_sz, getCurrentSYCLQueue(), ker); \ + } \ } switch (vec_size) { @@ -382,8 +434,18 @@ static inline void launch_vectorized_kernel( } } -template -static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, in_calc_t ic, out_calc_t oc) { +template < + int num_outputs, + typename func_t, + typename array_t, + typename in_calc_t, + typename out_calc_t> +static inline void launch_unrolled_kernel_for_multi_outputs( + int64_t N, + const func_t& f, + array_t data, + in_calc_t ic, + out_calc_t oc) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); auto ker = UnrolledElementwiseForMultiOutputsKernel< @@ -432,7 +494,7 @@ static inline bool can_vectorize_for_non_contigouous( return vec_size > 1; } -template +template void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using arg0_t = typename traits::result_type; @@ -449,8 +511,9 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { } int64_t numel = iter.numel(); - bool contiguous = iter.is_contiguous(); + bool latency_case = numel <= + syclMaxWorkItemsPerEU() * 4; /* on tuning for different data types */ int vec_size; if (contiguous) { @@ -458,23 +521,31 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { vec_size = memory::can_vectorize_up_to(data); launch_vectorized_kernel(numel, f, data, input_calc, vec_size); return; - } else if (can_vectorize_for_non_contigouous(iter, data, vec_size)) { - auto input_calc = make_input_offset_calculator(iter); - launch_vectorized_kernel(numel, f, data, input_calc, vec_size); - return; + } else { + if constexpr (enable_broadcast_vec) { + if (!latency_case && + can_vectorize_for_non_contigouous(iter, data, vec_size)) { + auto input_calc = make_input_offset_calculator(iter); + launch_vectorized_kernel(numel, f, data, input_calc, vec_size); + return; + } + } } auto offset_calc = ::make_offset_calculator(iter); - constexpr int unroll_factor = sizeof(arg0_t) > 4 ? 2 : 4; - launch_legacy_kernel( - numel, LegacyKernelScalarFunctor< - arg0_t, ntensors, decltype(offset_calc), func_t>(data, offset_calc, f)); + launch_legacy_global_range_kernel( + numel, + LegacyKernelScalarFunctor< + arg0_t, + ntensors, + decltype(offset_calc), + func_t>(data, offset_calc, f)); } -template +template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (!needs_dynamic_casting::check(iter)) { - return gpu_kernel_impl_nocast(iter, f); + return gpu_kernel_impl_nocast(iter, f); } using traits = function_traits; using arg0_t = typename traits::result_type; @@ -513,21 +584,25 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { } auto offset_calc = ::make_offset_calculator(iter); constexpr int unroll_factor = sizeof(arg0_t) > 4 ? 2 : 4; - launch_legacy_kernel( + launch_legacy_group_range_kernel( numel, LegacyKernelWithCastScalarFunctor< - arg0_t, ntensors, decltype(offset_calc), func_t>( - data, dtypes, offset_calc, f)); + arg0_t, + ntensors, + decltype(offset_calc), + func_t>(data, dtypes, offset_calc, f)); } } template void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { - for (int arg = 0; arg < iter.ntensors(); arg++) { TORCH_INTERNAL_ASSERT( - iter.device(arg).is_xpu(), - "argument ", arg, ": expected an XPU device but found ", iter.device(arg)); + iter.device(arg).is_xpu(), + "argument ", + arg, + ": expected an XPU device but found ", + iter.device(arg)); } if (iter.numel() == 0) { @@ -546,11 +621,13 @@ void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { template void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { - for (int arg = 0; arg < iter.ntensors(); arg++) { TORCH_INTERNAL_ASSERT( - iter.device(arg).is_xpu(), - "argument ", arg, ": expected an XPU device but found ", iter.device(arg)); + iter.device(arg).is_xpu(), + "argument ", + arg, + ": expected an XPU device but found ", + iter.device(arg)); } if (iter.numel() == 0) { @@ -567,30 +644,32 @@ void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { gpu_kernel_impl(iter, f); } -template +template struct AUnaryFunctor { using traits = function_traits; using opmath_arg1_t = typename traits::template arg<0>::type; return_t operator()(arg2_t b) const { return f(a, b); } - AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} - private: - func_t f; - opmath_arg1_t a; + AUnaryFunctor(func_t f_, opmath_arg1_t a_) : f(f_), a(a_) {} + + private: + func_t f; + opmath_arg1_t a; }; -template +template struct BUnaryFunctor { using traits = function_traits; using opmath_arg2_t = typename traits::template arg<1>::type; return_t operator()(arg1_t a) const { return f(a, b); } - BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} - private: - func_t f; - opmath_arg2_t b; + BUnaryFunctor(func_t f_, opmath_arg2_t b_) : f(f_), b(b_) {} + + private: + func_t f; + opmath_arg2_t b; }; template @@ -598,12 +677,17 @@ struct BinaryFunctor { return_t operator()(arg1_t a, arg2_t b) const { return f(a, b); } - BinaryFunctor(func_t f_): f(f_) {} - private: - func_t f; + BinaryFunctor(func_t f_) : f(f_) {} + + private: + func_t f; }; -template +template < + typename arg1_t, + typename arg2_t = arg1_t, + typename return_t = arg1_t, + typename func_t> void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); @@ -615,12 +699,14 @@ void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { "gpu_kernel_with_scalars only supports two input arguments"); if (iter.is_cpu_scalar(1)) { - AUnaryFunctor af(f, iter.scalar_value(1)); + AUnaryFunctor af( + f, iter.scalar_value(1)); iter.remove_operand(1); const OptionalDeviceGuard device_guard(iter.device(1)); gpu_kernel(iter, af); } else if (iter.is_cpu_scalar(2)) { - BUnaryFunctor bf(f, iter.scalar_value(2)); + BUnaryFunctor bf( + f, iter.scalar_value(2)); iter.remove_operand(2); gpu_kernel(iter, bf); } else { @@ -629,7 +715,9 @@ void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { } template -void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { +void opmath_symmetric_gpu_kernel_with_scalars( + TensorIteratorBase& iter, + const func_t& f) { // Use symmetric property of the functor to reduce number of kernels, // requires f(a, b) == f(b, a) TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); @@ -639,8 +727,9 @@ void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const fu static_assert( traits::arity == 2, "gpu_kernel_with_scalars only supports two input arguments"); - static_assert(std::is_same::type>::value, - "f is not symmetric"); + static_assert( + std::is_same::type>::value, + "f is not symmetric"); OptionalDeviceGuard device_guard; opmath_arg_t scalar_val{}; @@ -678,7 +767,9 @@ void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { } template -void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { +void gpu_kernel_multiple_outputs_impl( + TensorIteratorBase& iter, + const func_t& f) { using traits = function_traits; using output_t = typename traits::result_type; constexpr int num_outputs = std::tuple_size::value; @@ -698,11 +789,13 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) if (iter.is_contiguous()) { auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator(); - launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + launch_unrolled_kernel_for_multi_outputs( + numel, f, data, input_calc, output_calc); } else { auto input_calc = make_input_offset_calculator(iter); auto output_calc = make_output_offset_calculator(iter); - launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + launch_unrolled_kernel_for_multi_outputs( + numel, f, data, input_calc, output_calc); } } @@ -726,4 +819,6 @@ void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { gpu_kernel_multiple_outputs_impl(iter, f); } -}}} //namespace at::native::xpu +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/aten/sycl/MemoryAccess.h b/src/aten/sycl/MemoryAccess.h index e13379479..184b682b8 100644 --- a/src/aten/sycl/MemoryAccess.h +++ b/src/aten/sycl/MemoryAccess.h @@ -149,9 +149,9 @@ struct vectorized { // like broadcast. The broadcasted operands are dense and could be // optimized by the policy if satisfying vectorization conditions. // auto offset = input_offset_calculator.get(group_offset); - auto linear_idx = group_offset + item_idx* vec_size; + auto linear_idx = group_offset + item_idx * vec_size; auto offset = input_offset_calculator.get(linear_idx); - detail::static_unroll::with_args(*this, args, offset, 0); + detail::static_unroll::with_args(*this, args, offset); } template @@ -160,14 +160,12 @@ struct vectorized { scalar_t* to = reinterpret_cast(data[0]) + group_work_size * group_idx; vec_t* to_ = reinterpret_cast(to); - - int index = item_idx; vec_t v; #pragma unroll for (int j = 0; j < vec_size; j++) { v.val[j] = from[j]; } - to_[index] = v; + to_[item_idx] = v; } }; diff --git a/src/aten/sycl/MemoryAccessUtils.h b/src/aten/sycl/MemoryAccessUtils.h index e924ecca8..26388d520 100644 --- a/src/aten/sycl/MemoryAccessUtils.h +++ b/src/aten/sycl/MemoryAccessUtils.h @@ -46,13 +46,13 @@ struct static_unroll { template struct vectorized_load_helper { template - static C10_DEVICE void apply(policy_t &self, args_t *args, offset_t offset, int args_vec_base) { + static C10_DEVICE void apply(policy_t &self, args_t *args, offset_t offset) { using arg_t = std::tuple_element_t; // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we // need a +1 offset to get the input auto ptr = reinterpret_cast(self.data[arg_index + 1]) + offset[arg_index]; - auto args_accessor = [&args, args_vec_base] C10_DEVICE (int thread_unroll_idx) -> arg_t & { - return std::get(args[args_vec_base + thread_unroll_idx]); + auto args_accessor = [&args] C10_DEVICE (int thread_unroll_idx) -> arg_t & { + return std::get(args[thread_unroll_idx]); }; self.load_single_arg(args_accessor, ptr); } diff --git a/test/python/examples/test_loops.py b/test/python/examples/test_loops.py index 87a74e3ed..f06292575 100644 --- a/test/python/examples/test_loops.py +++ b/test/python/examples/test_loops.py @@ -17,7 +17,7 @@ ] class TestLoopsKernel(TestCase): - def test_loops(self, dtype=torch.float): + def _test_loops(self, dtype=torch.float): for shape in test_shapes: if len(shape) == 2: a = torch.randn(shape[0], dtype=dtype) @@ -31,6 +31,34 @@ def test_loops(self, dtype=torch.float): ) a_xpu = a.xpu() b_xpu = b.xpu() - c = a + b - c_xpu = a_xpu + b_xpu + c = a + b + 1 + c_xpu = a_xpu + b_xpu + 1 + self.assertEqual(c, c_xpu.cpu()) + + def test_loops_float(self): + self._test_loops(torch.float) + + def test_loops_half(self): + self._test_loops(torch.half) + + def test_loops_bfloat16(self): + self._test_loops(torch.bfloat16) + + def test_loops_dynamic_cast(self): + for shape in test_shapes: + if len(shape) == 2: + a = torch.randn(shape[0], dtype=torch.float) + b = torch.randn(shape[1], dtype=torch.half) + elif len(shape) == 4: + a = torch.as_strided( + torch.randn(shape[0][0] * shape[1][0], dtype=torch.float), shape[0], shape[1] + ) + b = torch.as_strided( + torch.randn(shape[2][0] * shape[3][0], dtype=torch.half), shape[2], shape[3] + ) + a_xpu = a.xpu() + b_xpu = b.xpu() + print(f'a_xpu:{a_xpu.dtype}, {a_xpu.shape}, {a.stride()}; b_xpu:{b_xpu.dtype}, {b_xpu.shape}, {b_xpu.stride()}', flush=True) + c = a + b + 1 + c_xpu = a_xpu + b_xpu + 1 self.assertEqual(c, c_xpu.cpu())