Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] refine update_loss_scaling npu kernel #32580

Merged
merged 7 commits into from
May 8, 2021

Conversation

pangyoki
Copy link
Contributor

@pangyoki pangyoki commented Apr 26, 2021

PR types

Performance optimization

PR changes

OPs

Describe

use ZerosLike and Memcpy instead of NPUMemsetAsync.

  • before (use NPUMemsetAsync)

图片

As shown in the timeline, there is a blank correspondding to update_loss_scaling_op caused by NPUMemsetAsync.
update_loss_scaling_op cost about 103 ms.

图片

  • only use ZerosLike
    If only use ZerosLike to replace NPUMemsetAsync.

图片

update_loss_scaling_op will launch many ZerosLike NPU ops.
update_loss_scaling_op cost about 22.2 ms.

  • In this PR, use ZerosLike and Memcpy

图片

update_loss_scaling_op will launch only 1 ZerosLike NPU op, and then use Memcpy to set tensor to 0.
update_loss_scaling_op cost about 5.5 ms.

图片

Performance

Speed up: 19448 tokens/s -> 20679 tokens/s, +6.33 %

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

auto g = out->mutable_data<T>(place);
platform::NPUMemsetAsync(static_cast<void*>(g), 0,
out->numel() * sizeof(T), stream);
auto runner_zeros = NpuOpRunner("ZerosLike", {*out}, {*out});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutable_data is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

zhiqiu
zhiqiu previously approved these changes Apr 29, 2021
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@phlrain phlrain self-requested a review May 8, 2021 06:48
@pangyoki pangyoki merged commit 4628b6f into PaddlePaddle:develop May 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants