Skip to content

Commit

Permalink
Fix repetition penalty logits processor (#1062)
Browse files Browse the repository at this point in the history
* Fix repetition penalty logits processor

* Fix return types of logits processors
  • Loading branch information
xenova authored Dec 2, 2024
1 parent 2c92943 commit 9584263
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions src/generation/logits_process.js
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ export class ForcedBOSTokenLogitsProcessor extends LogitsProcessor {
* Apply the BOS token forcing to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with BOS token forcing.
* @returns {Tensor} The logits with BOS token forcing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -221,7 +221,7 @@ export class SuppressTokensAtBeginLogitsProcessor extends LogitsProcessor {
* Apply the BOS token forcing to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with BOS token forcing.
* @returns {Tensor} The logits with BOS token forcing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -391,7 +391,7 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
* Apply the no-repeat-ngram processor to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with no-repeat-ngram processing.
* @returns {Tensor} The logits with no-repeat-ngram processing.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand All @@ -406,12 +406,22 @@ export class NoRepeatNGramLogitsProcessor extends LogitsProcessor {
}

/**
* A logits processor that penalises repeated output tokens.
* A logits processor that prevents the repetition of previous tokens through a penalty.
* This penalty is applied at most once per token. Note that, for decoder-only models like most LLMs,
* the considered tokens include the prompt.
*
* In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a
* penalty of around 1.2 to achieve a good balance between truthful generation and lack of repetition.
* To penalize and reduce repetition, use `penalty` values above 1.0, where a higher value penalizes
* more strongly. To reward and encourage repetition, use `penalty` values between 0.0 and 1.0, where
* a lower value rewards more strongly.
*/
export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
/**
* Create a RepetitionPenaltyLogitsProcessor.
* @param {number} penalty The penalty to apply for repeated tokens.
* @param {number} penalty The parameter for repetition penalty.
* - 1.0 means no penalty. Above 1.0 penalizes previously generated tokens.
* - Between 0.0 and 1.0 rewards previously generated tokens.
*/
constructor(penalty) {
super();
Expand All @@ -422,16 +432,12 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
* Apply the repetition penalty to the logits.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The logits with repetition penalty processing.
* @returns {Tensor} The logits with repetition penalty processing.
*/
_call(input_ids, logits) {
// Modify the logits corresponding to each element in `input_ids`.
// As a consequence, the logits corresponding to tokens that appear
// many times in the output will be penalised more.

for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
for (const input_id of input_ids[i]) {
for (const input_id of new Set(input_ids[i])) {
const token = Number(input_id);
if (batch_logits_data[token] < 0) {
batch_logits_data[token] *= this.penalty;
Expand Down Expand Up @@ -464,7 +470,7 @@ export class MinLengthLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -502,7 +508,7 @@ export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -535,7 +541,7 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
Expand Down Expand Up @@ -596,7 +602,7 @@ export class ClassifierFreeGuidanceLogitsProcessor extends LogitsProcessor {
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
if (logits.dims[0] !== 2 * input_ids.length) {
Expand Down Expand Up @@ -650,7 +656,7 @@ export class TemperatureLogitsWarper extends LogitsWarper {
* Apply logit warper.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
* @returns {Tensor} The processed logits.
*/
_call(input_ids, logits) {
const batch_logits_data = /** @type {Float32Array} */(logits.data);
Expand Down

0 comments on commit 9584263

Please sign in to comment.