Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added TimestampRulesFilter implementation #45

Merged
merged 12 commits into from
Mar 22, 2024
194 changes: 188 additions & 6 deletions Sources/WhisperKit/Core/LogitsFilter.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Accelerate
import CoreML
import Foundation
import Tokenizers
Expand Down Expand Up @@ -46,19 +47,200 @@ public class SuppressBlankFilter: LogitsFiltering {
}
}

/// Implementation based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class TimestampRulesFilter: LogitsFiltering {
let tokenizer: Tokenizer
let transcribeToken: Int
let translateToken: Int
let noTimestampsToken: Int
let timeTokenBegin: Int
let endToken: Int
let sampleBegin: Int
let maxInitialTimestamp: Int?
let maxInitialTimestampIndex: Int?
let isModelMultilingual: Bool

public init(tokenizer: Tokenizer, sampleBegin: Int) {
// TODO: implement
fatalError("Not implemented: \(#function)")
public init(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interface is complex in the same way the SegmentSeeker is, I believe most of these are in the tokenizer object but that would require passing this in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, it's complex. I didn't want to make it dependent on Tokenizer so it's decoupled and relatively easier to test. I can change it if you think otherwise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, your logic makes sense, we may want a simple object like SpecialTokens in the future, and extend tokenizer with it, rather than just adding these index properties themselves as extensions.

transcribeToken: Int,
translateToken: Int,
noTimestampsToken: Int,
timeTokenBegin: Int,
endToken: Int,
sampleBegin: Int,
maxInitialTimestampIndex: Int?,
isModelMultilingual: Bool
) {
self.transcribeToken = transcribeToken
self.translateToken = translateToken
self.noTimestampsToken = noTimestampsToken
self.timeTokenBegin = timeTokenBegin
self.endToken = endToken
self.sampleBegin = sampleBegin
self.maxInitialTimestampIndex = maxInitialTimestampIndex
self.isModelMultilingual = isModelMultilingual
}

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
// TODO: implement
guard let sampleBegin = sampleBegin(for: tokens) else {
return logits
}
// suppress <|notimestamps|> which is handled by `withoutTimestamps`
logits.fill(indexes: [[0, 0, noTimestampsToken as NSNumber]], with: -FloatType.infinity)

if tokens.count > sampleBegin {
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
let sampledTokens = tokens[sampleBegin...]
let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
if lastWasTimestamp {
if penultimateWasTimestamp {
// has to be non-timestamp
logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
} else {
// cannot be normal text tokens
logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
}
}

let timestamps = sampledTokens.filter { $0 >= timeTokenBegin }
if let lastTimestamp = timestamps.last {
// timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
// also force each segment to have a nonzero length, to prevent infinite looping
let timestampLast =
if lastWasTimestamp && !penultimateWasTimestamp {
lastTimestamp
} else {
lastTimestamp + 1
}
logits.fillLastDimension(indexes: timeTokenBegin..<timestampLast, with: -FloatType.infinity)
}
}

if tokens.count == sampleBegin {
// suppress generating non-timestamp tokens at the beginning
logits.fillLastDimension(indexes: 0..<timeTokenBegin, with: -FloatType.infinity)
if let maxInitialTimestampIndex {
// apply the `maxInitialTimestamp` option
let lastAllowed = timeTokenBegin + maxInitialTimestampIndex + 1
logits.fillLastDimension(indexes: lastAllowed..<logits.count, with: -FloatType.infinity)
}
}

// if sum of probability over timestamps is above any other token, sample timestamp
if sumOfProbabilityOverTimestampsIsAboveAnyOtherToken(logits: logits, timeTokenBegin: timeTokenBegin) {
logits.fillLastDimension(indexes: 0..<timeTokenBegin, with: -FloatType.infinity)
}
return logits
}

private func sampleBegin(for tokens: [Int]) -> Int? {
if isModelMultilingual {
// NOTE: for multilingual model we don't want to supress "<|transcribe|>" or "<|translate|>" tokens
if let taskTokenIndex = tokens.prefix(3).firstIndex(where: { $0 == transcribeToken || $0 == translateToken }) {
return max(taskTokenIndex + 1, sampleBegin)
} else {
return nil
}
} else {
return sampleBegin
}
}

private func sumOfProbabilityOverTimestampsIsAboveAnyOtherToken(logits: MLMultiArray, timeTokenBegin: Int) -> Bool {
let timeTokenBeginOffset = logits.linearOffset(for: [0, 0, timeTokenBegin as NSNumber])

let logprobsInputPointer = UnsafeMutableRawBufferPointer(
start: logits.dataPointer,
count: logits.count * MemoryLayout<FloatType>.stride
)

guard let logprobsInputDescriptor = BNNSNDArrayDescriptor(
data: logprobsInputPointer,
scalarType: FloatType.self,
shape: .vector(logits.count, stride: 1)
) else {
Logging.error("Cannot create `logprobsInputDescriptor`")
return false
}

let logprobs = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: FloatType.self,
shape: .vector(logits.count, stride: 1)
)
defer { logprobs.deallocate() }

do {
try BNNS.applyActivation(
activation: BNNS.ActivationFunction.logSoftmax,
input: logprobsInputDescriptor,
output: logprobs,
batchSize: 1
)

let timeTokenCount = logits.count - timeTokenBeginOffset
let noTimeTokenCount = timeTokenBeginOffset
let logSumExpInputPointer = UnsafeMutableRawBufferPointer(
start: logprobs.data!.advanced(by: timeTokenBeginOffset * MemoryLayout<FloatType>.stride),
count: timeTokenCount * MemoryLayout<FloatType>.stride
)

guard let logSumExpInputDescriptor = BNNSNDArrayDescriptor(
data: logSumExpInputPointer,
scalarType: FloatType.self,
shape: .vector(timeTokenCount, stride: 1)
) else {
Logging.error("Cannot create `logSumExpInputDescriptor`")
return false
}

let timestampLogProb = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: FloatType.self,
shape: .vector(1, stride: 1)
)
defer { timestampLogProb.deallocate() }

try BNNS.applyReduction(
.logSumExp,
input: logSumExpInputDescriptor,
output: timestampLogProb,
weights: nil
)

let maxTextTokenLogProbInputPointer = UnsafeMutableRawBufferPointer(
start: logprobs.data,
count: noTimeTokenCount * MemoryLayout<FloatType>.stride
)

guard let maxTextTokenLogProbInputDescriptor = BNNSNDArrayDescriptor(
data: maxTextTokenLogProbInputPointer,
scalarType: FloatType.self,
shape: .vector(noTimeTokenCount, stride: 1)
) else {
Logging.error("Cannot create `maxTextTokenLogProbInputDescriptor`")
return false
}

let maxTextTokenLogProb = BNNSNDArrayDescriptor.allocateUninitialized(
scalarType: FloatType.self,
shape: .vector(1, stride: 1)
)
defer { maxTextTokenLogProb.deallocate() }

try BNNS.applyReduction(
.max,
input: maxTextTokenLogProbInputDescriptor,
output: maxTextTokenLogProb,
weights: nil
)

guard let timestampLogProbValue = timestampLogProb.makeArray(of: FloatType.self)?.first,
let maxTextTokenLogProbValue = maxTextTokenLogProb.makeArray(of: FloatType.self)?.first else {
Logging.error("Cannot create logProb arrays")
return false
}
return timestampLogProbValue > maxTextTokenLogProbValue
} catch {
Logging.error("TimestampRulesFilter error: \(error)")
return false
}
}
}
10 changes: 9 additions & 1 deletion Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public typealias FloatType = Float

#if (os(macOS) || targetEnvironment(macCatalyst)) && arch(arm64)
extension Float16: BNNSScalar {}
extension Float16: MLShapedArrayScalar {}
#endif

// MARK: - CoreML
Expand Down Expand Up @@ -209,9 +210,10 @@ public struct DecodingCache {
/// - sampleLength: The maximum number of tokens to sample.
/// - topK: Number of candidates when sampling with non-zero temperature.
/// - usePrefillPrompt: If true, the prefill tokens will be forced according to task and language settings.
/// - usePrefillPrompt: If true, the kv cache will be prefilled based on the prefill data mlmodel.
/// - usePrefillCache: If true, the kv cache will be prefilled based on the prefill data mlmodel.
/// - skipSpecialTokens: Whether to skip special tokens in the output.
/// - withoutTimestamps: Whether to include timestamps in the transcription result.
/// - maxInitialTimestamp: Maximal initial timestamp.
/// - suppressBlank: If true, blank tokens will be suppressed during decoding.
/// - supressTokens: List of token IDs to suppress during decoding.
/// - compressionRatioThreshold: If the compression ratio of the transcription text is above this value, it is too repetitive and treated as failed.
Expand All @@ -233,6 +235,7 @@ public struct DecodingOptions {
public var skipSpecialTokens: Bool
public var withoutTimestamps: Bool
public var wordTimestamps: Bool
public var maxInitialTimestamp: Float?
public var clipTimestamps: [Float]
public var suppressBlank: Bool
public var supressTokens: [Int]
Expand All @@ -253,6 +256,7 @@ public struct DecodingOptions {
skipSpecialTokens: Bool = false,
withoutTimestamps: Bool = false,
wordTimestamps: Bool = false,
maxInitialTimestamp: Float? = nil,
clipTimestamps: [Float] = [],
suppressBlank: Bool = false,
supressTokens: [Int]? = nil,
Expand All @@ -273,6 +277,7 @@ public struct DecodingOptions {
self.skipSpecialTokens = skipSpecialTokens
self.withoutTimestamps = withoutTimestamps
self.wordTimestamps = wordTimestamps
self.maxInitialTimestamp = maxInitialTimestamp
self.clipTimestamps = clipTimestamps
self.suppressBlank = suppressBlank
self.supressTokens = supressTokens ?? [] // nonSpeechTokens() // TODO: implement these as default
Expand Down Expand Up @@ -399,6 +404,7 @@ public struct TranscriptionTimings: Codable {
public var decodingInit: TimeInterval
public var decodingLoop: TimeInterval
public var decodingPredictions: TimeInterval
public var decodingFiltering: TimeInterval
public var decodingSampling: TimeInterval
public var decodingFallback: TimeInterval
public var decodingWindowing: TimeInterval
Expand Down Expand Up @@ -434,6 +440,7 @@ public struct TranscriptionTimings: Codable {
decodingInit: TimeInterval = 0,
decodingLoop: TimeInterval = 0,
decodingPredictions: TimeInterval = 0,
decodingFiltering: TimeInterval = 0,
decodingSampling: TimeInterval = 0,
decodingFallback: TimeInterval = 0,
decodingWindowing: TimeInterval = 0,
Expand Down Expand Up @@ -462,6 +469,7 @@ public struct TranscriptionTimings: Codable {
self.decodingInit = decodingInit
self.decodingLoop = decodingLoop
self.decodingPredictions = decodingPredictions
self.decodingFiltering = decodingFiltering
self.decodingSampling = decodingSampling
self.decodingFallback = decodingFallback
self.decodingWindowing = decodingWindowing
Expand Down
18 changes: 8 additions & 10 deletions Sources/WhisperKit/Core/SegmentSeeker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,19 @@ public class SegmentSeeker: SegmentSeeking {
let lastThreeTokens = isTimestampToken.suffix(3)
let singleTimestampEnding = lastThreeTokens == [false, true, false]

// find all indexes of time token pairs
var consecutive = [(start: Int, end: Int)]()
// find all end indexes of time token pairs
var sliceIndexes = [Int]()

var previousTokenIsTimestamp = false
for (i, tokenIsTimestamp) in isTimestampToken.enumerated() {
if previousTokenIsTimestamp && tokenIsTimestamp {
consecutive.append((i - 1, i))
for (currentTokenIsTimestampIndex, currentTokenIsTimestamp) in isTimestampToken.enumerated() {
if previousTokenIsTimestamp && currentTokenIsTimestamp {
sliceIndexes.append(currentTokenIsTimestampIndex)
}
previousTokenIsTimestamp = tokenIsTimestamp
previousTokenIsTimestamp = currentTokenIsTimestamp
}

if !consecutive.isEmpty {
// Window contains multiple consecutive timestamps, split into sub-segments
var sliceIndexes = consecutive.map { $0.end }

// Window contains multiple consecutive timestamps, split into sub-segments
if !sliceIndexes.isEmpty {
// If the last timestamp is not consecutive, we need to add it as the final slice manually
if singleTimestampEnding {
let singleTimestampEndingIndex = isTimestampToken.lastIndex(where: { $0 })!
Expand Down
23 changes: 21 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,24 @@ public class TextDecoder: TextDecoding, WhisperMLModel {
}

if !options.withoutTimestamps {
// TODO: implement
// logitsFilters.append(TimestampRulesFilter(tokenizer: tokenizer, sampleBegin: prefilledIndex))
let maxInitialTimestampIndex: Int? =
if let maxInitialTimestamp = options.maxInitialTimestamp {
Int(maxInitialTimestamp / WhisperKit.secondsPerTimeToken)
} else {
nil
}
logitsFilters.append(
TimestampRulesFilter(
transcribeToken: tokenizer.transcribeToken,
translateToken: tokenizer.translateToken,
noTimestampsToken: tokenizer.noTimestampsToken,
timeTokenBegin: tokenizer.timeTokenBegin,
endToken: tokenizer.endToken,
sampleBegin: intialPromptIndex,
maxInitialTimestampIndex: maxInitialTimestampIndex,
isModelMultilingual: isModelMultilingual(logitsDim: logitsSize)
)
)
}

// MARK: Main loop
Expand Down Expand Up @@ -417,6 +433,9 @@ public class TextDecoder: TextDecoding, WhisperMLModel {
logits = filter.filterLogits(logits, withTokens: currentTokens)
}

let filteringTime = Date().timeIntervalSince(nonInferenceStartTime)
timings.decodingFiltering += filteringTime

// MARK: Sampling

let samplingStartTime = Date()
Expand Down
13 changes: 13 additions & 0 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ extension MLMultiArray {
return linearOffset
}

func fillLastDimension(indexes: Range<Int>, with value: FloatType) {
precondition(shape.count == 3 && shape[0] == 1 && shape[1] == 1, "Must have [1, 1, n] shape")
withUnsafeMutableBufferPointer(ofType: FloatType.self) { ptr, strides in
for index in indexes {
ptr[index * strides[2]] = value
}
}
}

func fill<Value>(indexes: [[NSNumber]], with value: Value) {
let pointer = UnsafeMutablePointer<Value>(OpaquePointer(dataPointer))
let strideInts = strides.map { $0.intValue }
Expand Down Expand Up @@ -135,6 +144,10 @@ func tokenizerNameForVariant(_ variant: ModelVariant) -> String {
return tokenizerName
}

func isModelMultilingual(logitsDim: Int?) -> Bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this already here:

public var isMultilingual: Bool {

Thoughts on combining them? I think checking the logitDims is more robust, perhaps it can be set on the model or the textdecoder on load here:

if let logitsDim = textDecoder.logitsSize,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough, I can tackle it in a separate PR

logitsDim != 51864
}

func detectVariant(logitsDim: Int, encoderDim: Int) -> ModelVariant {
// Defaults
var modelVariant: ModelVariant = .base
Expand Down
Loading
Loading