Skip to content

Commit

Permalink
Add repetition penalty warper (#85)
Browse files Browse the repository at this point in the history
* Add repetition penalty warper

* Float the penalty

  * Add penalty to logits warpers
  * Test repetition penalty
  • Loading branch information
shavit authored Apr 26, 2024
1 parent 9df94c1 commit 5e02089
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Sources/Generation/Generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ public extension Generation {
if config.topP < 1.0 {
logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
}
if config.repetitionPenalty != 1.0 {
logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty))
}
return logitsWarpers
}
}
25 changes: 25 additions & 0 deletions Sources/TensorUtils/LogitsWarper/RepetitionPenaltyWarper.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import Foundation

/// `RepetitionPenaltyWarper` prevents the repetition of previous tokens through a penalty.
/// This penalty is applied at most once per token.
/// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L294
public struct RepetitionPenaltyWarper: LogitsWarper {
public var penalty: Float

public init(penalty: Double) {
self.penalty = Float(penalty)
}

public func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float]) {
var logits = logits
for index in indices {
if logits[index] < 0 {
logits[index] *= penalty
} else {
logits[index] /= penalty
}
}

return (indices, logits)
}
}
30 changes: 30 additions & 0 deletions Tests/TensorUtilsTests/LogitsWarperTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,36 @@ final class LogitsWarperTests: XCTestCase {
XCTAssertEqual(result5.logits, [2, 1, 0], accuracy: accuracy)
}

func testRepetitionPenaltyWarper() {
let indices = Array(0..<10)
let logits = indices.map({ Float($0) })

let result1 = RepetitionPenaltyWarper(penalty: 1.0)(indices, logits)
XCTAssertEqual(result1.indices, indices)
XCTAssertEqual(result1.logits, logits, accuracy: accuracy)

let result2 = RepetitionPenaltyWarper(penalty: 3.75)(indices, logits)
XCTAssertEqual(result2.indices, indices)
let logits2 = indices.map({ Float($0) / 3.75 })
XCTAssertEqual(result2.logits, logits2, accuracy: accuracy)

let result3 = RepetitionPenaltyWarper(penalty: 0.75)([0, 1, 2], [0.8108, 0.9954, 0.0119])
XCTAssertEqual(result3.indices, [0, 1, 2])
XCTAssertEqual(result3.logits, [1.0811, 1.3272, 0.0158], accuracy: 1e-4)

let result4 = RepetitionPenaltyWarper(penalty: 1.11)([2, 3, 4], [0.5029, 0.8694, 0.4765, 0.9967, 0.4190, 0.9158])
XCTAssertEqual(result4.indices, [2, 3, 4])
XCTAssertEqual(result4.logits, [0.5029, 0.8694, 0.4293, 0.8980, 0.3775, 0.9158], accuracy: 1e-4)

let result5 = RepetitionPenaltyWarper(penalty: 0.9)([0, 1, 2], [-0.7433, -0.4738, -0.2966])
XCTAssertEqual(result5.indices, [0, 1, 2])
XCTAssertEqual(result5.logits, [-0.6690, -0.4264, -0.2669], accuracy: 1e-4)

let result6 = RepetitionPenaltyWarper(penalty: 1.125)([3, 1, 2], [0.1674, 0.6431, 0.6780, 0.2755])
XCTAssertEqual(result6.indices, [3, 1, 2])
XCTAssertEqual(result6.logits, [0.1674, 0.5716, 0.6026, 0.2449], accuracy: 1e-4)
}

func testLogitsProcessor() {
let processor1 = LogitsProcessor(logitsWarpers: [])
let result1 = processor1([])
Expand Down

0 comments on commit 5e02089

Please sign in to comment.