From f039d097d3fc31028aef609614496e517de84371 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 11 Jan 2024 16:44:22 +0800 Subject: [PATCH] fused rms spmd (#7830) --- .../external_ops/fused_ln/layer_norm_cuda.cu | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu b/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu index 66af7f8bb41f..68c563b09235 100644 --- a/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu +++ b/model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.cu @@ -24,6 +24,11 @@ #include "layer_norm_cuda.h" // NOLINT #include "paddle/extension.h" +#ifdef CUSTOM_OP_WITH_SPMD +#include "paddle/phi/api/ext/spmd_infer.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + #define CHECK_CUDA(x) PD_CHECK(!x.is_cpu(), #x " must be a CUDA tensor") static void GetRowsCols(const std::vector &shape, @@ -214,14 +219,22 @@ PD_BUILD_OP(fused_rms_norm) .Attrs({"epsilon: float"}) .SetKernelFn(PD_KERNEL(RMSLnFwd)) .SetInferShapeFn(PD_INFER_SHAPE(RMSLnFwdInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype)); + .SetInferDtypeFn(PD_INFER_DTYPE(RMSLnFwdInferDtype)) +#ifdef CUSTOM_OP_WITH_SPMD + .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormInferSpmd)) +#endif + ; PD_BUILD_GRAD_OP(fused_rms_norm) .Inputs({"x", "scale", "invvar", paddle::Grad("y")}) .Outputs({paddle::Grad("x"), paddle::Grad("scale")}) .Attrs({"epsilon: float"}) .SetKernelFn(PD_KERNEL(RMSLnBwd)) - .SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape)); + .SetInferShapeFn(PD_INFER_SHAPE(RMSLnBwdInferShape)) +#ifdef CUSTOM_OP_WITH_SPMD + .SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::RmsNormGradInferSpmd)) +#endif + ; // https://github.com/NVIDIA/apex/blob/85e9eddece9d4ac72b48c2407f8162f2173e1bf4/csrc/layer_norm_cuda_kernel.cu#L679