Skip to content

Commit

Permalink
fused rms spmd (#7830)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 authored Jan 11, 2024
1 parent 4069f22 commit f039d09
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
#include "layer_norm_cuda.h" // NOLINT
#include "paddle/extension.h"

#ifdef CUSTOM_OP_WITH_SPMD
#include "paddle/phi/api/ext/spmd_infer.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif

#define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor")

static void GetRowsCols(const std::vector<int64_t> &shape,
Expand Down Expand Up @@ -214,14 +219,22 @@ PD_BUILD_OP(fused_rms_norm)
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(RMSLnFwd))
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype));
.SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype))
#ifdef CUSTOM_OP_WITH_SPMD
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormInferSpmd))
#endif
;

PD_BUILD_GRAD_OP(fused_rms_norm)
.Inputs({"x", "scale", "invvar", paddle::Grad("y")})
.Outputs({paddle::Grad("x"), paddle::Grad("scale")})
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(RMSLnBwd))
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape));
.SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape))
#ifdef CUSTOM_OP_WITH_SPMD
.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd))
#endif
;


// https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679

0 comments on commit f039d09

Please sign in to comment.