Skip to content

Commit

Permalink
Add aten::_foreach_lerp and its variants (#395)
Browse files Browse the repository at this point in the history
Including aten::_foreach_lerp/_/.List/.Scalar

Signed-off-by: Feng Yuan <feng1.yuan@intel.com>
  • Loading branch information
fengyuan14 authored Jun 21, 2024
1 parent 8750cf3 commit 01fc85f
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/aten/ForeachOpList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#endif
#include <aten/sycl/ForeachBinaryOpListKernels.h>
#include <aten/sycl/ForeachPointwiseOpListKernels.h>
#include <aten/sycl/ForeachTernaryOpListKernels.h>

namespace at {

Expand Down Expand Up @@ -107,4 +108,45 @@ FOREACH_BINARY_OP_LIST(div, true);
FOREACH_POINTWISE_OP_TENSOR(addcmul)
FOREACH_POINTWISE_OP_TENSOR(addcdiv)

std::vector<at::Tensor> XPUNativeFunctions::_foreach_lerp(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
at::native::check_foreach_api_restrictions(tensors1, tensors2, tensors3);
if (!at::native::can_use_fast_route(
{tensors1, tensors2, tensors3}, {}, true)) {
return at::native::foreach_tensor_ternary_lerp_slow(
tensors1, tensors2, tensors3);
}

std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}

native::xpu::foreach_lerp_list_kernel(tensors1, tensors2, tensors3, vec_res);
return vec_res;
}

void XPUNativeFunctions::_foreach_lerp_(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
at::native::check_foreach_api_restrictions(tensors1, tensors2, tensors3);
if (!at::native::can_use_fast_route(
{tensors1, tensors2, tensors3}, {}, true)) {
return at::native::foreach_tensor_ternary_lerp_slow_(
tensors1, tensors2, tensors3);
}

native::xpu::foreach_lerp_list_kernel_(tensors1, tensors2, tensors3);

// TODO: Handle version bump in codegen.
// increment_version
for (const auto& t : tensors1) {
t.unsafeGetTensorImpl()->bump_version();
}
}

} // namespace at
34 changes: 34 additions & 0 deletions src/aten/ForeachOpScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/XPUNativeFunctions.h>
#include <aten/sycl/ForeachBinaryOpScalarKernels.h>
#include <aten/sycl/ForeachPointwiseOpScalarKernels.h>
#include <aten/sycl/ForeachTernaryOpScalarKernels.h>

namespace at {

Expand Down Expand Up @@ -74,4 +75,37 @@ FOREACH_BINARY_OP_SCALAR(div, /*div_op*/ true);
FOREACH_POINTWISE_OP_SCALAR(addcmul)
FOREACH_POINTWISE_OP_SCALAR(addcdiv)

std::vector<at::Tensor> XPUNativeFunctions::_foreach_lerp(
TensorList tensors1,
TensorList tensors2,
const Scalar& weight) {
at::native::check_foreach_api_restrictions(tensors1, tensors2);
if (!at::native::can_use_fast_route({tensors1, tensors2}, {}, true)) {
return at::native::foreach_tensor_lerp_list_kernel_slow(
tensors1, tensors2, weight);
}

std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}

native::xpu::foreach_lerp_scalar_kernel(tensors1, tensors2, weight, vec_res);

return vec_res;
}

void XPUNativeFunctions::_foreach_lerp_(
TensorList tensors1,
TensorList tensors2,
const Scalar& weight) {
at::native::check_foreach_api_restrictions(tensors1, tensors2);
if (!at::native::can_use_fast_route({tensors1, tensors2}, {}, true)) {
return at::native::foreach_tensor_lerp_list_kernel_slow_(
tensors1, tensors2, weight);
}

native::xpu::foreach_lerp_scalar_kernel_(tensors1, tensors2, weight);
}
} // namespace at
123 changes: 123 additions & 0 deletions src/aten/sycl/ForeachTernaryKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/Lerp.h>

#include <aten/sycl/ForeachFunctors.h>
#include <aten/sycl/MultiTensorApply.h>

namespace at::native::xpu {

template <typename scalar_t>
struct LerpFunctor {
inline scalar_t operator()(
const scalar_t self,
const scalar_t end,
const scalar_t weight) {
return lerp(self, end, weight);
}
};

void foreach_lerp_list_kernel(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3,
TensorList result) {
std::vector<std::vector<at::Tensor>> tensor_lists{
tensors1.vec(), tensors2.vec(), tensors3.vec(), result.vec()};

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
tensors1[0].scalar_type(),
"foreach_lerp_list_xpu",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<4>(
tensor_lists,
TernaryOpListFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3>(),
LerpFunctor<opmath_t>());
});
}

void foreach_lerp_list_kernel_(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3) {
std::vector<std::vector<at::Tensor>> tensor_lists{
tensors1.vec(), tensors2.vec(), tensors3.vec()};

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
tensors1[0].scalar_type(),
"foreach_lerp_list_xpu_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<3>(
tensor_lists,
TernaryOpListFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0>(),
LerpFunctor<opmath_t>());
});
}

void foreach_lerp_scalar_kernel(
TensorList tensors1,
TensorList tensors2,
const Scalar& weight,
TensorList result) {
std::vector<std::vector<at::Tensor>> tensor_lists{
tensors1.vec(), tensors2.vec(), result.vec()};

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
tensors1[0].scalar_type(),
"foreach_lerp_scalar_xpu",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<3>(
tensor_lists,
TernaryOpScalarFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
});
}

void foreach_lerp_scalar_kernel_(
TensorList tensors1,
TensorList tensors2,
const Scalar& weight) {
std::vector<std::vector<at::Tensor>> tensor_lists{
tensors1.vec(), tensors2.vec()};

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
tensors1[0].scalar_type(),
"foreach_lerp_scalar_xpu_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<2>(
tensor_lists,
TernaryOpScalarFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
});
}
} // namespace at::native::xpu
17 changes: 17 additions & 0 deletions src/aten/sycl/ForeachTernaryOpListKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include <ATen/ATen.h>

namespace at::native::xpu {

void foreach_lerp_list_kernel(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3,
TensorList result);

void foreach_lerp_list_kernel_(
TensorList tensors1,
TensorList tensors2,
TensorList tensors3);

} // namespace at::native::xpu
17 changes: 17 additions & 0 deletions src/aten/sycl/ForeachTernaryOpScalarKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once
#include <ATen/ATen.h>

namespace at::native::xpu {

void foreach_lerp_scalar_kernel(
TensorList tensors1,
TensorList tensors2,
const Scalar& weight,
TensorList result);

void foreach_lerp_scalar_kernel_(
TensorList tensors1,
TensorList tensors2,
const Scalar& weight);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/run_test_with_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def launch_test(test_case, skip_list=None, exe_list=None):
# Compiler optimization on data type conversion brings the precision error.
"_foreach_addcdiv_ and not slowpath and not test_pointwise_op_with_tensor_of_scalarlist_overload__foreach_addcdiv_is_fastpath_True_xpu_float16",
"_foreach_sqrt_ and not slowpath",
"_foreach_lerp_ and not slowpath",
)
res += launch_test("test_foreach_xpu.py", exe_list=execute_list)

Expand Down
4 changes: 4 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ supported:
- _foreach_addcdiv_.Tensor
- _foreach_sqrt
- _foreach_sqrt_
- _foreach_lerp.List
- _foreach_lerp_.List
- _foreach_lerp.Scalar
- _foreach_lerp_.Scalar
- maximum
- maximum.out
- minimum
Expand Down

0 comments on commit 01fc85f

Please sign in to comment.