-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add aten::_foreach_lerp and its variants (#395)
Including aten::_foreach_lerp/_/.List/.Scalar Signed-off-by: Feng Yuan <feng1.yuan@intel.com>
- Loading branch information
1 parent
8750cf3
commit 01fc85f
Showing
7 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/Lerp.h> | ||
|
||
#include <aten/sycl/ForeachFunctors.h> | ||
#include <aten/sycl/MultiTensorApply.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct LerpFunctor { | ||
inline scalar_t operator()( | ||
const scalar_t self, | ||
const scalar_t end, | ||
const scalar_t weight) { | ||
return lerp(self, end, weight); | ||
} | ||
}; | ||
|
||
void foreach_lerp_list_kernel( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
TensorList tensors3, | ||
TensorList result) { | ||
std::vector<std::vector<at::Tensor>> tensor_lists{ | ||
tensors1.vec(), tensors2.vec(), tensors3.vec(), result.vec()}; | ||
|
||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
tensors1[0].scalar_type(), | ||
"foreach_lerp_list_xpu", | ||
[&]() { | ||
using opmath_t = typename at::opmath_type<scalar_t>; | ||
multi_tensor_apply<4>( | ||
tensor_lists, | ||
TernaryOpListFunctor< | ||
scalar_t, | ||
/* depth */ 4, | ||
/* r_args_depth */ 3, | ||
/* res_arg_index */ 3>(), | ||
LerpFunctor<opmath_t>()); | ||
}); | ||
} | ||
|
||
void foreach_lerp_list_kernel_( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
TensorList tensors3) { | ||
std::vector<std::vector<at::Tensor>> tensor_lists{ | ||
tensors1.vec(), tensors2.vec(), tensors3.vec()}; | ||
|
||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
tensors1[0].scalar_type(), | ||
"foreach_lerp_list_xpu_", | ||
[&]() { | ||
using opmath_t = typename at::opmath_type<scalar_t>; | ||
multi_tensor_apply<3>( | ||
tensor_lists, | ||
TernaryOpListFunctor< | ||
scalar_t, | ||
/* depth */ 3, | ||
/* r_args_depth */ 3, | ||
/* res_arg_index */ 0>(), | ||
LerpFunctor<opmath_t>()); | ||
}); | ||
} | ||
|
||
void foreach_lerp_scalar_kernel( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
const Scalar& weight, | ||
TensorList result) { | ||
std::vector<std::vector<at::Tensor>> tensor_lists{ | ||
tensors1.vec(), tensors2.vec(), result.vec()}; | ||
|
||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
tensors1[0].scalar_type(), | ||
"foreach_lerp_scalar_xpu", | ||
[&]() { | ||
using opmath_t = typename at::opmath_type<scalar_t>; | ||
multi_tensor_apply<3>( | ||
tensor_lists, | ||
TernaryOpScalarFunctor< | ||
scalar_t, | ||
/* depth */ 3, | ||
/* r_args_depth */ 2, | ||
/* res_arg_index */ 2>(), | ||
LerpFunctor<opmath_t>(), | ||
weight.to<opmath_t>()); | ||
}); | ||
} | ||
|
||
void foreach_lerp_scalar_kernel_( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
const Scalar& weight) { | ||
std::vector<std::vector<at::Tensor>> tensor_lists{ | ||
tensors1.vec(), tensors2.vec()}; | ||
|
||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
tensors1[0].scalar_type(), | ||
"foreach_lerp_scalar_xpu_", | ||
[&]() { | ||
using opmath_t = typename at::opmath_type<scalar_t>; | ||
multi_tensor_apply<2>( | ||
tensor_lists, | ||
TernaryOpScalarFunctor< | ||
scalar_t, | ||
/* depth */ 2, | ||
/* r_args_depth */ 2, | ||
/* res_arg_index */ 0>(), | ||
LerpFunctor<opmath_t>(), | ||
weight.to<opmath_t>()); | ||
}); | ||
} | ||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#pragma once | ||
#include <ATen/ATen.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
void foreach_lerp_list_kernel( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
TensorList tensors3, | ||
TensorList result); | ||
|
||
void foreach_lerp_list_kernel_( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
TensorList tensors3); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#pragma once | ||
#include <ATen/ATen.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
void foreach_lerp_scalar_kernel( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
const Scalar& weight, | ||
TensorList result); | ||
|
||
void foreach_lerp_scalar_kernel_( | ||
TensorList tensors1, | ||
TensorList tensors2, | ||
const Scalar& weight); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters