Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reduce_merge files for reduce op #32697

Merged
merged 16 commits into from
May 25, 2021

Conversation

AnnaTrainingG
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG commented Apr 30, 2021

PR types

Performance optimization

PR changes

APIs

Describe

add reduce_op.h and functor.h in reduce_merge for reduce op

性能测试数据:

axis case pytorch us paddle_old us paddle_new us 加速比 old/new 加速比pytorch/paddle_new 是否为benchmark
axis=0 [512    2048] 12.442 28.272 10.821 2.61 1.15
axis=0 [128    1024] 5.595 5.181 3.711 1.40 1.51
axis=0 [30522  1024] 162.77 1767.3 152.229 11.61 1.07
axis=0 [1024   16] 4.703 2.471 3.509 0.70 1.34
axis=0 [256    12800] 18.756 81.647 17.734 4.60 1.06
axis=0 [256    10240] 15.742 59.888 15.379 3.89 1.02
axis=0 [1024   1280] 11.625 33.204 8.399 3.95 1.38
axis=0 [32768  1280] 205.95 3504.7 198.15 17.69 1.04
axis=0 [30522  10240] 1414.6 32643 1437.523 22.71 0.98
axis=0 [256    10240] 15.257 65.901 14.79 4.46 1.03
axis=0 [1024   1280] 8.265 31.31 7.158 4.37 1.15
axis=0 [32768  1280] 207.58 3501 198.297 17.66 1.05
axis=0 [30522  10240] 1415.5 32554 1438.646 22.63 0.98
axis=0 [2560   10240] 127.21 585.19 126.275 4.63 1.01
axis=0 [10240  1280] 76.668 413.34 67.667 6.11 1.13
axis=0 [32768  2560] 390.23 8323.7 383.609 21.70 1.02
axis=0 [30522  1024] 160.21 1808.7 151.341 11.95 1.06
axis=0 [16 16  1   1] 319.00% 170.70% 177.60% 0.96 1.80
               
axis=1 [2  512 2048] 20.069 58.785 18.84 3.12 1.07
axis=1 [2  128 1024] 5.864 9.161 3.76 2.44 1.56
axis=1 [2  30522   1024] 296.54 4734 297.31 15.92 1.00
axis=1 [2  1024    16] 5.163 2.678 3.37 0.80 1.53
axis=1 [2  256 12800] 33.255 173.95 32.43 5.36 1.03
axis=1 [2  256 10240] 27.616 120.5 26.23 4.59 1.05
axis=1 [2  1024    1280] 18.325 66.735 18.69 3.57 0.98
axis=1 [2  32768   1280] 390.39 8531.5 390.52 21.85 1.00
axis=1 [2  30522   10240] 1420.7 71629 2878.59 24.88 0.49
axis=1 [2  256 10240] 27.114 140.58 26.47 5.31 1.02
axis=1 [2  1024    1280] 17.84 57.585 19.25 2.99 0.93
axis=1 [2  32768   1280] 390.53 8500 390.55 21.76 1.00
axis=1 [2  30522   10240] 1420.1 71578 2878.75 24.86 0.49
axis=1 [2  2560    10240] 244.58 1157.2 247.81 4.67 0.99
axis=1 [2  10240   1280] 132.68 1228.3 129.82 9.46 1.02
axis=1 [2  32768   2560] 762.9 19096 766.86 24.90 0.99
axis=1 [2  30522   1024] 296.87 4708.1 296.50 15.88 1.00
axis=1 [16 8   128] 3.95 4.253 1.31 3.24 3.01

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@xingfeng01
Copy link
Contributor

建议文件放在reduce_ops下即可,后缀改为.cuh

@xingfeng01
Copy link
Contributor

注意clang-format

@CLAassistant
Copy link

CLAassistant commented May 17, 2021

CLA assistant check
All committers have signed the CLA.

zhangting2020
zhangting2020 previously approved these changes May 20, 2021
Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhangting2020 zhangting2020 mentioned this pull request May 24, 2021
@xingfeng01
Copy link
Contributor

LGTM

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议PR代码接入一个算子,跑下完整的CI,不然光看代码看不出是否有问题。#32804 正在给reduce_sum添加fp16 kernel,这个实现当前没有针对fp16进行优化,先不要改reduce_sum。

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/macros.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

极小化include的头文件,看起来以上头文件都没有实际依赖。

}

template <typename T, size_t ElementCount, typename VectorLikeType>
static inline paddle::framework::Array<T, ElementCount> from(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数是从什么转到什么?单个的From函数一般是用作类的成员函数,如果是单独的函数,最好命名更清楚一些。

vec.size(), ElementCount));
size_t n = static_cast<size_t>(vec.size());
paddle::framework::Array<T, ElementCount> ret;
for (size_t i = 0; i < n; ++i) ret[i] = vec[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要写在同一行,加{}。


template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp, int kRank, int kReduceRank>
static void launchKernel(const Tx* x_data, Ty* y_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

launchKernel -> LaunchKernel,首字母大写。


template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp>
static void launchReduceKernel(const Tx* x_data, Ty* y_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

launchReduceKernel -> LaunchReduceKernel

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 先合入这个pr,下个pr尽快接入算子进行ci验证。

@Xreki Xreki merged commit 88b43b5 into PaddlePaddle:develop May 25, 2021
@AnnaTrainingG AnnaTrainingG deleted the reduce_merge branch October 9, 2022 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants