Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inf-cl in embedding trainer #9673

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

jie-z-0607
Copy link

PR types

Function optimization

PR changes

Others

Description

在embedding训练中增加inf_cl_loss,在超大batch_size下能有效节省显存消耗。

经测试,inf-cl算子能够与原有损失函数有效对齐:

  • 以数据类型设置bf16,group_size设置1,gradient_accumulation_steps设置4为例,inf_cl_loss与原有contrastive_loss的收敛曲线如下:
    image

经测试,在超大batch_size下,inf-cl算子能够有效降低embedding训练时的显存消耗:

  • 在8张A100(80G)显卡下,以数据类型设置bf16,group_size设置4,gradient_accumulation_steps设置4096为例,inf_cl_loss与原有contrastive_loss的显存占用对比如下:

参数设置 显存占用 首个step完成耗费时间
不使用inf-cl;embedding_negatives_cross_device=True 42238MiB;42526MiB;
42526MiB;42470MiB;
42470MiB;42526MiB;
42526MiB;42182MiB
48min42s
使用inf-cl;embedding_negatives_cross_device=Flase 29630MiB;28392MiB;
28372MiB;28308MiB;
28320MiB;28384MiB;
28316MiB;28070MiB
49min56s


  • 在8张A100(80G)显卡下,以数据类型设置bf16,group_size设置1,gradient_accumulation_steps设置16384(总计batch_size 128K)为例,inf_cl_loss与原有contrastive_loss的显存占用对比如下:

参数设置 显存占用 首个step完成耗费时间
不使用inf-cl;embedding_negatives_cross_device=True 超出显存限制  
使用inf-cl;embedding_negatives_cross_device=Flase 46324MiB;45192MiB;
44926MiB;45180MiB;
44674MiB;45022MiB;
45032MiB;44904MiB
2h23min46s



Copy link

paddle-bot bot commented Dec 23, 2024

Thanks for your contribution!

Copy link

codecov bot commented Dec 23, 2024

Codecov Report

Attention: Patch coverage is 33.33333% with 10 lines in your changes missing coverage. Please review.

Project coverage is 52.79%. Comparing base (1842d6d) to head (6b6a108).
Report is 1 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/trl/embedding_trainer.py 33.33% 10 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9673      +/-   ##
===========================================
- Coverage    53.18%   52.79%   -0.40%     
===========================================
  Files          718      718              
  Lines       113340   112267    -1073     
===========================================
- Hits         60282    59268    -1014     
+ Misses       53058    52999      -59     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

__all__ = ["Simple_Inf_cl_loss", "Matryoshka_Inf_cl_loss"]


class Simple_Inf_cl_loss(nn.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一些注释

@@ -18,6 +18,10 @@
from paddle.base import core
from paddle.distributed import fleet

from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (
from paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (

@@ -18,6 +18,10 @@
from paddle.base import core
from paddle.distributed import fleet

from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有默认安装,需要 try except一下

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants