Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added TimestampRulesFilter implementation #45

Merged
merged 12 commits into from
Mar 22, 2024

Conversation

jkrukowski
Copy link
Contributor

This PR adds implementation for TimestampRulesFilter. The implementation is based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441

Couple of questions here @ZachNagengast:

  • sampleBegin param passed to TimestampRulesFilter is 0, I think it might be incorrect. I compared it to the python implementation from the OpenAI repo and there this param is always greater or equal than 3 (and this makes sense, first 3 tokens are special tokens: 50258, 50259 and 50359 and AFAIK we don't want to supress them). If you run this code as is, some segments might be omited (because of the sampleBegin is 0, if you change it to 3, it should be ok).
  • this implementation slows down the whole inference code, maybe you have some ideas how to optimize it?
  • you mentioned that is has duplicated logic with SegmentSeeker, but I don't see it (AFAIK TimestampRulesFilter just supresses the token probabilities, while SegmentSeeker creates the whole segments). Could you please clarify?

XCTAssertEqual(result?.segments.count, 2, "Expected 2 segments")
XCTAssertEqual(result?.segments.count, 3, "Expected 3 segments")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enabled timestamps are causing more segments to appear

@ZachNagengast ZachNagengast linked an issue Mar 8, 2024 that may be closed by this pull request
@ZachNagengast
Copy link
Contributor

@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav file:
With:
[WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78%
Without:
[WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%

This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav

image
Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:

            // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
            let sampledTokens = tokens[sampleBegin...]
            let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
            let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
            if lastWasTimestamp {
                if penultimateWasTimestamp {
                    // has to be non-timestamp
                    logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
                } else {
                    // cannot be normal text tokens
                    logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
                }
            }

@jkrukowski
Copy link
Contributor Author

@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the jfk.wav file: With: [WhisperKit] - Logit Filtering: 192.41 ms / 28 runs ( 6.87 ms/run) 37.78% Without: [WhisperKit] - Logit Filtering: 0.07 ms / 28 runs ( 0.00 ms/run) 0.02%

This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav

image Hopefully this gives you some guidance on where to look for optimizations. And the majority of the slowdown is in this block of code:

            // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
            let sampledTokens = tokens[sampleBegin...]
            let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
            let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
            if lastWasTimestamp {
                if penultimateWasTimestamp {
                    // has to be non-timestamp
                    logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
                } else {
                    // cannot be normal text tokens
                    logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
                }
            }

@ZachNagengast I've added more performant version of fillLastDimension function, seems like it's doing better, this is what I get for the release build on the jfk.wav file:

[WhisperKit] ---- Transcription Timings ----
[WhisperKit] Audio Load:              2.33 ms /      1 runs (    2.33 ms/run)  0.66%
[WhisperKit] Audio Processing:        0.11 ms /      1 runs (    0.11 ms/run)  0.03%
[WhisperKit] Mels:                   35.53 ms /      1 runs (   35.53 ms/run) 10.11%
[WhisperKit] Encoding:               13.39 ms /      1 runs (   13.39 ms/run)  3.81%
[WhisperKit] Matrices Init:           0.22 ms /      1 runs (    0.22 ms/run)  0.06%
[WhisperKit] Prefill:                 0.00 ms /      1 runs (    0.00 ms/run)  0.00%
[WhisperKit] Decoding:              239.40 ms /     28 runs (    8.55 ms/run) 68.15%
[WhisperKit] Non-inference:          61.25 ms /     28 runs (    2.19 ms/run) 17.43%
[WhisperKit] - Logit Filtering:       3.24 ms /     28 runs (    0.12 ms/run)  0.92%
[WhisperKit] - Sampling:             14.17 ms /     28 runs (    0.51 ms/run)  4.03%
[WhisperKit] - Kv Caching:            2.79 ms /     28 runs (    0.10 ms/run)  0.80%
[WhisperKit] - Word Timestamps:       0.00 ms /      0 runs (    0.00 ms/run)  0.00%
[WhisperKit] - Windowing:             0.08 ms /      1 runs (    0.08 ms/run)  0.02%
[WhisperKit] Fallbacks:               0.00 ms /      0 runs (    0.00 ms/run)  0.00%
[WhisperKit] Decoding Full Loop:    351.06 ms /     28 runs (   12.54 ms/run) 99.93%

@ZachNagengast
Copy link
Contributor

Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?

@jkrukowski
Copy link
Contributor Author

Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now?

good to hear this, 2 things are left:

  1. self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there
  2. force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?

@ZachNagengast
Copy link
Contributor

  1. self.sampleBegin = 3 // FIXME: it should not be hardcoded value -- not sure what value should I put there

PrefilledIndex is already being passed into this function, but I think actually it should use intialPromptIndex. A good test to add for accuracy on this would be similar to this one

func testSampleLength() async {
where you'd create a bunch of options that change this initialPromptIndex and make sure it's working properly.

  1. force unwrapping in sumOfProbabilityOverTimestampsIsAboveAnyOtherToken maybe we should not force unwrap and return false gracefully, wdyt?

Besides the verbosity I think it's ok. If you want to be extra safe, you can wrap that whole part in a do catch and log an error similar to the sampling code. I'm not sure all the scenarios where BNNS will throw, but returning false would just fallback to default behavior so no issues there.

@jkrukowski jkrukowski marked this pull request as ready for review March 20, 2024 11:33
Copy link
Contributor

@ZachNagengast ZachNagengast left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving for the pre-release tests but curious your thoughts on the comments, can be a future PR.

public init(tokenizer: Tokenizer, sampleBegin: Int) {
// TODO: implement
fatalError("Not implemented: \(#function)")
public init(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interface is complex in the same way the SegmentSeeker is, I believe most of these are in the tokenizer object but that would require passing this in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, it's complex. I didn't want to make it dependent on Tokenizer so it's decoupled and relatively easier to test. I can change it if you think otherwise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, your logic makes sense, we may want a simple object like SpecialTokens in the future, and extend tokenizer with it, rather than just adding these index properties themselves as extensions.

@@ -135,6 +144,10 @@ func tokenizerNameForVariant(_ variant: ModelVariant) -> String {
return tokenizerName
}

func isModelMultilingual(logitsDim: Int?) -> Bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this already here:

public var isMultilingual: Bool {

Thoughts on combining them? I think checking the logitDims is more robust, perhaps it can be set on the model or the textdecoder on load here:

if let logitsDim = textDecoder.logitsSize,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough, I can tackle it in a separate PR

Sources/WhisperKit/Core/WhisperKit.swift Show resolved Hide resolved
@ZachNagengast ZachNagengast merged commit 508240f into argmaxinc:main Mar 22, 2024
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Timestamp Rules Logits Filter
2 participants