From 291adae9ffc6ff3f3133031f8fa24ac2d6de9cf8 Mon Sep 17 00:00:00 2001 From: lewtun Date: Mon, 11 Mar 2024 13:23:56 +0100 Subject: [PATCH] Fix import error from deprecation in transformers (#1415) * Fix import error from deprecation in transformers * Fix import path --- trl/core.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/trl/core.py b/trl/core.py index 9d92ee18f0..0f673c6ddb 100644 --- a/trl/core.py +++ b/trl/core.py @@ -22,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence -from transformers import top_k_top_p_filtering +from transformers.generation import TopKLogitsWarper, TopPLogitsWarper from .import_utils import is_npu_available, is_xpu_available @@ -36,6 +36,42 @@ WANDB_PADDING = -1 +def top_k_top_p_filtering( + logits: torch.FloatTensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. + + Args: + logits: logits distribution shape (batch size, vocabulary size) + top_k (`int`, *optional*, defaults to 0): + If > 0, only keep the top k tokens with highest probability (top-k filtering) + top_p (`float`, *optional*, defaults to 1.0): + If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus + filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimumber of tokens we keep per batch example in the output. + + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + + if top_k > 0: + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) + + if 0 <= top_p <= 1.0: + logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) + + return logits + + def flatten_dict(nested: Dict, sep: str = "/") -> Dict: """Flatten dictionary and concatenate nested keys with separator."""