From dd7fc8a0e21c0773a91ca378f6656c76cd5095b5 Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Mon, 23 Dec 2024 15:27:24 +0800 Subject: [PATCH 1/7] add inf-cl in embedding trainer --- llm/config/qwen/emb_argument.json | 5 +- llm/utils/argument.py | 8 +++ .../triton/inf_cl/__init__.py | 1 + .../triton/inf_cl/inf_cl_loss.py | 61 +++++++++++++++++++ paddlenlp/trl/embedding_trainer.py | 20 ++++-- 5 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py diff --git a/llm/config/qwen/emb_argument.json b/llm/config/qwen/emb_argument.json index d8c6aeeb7f6e..fdfed2d82357 100644 --- a/llm/config/qwen/emb_argument.json +++ b/llm/config/qwen/emb_argument.json @@ -32,5 +32,6 @@ "unified_checkpoint": true, "use_flash_attention": true, "amp_custom_black_list": "elementwise_div", - "release_grads": true -} + "release_grads": true, + "loss_type": "contrastive" +} \ No newline at end of file diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 99df142e826e..e46e95eec7a6 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -88,3 +88,11 @@ class EmbeddingArgument: default=None, metadata={"help": "The dims for matryoshka training."}, ) + loss_type: str = field( + default="contrastive", + metadata={"help": "The type of loss computation."}, + ) + inf_cl_head_dim: int = field( + default=64, + metadata={"help": "The size of the head dimension when gpu ops are set as 'inf_cl'."}, + ) diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py b/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py index 371bdba8a6de..441f44b821cd 100644 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py +++ b/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .flash import cal_flash_loss +from .inf_cl_loss import * from .ring import cal_inf_loss, cal_ring_loss diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py new file mode 100644 index 000000000000..379affb6a6c3 --- /dev/null +++ b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import paddle +import paddle.nn as nn + +from .ring import cal_inf_loss + +__all__ = ["Simple_Inf_cl_loss", "Matryoshka_Inf_cl_loss"] + + +class Simple_Inf_cl_loss(nn.Layer): + def __init__(self, inf_cl_head_dim=64): + super().__init__() + self.head_dim = inf_cl_head_dim + + def forward(self, q_reps, p_reps): + group_size = p_reps.shape[0] // q_reps.shape[0] + labels = paddle.arange(q_reps.shape[0], dtype="int64") + labels = labels * group_size + loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim) + return loss + + +class Matryoshka_Inf_cl_loss(nn.Layer): + def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64): + super().__init__() + if embedding_matryoshka_dims is None: + self.embedding_matryoshka_dims = [] + else: + self.embedding_matryoshka_dims = embedding_matryoshka_dims + self.loss_fn = Simple_Inf_cl_loss(inf_cl_head_dim) + + def forward(self, q_reps, p_reps): + if len(self.embedding_matryoshka_dims) > 0: + loss = 0.0 + for dim in self.embedding_matryoshka_dims: + reduced_q_reps = q_reps[:, :dim] + reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1) + + reduced_p_reps = p_reps[:, :dim] + reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1) + + dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps) + loss += dim_loss + else: + loss = self.loss_fn(q_reps, p_reps) + return loss diff --git a/paddlenlp/trl/embedding_trainer.py b/paddlenlp/trl/embedding_trainer.py index 4ce7601b9056..08591066ca0a 100644 --- a/paddlenlp/trl/embedding_trainer.py +++ b/paddlenlp/trl/embedding_trainer.py @@ -18,6 +18,10 @@ from paddle.base import core from paddle.distributed import fleet +from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import ( + Matryoshka_Inf_cl_loss, + Simple_Inf_cl_loss, +) from paddlenlp.trainer import Trainer from paddlenlp.transformers.contrastive_loss import ( MatryoshkaContrastiveLoss, @@ -44,11 +48,19 @@ def __init__(self, model_args, **kwargs): self.accum_rng_states["hybrid"] = [] if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0: - self.loss_fn = MatryoshkaContrastiveLoss( - model_args.embedding_temperature, model_args.embedding_matryoshka_dims - ) + if model_args.loss_type == "inf_cl": + self.embedding_negatives_cross_device = False + self.loss_fn = Matryoshka_Inf_cl_loss(model_args.embedding_matryoshka_dims, model_args.inf_cl_head_dim) + elif model_args.loss_type == "contrastive": + self.loss_fn = MatryoshkaContrastiveLoss( + model_args.embedding_temperature, model_args.embedding_matryoshka_dims + ) else: - self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature) + if model_args.loss_type == "inf_cl": + self.embedding_negatives_cross_device = False + self.loss_fn = Simple_Inf_cl_loss(model_args.inf_cl_head_dim) + elif model_args.loss_type == "contrastive": + self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature) def clear_memory(self): self.accum_q_features.clear() From 3b066559ed481c0f83cd24f9346cddadbf6f59f1 Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Mon, 23 Dec 2024 18:00:54 +0800 Subject: [PATCH 2/7] add annotations and fix import --- .../triton/inf_cl/inf_cl_loss.py | 60 ++++++++++++++++--- paddlenlp/trl/embedding_trainer.py | 12 ++-- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py index 379affb6a6c3..f9d756261884 100644 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py +++ b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py @@ -24,19 +24,43 @@ class Simple_Inf_cl_loss(nn.Layer): def __init__(self, inf_cl_head_dim=64): + """ + Initializes the Simple Inf_cl Loss class. + + Args: + inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64. + """ super().__init__() self.head_dim = inf_cl_head_dim def forward(self, q_reps, p_reps): - group_size = p_reps.shape[0] // q_reps.shape[0] - labels = paddle.arange(q_reps.shape[0], dtype="int64") - labels = labels * group_size + """ + Computes the instance discrimination loss. + + Args: + q_reps (Tensor): Query representations. + p_reps (Tensor): key representations. + + Returns: + Tensor: The computed loss. + """ + group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query + labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries + labels = labels * group_size # Adjust labels based on group size loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim) return loss class Matryoshka_Inf_cl_loss(nn.Layer): def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64): + """ + Initializes the Matryoshka Inf_cl Loss class. + + Args: + embedding_matryoshka_dims (List[int], optional): List of dimensions for Matryoshka embeddings. + If None, no Matryoshka embedding is used. Default is None. + inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64. + """ super().__init__() if embedding_matryoshka_dims is None: self.embedding_matryoshka_dims = [] @@ -45,17 +69,35 @@ def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl self.loss_fn = Simple_Inf_cl_loss(inf_cl_head_dim) def forward(self, q_reps, p_reps): + """ + Computes the Matryoshka instance discrimination loss. + + Args: + q_reps (Tensor): Query representations. + p_reps (Tensor): key representations. + + Returns: + Tensor: The computed loss. + """ if len(self.embedding_matryoshka_dims) > 0: loss = 0.0 for dim in self.embedding_matryoshka_dims: - reduced_q_reps = q_reps[:, :dim] - reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1) + reduced_q_reps = q_reps[:, :dim] # Reduce query representations to the current Matryoshka dimension + reduced_q_reps = nn.functional.normalize( + reduced_q_reps, axis=-1 + ) # Normalize the reduced query representations along the last axis - reduced_p_reps = p_reps[:, :dim] - reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1) + reduced_p_reps = p_reps[:, :dim] # Reduce key representations to the current Matryoshka dimension + reduced_p_reps = nn.functional.normalize( + reduced_p_reps, axis=-1 + ) # Normalize the reduced key representations along the last axis - dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps) + dim_loss = self.loss_fn( + reduced_q_reps, reduced_p_reps + ) # Compute the loss for the current Matryoshka dimension using the internal loss function loss += dim_loss else: - loss = self.loss_fn(q_reps, p_reps) + loss = self.loss_fn( + q_reps, p_reps + ) # If no Matryoshka dimensions are specified, compute the loss using the full representations return loss diff --git a/paddlenlp/trl/embedding_trainer.py b/paddlenlp/trl/embedding_trainer.py index 08591066ca0a..259bae40939d 100644 --- a/paddlenlp/trl/embedding_trainer.py +++ b/paddlenlp/trl/embedding_trainer.py @@ -18,10 +18,14 @@ from paddle.base import core from paddle.distributed import fleet -from ops.src.paddlenlp_kernel.triton.inf_cl.inf_cl_loss import ( - Matryoshka_Inf_cl_loss, - Simple_Inf_cl_loss, -) +try: + from paddlenlp_kernel.triton.inf_cl import ( + Matryoshka_Inf_cl_loss, + Simple_Inf_cl_loss, + ) +except ImportError: + print("WARNING: paddlenlp_kernels are not available.") + from paddlenlp.trainer import Trainer from paddlenlp.transformers.contrastive_loss import ( MatryoshkaContrastiveLoss, From e3c55c328cc6a53631b8a6c6e1a36fcf422170cf Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Mon, 23 Dec 2024 18:27:11 +0800 Subject: [PATCH 3/7] rename inf_cl_loss and fix warning --- .../paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py | 8 ++++---- paddlenlp/trl/embedding_trainer.py | 15 ++++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py index f9d756261884..58d2e5f6aff3 100644 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py +++ b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py @@ -19,10 +19,10 @@ from .ring import cal_inf_loss -__all__ = ["Simple_Inf_cl_loss", "Matryoshka_Inf_cl_loss"] +__all__ = ["SimpleInfclloss", "MatryoshkaInfclLoss"] -class Simple_Inf_cl_loss(nn.Layer): +class SimpleInfclloss(nn.Layer): def __init__(self, inf_cl_head_dim=64): """ Initializes the Simple Inf_cl Loss class. @@ -51,7 +51,7 @@ def forward(self, q_reps, p_reps): return loss -class Matryoshka_Inf_cl_loss(nn.Layer): +class MatryoshkaInfclLoss(nn.Layer): def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64): """ Initializes the Matryoshka Inf_cl Loss class. @@ -66,7 +66,7 @@ def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl self.embedding_matryoshka_dims = [] else: self.embedding_matryoshka_dims = embedding_matryoshka_dims - self.loss_fn = Simple_Inf_cl_loss(inf_cl_head_dim) + self.loss_fn = SimpleInfclloss(inf_cl_head_dim) def forward(self, q_reps, p_reps): """ diff --git a/paddlenlp/trl/embedding_trainer.py b/paddlenlp/trl/embedding_trainer.py index 259bae40939d..45f85dbb6549 100644 --- a/paddlenlp/trl/embedding_trainer.py +++ b/paddlenlp/trl/embedding_trainer.py @@ -18,13 +18,14 @@ from paddle.base import core from paddle.distributed import fleet +from paddlenlp.utils.log import logger + try: - from paddlenlp_kernel.triton.inf_cl import ( - Matryoshka_Inf_cl_loss, - Simple_Inf_cl_loss, - ) + from paddlenlp_kernel.triton.inf_cl import MatryoshkaInfclLoss, SimpleInfclloss except ImportError: - print("WARNING: paddlenlp_kernels are not available.") + logger.warning( + "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." + ) from paddlenlp.trainer import Trainer from paddlenlp.transformers.contrastive_loss import ( @@ -54,7 +55,7 @@ def __init__(self, model_args, **kwargs): if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0: if model_args.loss_type == "inf_cl": self.embedding_negatives_cross_device = False - self.loss_fn = Matryoshka_Inf_cl_loss(model_args.embedding_matryoshka_dims, model_args.inf_cl_head_dim) + self.loss_fn = MatryoshkaInfclLoss(model_args.embedding_matryoshka_dims, model_args.inf_cl_head_dim) elif model_args.loss_type == "contrastive": self.loss_fn = MatryoshkaContrastiveLoss( model_args.embedding_temperature, model_args.embedding_matryoshka_dims @@ -62,7 +63,7 @@ def __init__(self, model_args, **kwargs): else: if model_args.loss_type == "inf_cl": self.embedding_negatives_cross_device = False - self.loss_fn = Simple_Inf_cl_loss(model_args.inf_cl_head_dim) + self.loss_fn = SimpleInfclloss(model_args.inf_cl_head_dim) elif model_args.loss_type == "contrastive": self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature) From 6b6a1080f14f600422e959d2ee2c401804d6bb58 Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Mon, 23 Dec 2024 18:33:37 +0800 Subject: [PATCH 4/7] rename simple_inf_cl --- ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py | 6 +++--- paddlenlp/trl/embedding_trainer.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py index 58d2e5f6aff3..302400dcf5c6 100644 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py +++ b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py @@ -19,10 +19,10 @@ from .ring import cal_inf_loss -__all__ = ["SimpleInfclloss", "MatryoshkaInfclLoss"] +__all__ = ["SimpleInfclLoss", "MatryoshkaInfclLoss"] -class SimpleInfclloss(nn.Layer): +class SimpleInfclLoss(nn.Layer): def __init__(self, inf_cl_head_dim=64): """ Initializes the Simple Inf_cl Loss class. @@ -66,7 +66,7 @@ def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl self.embedding_matryoshka_dims = [] else: self.embedding_matryoshka_dims = embedding_matryoshka_dims - self.loss_fn = SimpleInfclloss(inf_cl_head_dim) + self.loss_fn = SimpleInfclLoss(inf_cl_head_dim) def forward(self, q_reps, p_reps): """ diff --git a/paddlenlp/trl/embedding_trainer.py b/paddlenlp/trl/embedding_trainer.py index 45f85dbb6549..cced4ff718b4 100644 --- a/paddlenlp/trl/embedding_trainer.py +++ b/paddlenlp/trl/embedding_trainer.py @@ -21,7 +21,7 @@ from paddlenlp.utils.log import logger try: - from paddlenlp_kernel.triton.inf_cl import MatryoshkaInfclLoss, SimpleInfclloss + from paddlenlp_kernel.triton.inf_cl import MatryoshkaInfclLoss, SimpleInfclLoss except ImportError: logger.warning( "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." @@ -63,7 +63,7 @@ def __init__(self, model_args, **kwargs): else: if model_args.loss_type == "inf_cl": self.embedding_negatives_cross_device = False - self.loss_fn = SimpleInfclloss(model_args.inf_cl_head_dim) + self.loss_fn = SimpleInfclLoss(model_args.inf_cl_head_dim) elif model_args.loss_type == "contrastive": self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature) From d69ac4a10c94e97953e3e4b901db21c2a9bcf459 Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Tue, 24 Dec 2024 16:11:55 +0800 Subject: [PATCH 5/7] Change inf_cl location --- .../triton/inf_cl/__init__.py | 1 - .../triton/inf_cl/inf_cl_loss.py | 103 ------------------ paddlenlp/transformers/contrastive_loss.py | 90 +++++++++++++++ paddlenlp/trl/embedding_trainer.py | 11 +- 4 files changed, 92 insertions(+), 113 deletions(-) delete mode 100644 ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py b/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py index 441f44b821cd..371bdba8a6de 100644 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py +++ b/ops/src/paddlenlp_kernel/triton/inf_cl/__init__.py @@ -13,5 +13,4 @@ # limitations under the License. from .flash import cal_flash_loss -from .inf_cl_loss import * from .ring import cal_inf_loss, cal_ring_loss diff --git a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py b/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py deleted file mode 100644 index 302400dcf5c6..000000000000 --- a/ops/src/paddlenlp_kernel/triton/inf_cl/inf_cl_loss.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import paddle -import paddle.nn as nn - -from .ring import cal_inf_loss - -__all__ = ["SimpleInfclLoss", "MatryoshkaInfclLoss"] - - -class SimpleInfclLoss(nn.Layer): - def __init__(self, inf_cl_head_dim=64): - """ - Initializes the Simple Inf_cl Loss class. - - Args: - inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64. - """ - super().__init__() - self.head_dim = inf_cl_head_dim - - def forward(self, q_reps, p_reps): - """ - Computes the instance discrimination loss. - - Args: - q_reps (Tensor): Query representations. - p_reps (Tensor): key representations. - - Returns: - Tensor: The computed loss. - """ - group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query - labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries - labels = labels * group_size # Adjust labels based on group size - loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim) - return loss - - -class MatryoshkaInfclLoss(nn.Layer): - def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64): - """ - Initializes the Matryoshka Inf_cl Loss class. - - Args: - embedding_matryoshka_dims (List[int], optional): List of dimensions for Matryoshka embeddings. - If None, no Matryoshka embedding is used. Default is None. - inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64. - """ - super().__init__() - if embedding_matryoshka_dims is None: - self.embedding_matryoshka_dims = [] - else: - self.embedding_matryoshka_dims = embedding_matryoshka_dims - self.loss_fn = SimpleInfclLoss(inf_cl_head_dim) - - def forward(self, q_reps, p_reps): - """ - Computes the Matryoshka instance discrimination loss. - - Args: - q_reps (Tensor): Query representations. - p_reps (Tensor): key representations. - - Returns: - Tensor: The computed loss. - """ - if len(self.embedding_matryoshka_dims) > 0: - loss = 0.0 - for dim in self.embedding_matryoshka_dims: - reduced_q_reps = q_reps[:, :dim] # Reduce query representations to the current Matryoshka dimension - reduced_q_reps = nn.functional.normalize( - reduced_q_reps, axis=-1 - ) # Normalize the reduced query representations along the last axis - - reduced_p_reps = p_reps[:, :dim] # Reduce key representations to the current Matryoshka dimension - reduced_p_reps = nn.functional.normalize( - reduced_p_reps, axis=-1 - ) # Normalize the reduced key representations along the last axis - - dim_loss = self.loss_fn( - reduced_q_reps, reduced_p_reps - ) # Compute the loss for the current Matryoshka dimension using the internal loss function - loss += dim_loss - else: - loss = self.loss_fn( - q_reps, p_reps - ) # If no Matryoshka dimensions are specified, compute the loss using the full representations - return loss diff --git a/paddlenlp/transformers/contrastive_loss.py b/paddlenlp/transformers/contrastive_loss.py index 0252c0712a27..06bac62fda37 100644 --- a/paddlenlp/transformers/contrastive_loss.py +++ b/paddlenlp/transformers/contrastive_loss.py @@ -17,6 +17,15 @@ import paddle import paddle.nn as nn +from paddlenlp.utils.log import logger + +try: + from paddlenlp_kernel.triton.inf_cl import cal_inf_loss +except ImportError: + logger.warning( + "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." + ) + class SimpleContrastiveLoss(nn.Layer): def __init__(self, embedding_temperature: float = 0.02): @@ -63,3 +72,84 @@ def forward(self, q_reps, p_reps): else: loss = self.loss_fn(q_reps, p_reps) return loss + + +class SimpleInfclLoss(nn.Layer): + def __init__(self, inf_cl_head_dim=64): + """ + Initializes the Simple Inf_cl Loss class. + + Args: + inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64. + """ + super().__init__() + self.head_dim = inf_cl_head_dim + + def forward(self, q_reps, p_reps): + """ + Computes the instance discrimination loss. + + Args: + q_reps (Tensor): Query representations. + p_reps (Tensor): key representations. + + Returns: + Tensor: The computed loss. + """ + group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query + labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries + labels = labels * group_size # Adjust labels based on group size + loss = cal_inf_loss(q_reps, p_reps, labels=labels, scale=None, head_dim=self.head_dim) + return loss + + +class MatryoshkaInfclLoss(nn.Layer): + def __init__(self, embedding_matryoshka_dims: Optional[List[int]] = None, inf_cl_head_dim=64): + """ + Initializes the Matryoshka Inf_cl Loss class. + + Args: + embedding_matryoshka_dims (List[int], optional): List of dimensions for Matryoshka embeddings. + If None, no Matryoshka embedding is used. Default is None. + inf_cl_head_dim (int, optional): Dimension of the projection head. Default is 64. + """ + super().__init__() + if embedding_matryoshka_dims is None: + self.embedding_matryoshka_dims = [] + else: + self.embedding_matryoshka_dims = embedding_matryoshka_dims + self.loss_fn = SimpleInfclLoss(inf_cl_head_dim) + + def forward(self, q_reps, p_reps): + """ + Computes the Matryoshka instance discrimination loss. + + Args: + q_reps (Tensor): Query representations. + p_reps (Tensor): key representations. + + Returns: + Tensor: The computed loss. + """ + if len(self.embedding_matryoshka_dims) > 0: + loss = 0.0 + for dim in self.embedding_matryoshka_dims: + reduced_q_reps = q_reps[:, :dim] # Reduce query representations to the current Matryoshka dimension + reduced_q_reps = nn.functional.normalize( + reduced_q_reps, axis=-1 + ) # Normalize the reduced query representations along the last axis + + reduced_p_reps = p_reps[:, :dim] # Reduce key representations to the current Matryoshka dimension + reduced_p_reps = nn.functional.normalize( + reduced_p_reps, axis=-1 + ) # Normalize the reduced key representations along the last axis + + dim_loss = self.loss_fn( + reduced_q_reps, reduced_p_reps + ) # Compute the loss for the current Matryoshka dimension using the internal loss function + loss += dim_loss + else: + loss = self.loss_fn( + q_reps, p_reps + ) # If no Matryoshka dimensions are specified, compute the loss using the full representations + return loss diff --git a/paddlenlp/trl/embedding_trainer.py b/paddlenlp/trl/embedding_trainer.py index cced4ff718b4..c50f19738bed 100644 --- a/paddlenlp/trl/embedding_trainer.py +++ b/paddlenlp/trl/embedding_trainer.py @@ -18,19 +18,12 @@ from paddle.base import core from paddle.distributed import fleet -from paddlenlp.utils.log import logger - -try: - from paddlenlp_kernel.triton.inf_cl import MatryoshkaInfclLoss, SimpleInfclLoss -except ImportError: - logger.warning( - "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." - ) - from paddlenlp.trainer import Trainer from paddlenlp.transformers.contrastive_loss import ( MatryoshkaContrastiveLoss, + MatryoshkaInfclLoss, SimpleContrastiveLoss, + SimpleInfclLoss, ) from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient From f05dd610a807219c31a4b3b56eccb7dac02ca19e Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Tue, 24 Dec 2024 16:38:44 +0800 Subject: [PATCH 6/7] Change import location --- paddlenlp/transformers/contrastive_loss.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/paddlenlp/transformers/contrastive_loss.py b/paddlenlp/transformers/contrastive_loss.py index 06bac62fda37..97b3d2fd88c0 100644 --- a/paddlenlp/transformers/contrastive_loss.py +++ b/paddlenlp/transformers/contrastive_loss.py @@ -17,15 +17,6 @@ import paddle import paddle.nn as nn -from paddlenlp.utils.log import logger - -try: - from paddlenlp_kernel.triton.inf_cl import cal_inf_loss -except ImportError: - logger.warning( - "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." - ) - class SimpleContrastiveLoss(nn.Layer): def __init__(self, embedding_temperature: float = 0.02): @@ -96,6 +87,14 @@ def forward(self, q_reps, p_reps): Returns: Tensor: The computed loss. """ + from paddlenlp.utils.log import logger + + try: + from paddlenlp_kernel.triton.inf_cl import cal_inf_loss + except ImportError: + logger.warning( + "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." + ) group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query labels = paddle.arange(q_reps.shape[0], dtype="int64") # Generate labels for queries labels = labels * group_size # Adjust labels based on group size From 8f55e527ba8260d1c7a59b6a240fd810a5cab4b4 Mon Sep 17 00:00:00 2001 From: jie-z-0607 <1712955306@qq.com> Date: Tue, 24 Dec 2024 16:49:49 +0800 Subject: [PATCH 7/7] Change error information --- paddlenlp/transformers/contrastive_loss.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddlenlp/transformers/contrastive_loss.py b/paddlenlp/transformers/contrastive_loss.py index 97b3d2fd88c0..3e132b6f454a 100644 --- a/paddlenlp/transformers/contrastive_loss.py +++ b/paddlenlp/transformers/contrastive_loss.py @@ -87,12 +87,10 @@ def forward(self, q_reps, p_reps): Returns: Tensor: The computed loss. """ - from paddlenlp.utils.log import logger - try: from paddlenlp_kernel.triton.inf_cl import cal_inf_loss except ImportError: - logger.warning( + raise ImportError( "Paddlenlp_kernels are not available, which means the inf_cl loss cannot be used. If you wish to use the inf_cl loss, please follow the instructions in the README.md on the `ops`." ) group_size = p_reps.shape[0] // q_reps.shape[0] # Number of keys per query