Skip to content

Commit

Permalink
[PIR] add some check if for onednn kernel (#62269)
Browse files Browse the repository at this point in the history
* add some check if for onednn kernel
  • Loading branch information
wanghuancoder authored Mar 4, 2024
1 parent cb8ae07 commit adb8bc2
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 7 deletions.
4 changes: 4 additions & 0 deletions paddle/phi/core/kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class KernelContext {
return paddle::none;
}

const TensorBase* MutableIutputAt(size_t idx) const {
return inputs_.at(idx);
}

template <typename TensorType>
TensorType* MutableOutputAt(size_t idx) {
return static_cast<TensorType*>(outputs_.at(idx));
Expand Down
17 changes: 16 additions & 1 deletion paddle/phi/kernels/onednn/add_n_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
bool AddNCheckIfOneDNNSupport(const KernelContext* ctx) {
for (size_t i = 0; i < ctx->InputsSize(); i++) {
if (!DenseTensor::classof(ctx->MutableIutputAt(i))) {
return false;
}
}
KernelContext* ctx_tmp = const_cast<KernelContext*>(ctx);
if (!DenseTensor::classof(ctx_tmp->MutableOutputAt(0))) {
return false;
}
return true;
}

namespace funcs {
template <typename T>
class SumOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::sum> {
Expand Down Expand Up @@ -122,4 +135,6 @@ void AddNKernel(const Context& dev_ctx,
} // namespace phi

PD_REGISTER_KERNEL(
add_n, OneDNN, ONEDNN, phi::AddNKernel, float, phi::dtype::bfloat16) {}
add_n, OneDNN, ONEDNN, phi::AddNKernel, float, phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::AddNCheckIfOneDNNSupport;
}
24 changes: 22 additions & 2 deletions paddle/phi/kernels/onednn/sgd_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@

namespace phi {

bool SgdCheckIfOneDNNSupport(const KernelContext* ctx) {
if (DenseTensor::classof(ctx->MutableIutputAt(0)) &&
DenseTensor::classof(ctx->MutableIutputAt(2))) {
return true;
}
return false;
}

bool SgdSparseCheckIfOneDNNSupport(const KernelContext* ctx) {
if (DenseTensor::classof(ctx->MutableIutputAt(0)) &&
SelectedRows::classof(ctx->MutableIutputAt(2))) {
return true;
}
return false;
}

template <typename T, typename Context>
void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
Expand Down Expand Up @@ -82,11 +98,15 @@ void SGDDenseParamSparseGradKernel(
} // namespace phi

PD_REGISTER_KERNEL(
sgd, OneDNN, ONEDNN, phi::SGDDenseKernel, float, phi::dtype::bfloat16) {}
sgd, OneDNN, ONEDNN, phi::SGDDenseKernel, float, phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::SgdCheckIfOneDNNSupport;
}

PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
OneDNN,
ONEDNN,
phi::SGDDenseParamSparseGradKernel,
float,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::SgdSparseCheckIfOneDNNSupport;
}
11 changes: 10 additions & 1 deletion paddle/phi/kernels/onednn/slice_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@

namespace phi {

bool SliceGradCheckIfOneDNNSupport(const KernelContext* ctx) {
if (ctx->InputAt<phi::DenseTensor>(1).mem_desc().get_inner_nblks() == 0) {
return true;
}
return false;
}

template <typename T, typename Context>
void SliceGradKernel(const Context& dev_ctx,
const DenseTensor& input UNUSED,
Expand Down Expand Up @@ -83,4 +90,6 @@ PD_REGISTER_KERNEL(slice_grad,
ONEDNN,
phi::SliceGradKernel,
float,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::SliceGradCheckIfOneDNNSupport;
}
16 changes: 15 additions & 1 deletion paddle/phi/kernels/onednn/slice_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@

namespace phi {

bool SliceCheckIfOneDNNSupport(const KernelContext* ctx) {
auto x = ctx->InputAt<phi::DenseTensor>(0);
auto vec_dims = common::vectorize(x.dims());
bool all_zero_dims = std::all_of(
vec_dims.cbegin(), vec_dims.cend(), [](int64_t i) { return i == 0; });

if (!all_zero_dims && x.mem_desc().get_inner_nblks() == 0) {
return true;
}
return false;
}

template <typename T, typename Context>
void SliceKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -106,4 +118,6 @@ PD_REGISTER_KERNEL(slice,
float,
int8_t,
uint8_t,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->check_if_onednn_kernel_support_ = phi::SliceCheckIfOneDNNSupport;
}
15 changes: 13 additions & 2 deletions paddle/phi/kernels/onednn/split_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@

namespace phi {

bool SplitCheckIfOneDNNSupport(const KernelContext* ctx) {
if (ctx->InputAt<phi::DenseTensor>(0).mem_desc().get_inner_nblks() == 0) {
return true;
}
return false;
}

const std::vector<int64_t> get_slice_strides(
const std::vector<int64_t>& out_vec_dims,
const dnnl::memory::desc& full_md,
Expand Down Expand Up @@ -104,7 +111,9 @@ PD_REGISTER_KERNEL(split,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
uint8_t) {
kernel->check_if_onednn_kernel_support_ = phi::SplitCheckIfOneDNNSupport;
}

PD_REGISTER_KERNEL(split_with_num,
OneDNN,
Expand All @@ -113,4 +122,6 @@ PD_REGISTER_KERNEL(split_with_num,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
uint8_t) {
kernel->check_if_onednn_kernel_support_ = phi::SplitCheckIfOneDNNSupport;
}

0 comments on commit adb8bc2

Please sign in to comment.