-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added TimestampRulesFilter implementation #45
Added TimestampRulesFilter implementation #45
Conversation
XCTAssertEqual(result?.segments.count, 2, "Expected 2 segments") | ||
XCTAssertEqual(result?.segments.count, 3, "Expected 3 segments") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enabled timestamps are causing more segments to appear
@jkrukowski I push a small commit to measure the logit filtering time, here is what I'm getting for tiny with and without these new timestamp rules on the This is a bit high, it becomes especially noticeable with the tiny model. Something that is interesting is that only the first and last few tokens are slow (graph by chatgpt). This is for the jfk.wav
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
let sampledTokens = tokens[sampleBegin...]
let lastWasTimestamp = sampledTokens.count >= 1 && sampledTokens.last! >= timeTokenBegin
let penultimateWasTimestamp = sampledTokens.count < 2 || sampledTokens.dropLast().last! >= timeTokenBegin
if lastWasTimestamp {
if penultimateWasTimestamp {
// has to be non-timestamp
logits.fillLastDimension(indexes: timeTokenBegin..<logits.count, with: -FloatType.infinity)
} else {
// cannot be normal text tokens
logits.fillLastDimension(indexes: 0..<endToken, with: -FloatType.infinity)
}
} |
# Conflicts: # Sources/WhisperKitCLI/transcribe.swift
@ZachNagengast I've added more performant version of
|
Much better! This looks in line with what I was seeing for those faster middle tokens previously. Think this is ready to come out of draft now? |
good to hear this, 2 things are left:
|
PrefilledIndex is already being passed into this function, but I think actually it should use
Besides the verbosity I think it's ok. If you want to be extra safe, you can wrap that whole part in a do catch and log an error similar to the sampling code. I'm not sure all the scenarios where BNNS will throw, but returning false would just fallback to default behavior so no issues there. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving for the pre-release tests but curious your thoughts on the comments, can be a future PR.
public init(tokenizer: Tokenizer, sampleBegin: Int) { | ||
// TODO: implement | ||
fatalError("Not implemented: \(#function)") | ||
public init( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This interface is complex in the same way the SegmentSeeker is, I believe most of these are in the tokenizer object but that would require passing this in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, it's complex. I didn't want to make it dependent on Tokenizer
so it's decoupled and relatively easier to test. I can change it if you think otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, your logic makes sense, we may want a simple object like SpecialTokens
in the future, and extend tokenizer with it, rather than just adding these index properties themselves as extensions.
@@ -135,6 +144,10 @@ func tokenizerNameForVariant(_ variant: ModelVariant) -> String { | |||
return tokenizerName | |||
} | |||
|
|||
func isModelMultilingual(logitsDim: Int?) -> Bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this already here:
public var isMultilingual: Bool { |
Thoughts on combining them? I think checking the logitDims is more robust, perhaps it can be set on the model or the textdecoder on load here:
if let logitsDim = textDecoder.logitsSize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair enough, I can tackle it in a separate PR
This PR adds implementation for
TimestampRulesFilter
. The implementation is based on https://github.com/openai/whisper/blob/master/whisper/decoding.py#L441Couple of questions here @ZachNagengast:
sampleBegin
param passed toTimestampRulesFilter
is 0, I think it might be incorrect. I compared it to the python implementation from the OpenAI repo and there this param is always greater or equal than 3 (and this makes sense, first 3 tokens are special tokens: 50258, 50259 and 50359 and AFAIK we don't want to supress them). If you run this code as is, some segments might be omited (because of thesampleBegin
is 0, if you change it to 3, it should be ok).SegmentSeeker
, but I don't see it (AFAIKTimestampRulesFilter
just supresses the token probabilities, whileSegmentSeeker
creates the whole segments). Could you please clarify?