Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inf-cl in embedding trainer #9673

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llm/config/qwen/emb_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@
"unified_checkpoint": true,
"use_flash_attention": true,
"amp_custom_black_list": "elementwise_div",
"release_grads": true
}
"release_grads": true,
"loss_type": "contrastive"
}
8 changes: 8 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ class EmbeddingArgument:
default=None,
metadata={"help": "The dims for matryoshka training."},
)
loss_type: str = field(
default="contrastive",
metadata={"help": "The type of loss computation."},
)
inf_cl_head_dim: int = field(
default=64,
metadata={"help": "The size of the head dimension when gpu ops are set as 'inf_cl'."},
)
1 change: 1 addition & 0 deletions ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from .flash import cal_flash_loss
from .inf_cl_loss import *
from .ring import cal_inf_loss, cal_ring_loss
103 changes: 103 additions & 0 deletions ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import paddle
import paddle.nn as nn

from .ring import cal_inf_loss

__all__ = ["SimpleInfclLoss", "MatryoshkaInfclLoss"]


class SimpleInfclLoss(nn.Layer):
def __init__(self, inf_cl_head_dim=64):
"""
Initializes the Simple Inf_cl Loss class.

Args:
inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64.
"""
super().__init__()
self.head_dim = inf_cl_head_dim

def forward(self, q_reps, p_reps):
"""
Computes the instance discrimination loss.

Args:
q_reps (Tensor): Query representations.
p_reps (Tensor): key representations.

Returns:
Tensor: The computed loss.
"""
group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query
labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries
labels = labels * group_size # Adjust labels based on group size
loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim)
return loss


class MatryoshkaInfclLoss(nn.Layer):
def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64):
"""
Initializes the Matryoshka Inf_cl Loss class.

Args:
embedding_matryoshka_dims (List[int], optional): List of dimensions for Matryoshka embeddings.
If None, no Matryoshka embedding is used. Default is None.
inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64.
"""
super().__init__()
if embedding_matryoshka_dims is None:
self.embedding_matryoshka_dims = []
else:
self.embedding_matryoshka_dims = embedding_matryoshka_dims
self.loss_fn = SimpleInfclLoss(inf_cl_head_dim)

def forward(self, q_reps, p_reps):
"""
Computes the Matryoshka instance discrimination loss.

Args:
q_reps (Tensor): Query representations.
p_reps (Tensor): key representations.

Returns:
Tensor: The computed loss.
"""
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim] # Reduce query representations to the current Matryoshka dimension
reduced_q_reps = nn.functional.normalize(
reduced_q_reps, axis=-1
) # Normalize the reduced query representations along the last axis

reduced_p_reps = p_reps[:, :dim] # Reduce key representations to the current Matryoshka dimension
reduced_p_reps = nn.functional.normalize(
reduced_p_reps, axis=-1
) # Normalize the reduced key representations along the last axis

dim_loss = self.loss_fn(
reduced_q_reps, reduced_p_reps
) # Compute the loss for the current Matryoshka dimension using the internal loss function
loss += dim_loss
else:
loss = self.loss_fn(
q_reps, p_reps
) # If no Matryoshka dimensions are specified, compute the loss using the full representations
return loss
25 changes: 21 additions & 4 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
from paddle.base import core
from paddle.distributed import fleet

from paddlenlp.utils.log import logger

try:
from paddlenlp_kernel.triton.inf_cl import MatryoshkaInfclLoss, SimpleInfclLoss
except ImportError:
logger.warning(
"Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`."
)

from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import (
MatryoshkaContrastiveLoss,
Expand All @@ -44,11 +53,19 @@
self.accum_rng_states["hybrid"] = []

if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0:
self.loss_fn = MatryoshkaContrastiveLoss(
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
if model_args.loss_type == "inf_cl":
self.embedding_negatives_cross_device = False
self.loss_fn = MatryoshkaInfclLoss(model_args.embedding_matryoshka_dims, model_args.inf_cl_head_dim)
elif model_args.loss_type == "contrastive":
self.loss_fn = MatryoshkaContrastiveLoss(

Check warning on line 60 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L56-L60

Added lines #L56 - L60 were not covered by tests
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
else:
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)
if model_args.loss_type == "inf_cl":
self.embedding_negatives_cross_device = False
self.loss_fn = SimpleInfclLoss(model_args.inf_cl_head_dim)
elif model_args.loss_type == "contrastive":
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)

Check warning on line 68 in paddlenlp/trl/embedding_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trl/embedding_trainer.py#L64-L68

Added lines #L64 - L68 were not covered by tests

def clear_memory(self):
self.accum_q_features.clear()
Expand Down