Skip to content

Commit

Permalink
Add helper for mlpackage detection
Browse files Browse the repository at this point in the history
  • Loading branch information
ZachNagengast committed Aug 10, 2024
1 parent 5918a59 commit fde66a8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 46 deletions.
12 changes: 10 additions & 2 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,6 @@ public struct TranscriptionResult: Codable {
// NOTE: this is a relative value for percentage calculations
let fullDecodingDuration = max(timings.decodingLoop, timings.fullPipeline) * 1000 // Convert to milliseconds

let encoderLoadTime = timings.encoderLoadTime
let decoderLoadTime = timings.decoderLoadTime
let audioLoadTime = formatTimeWithPercentage(timings.audioLoading, 1, fullDecodingDuration)
let audioProcTime = formatTimeWithPercentage(timings.audioProcessing, timings.totalAudioProcessingRuns, fullDecodingDuration)
let logmelsTime = formatTimeWithPercentage(timings.logmels, timings.totalLogmelRuns, fullDecodingDuration)
Expand Down Expand Up @@ -574,6 +572,10 @@ public struct TranscriptionResult: Codable {
Decoding Full Loop: \(decodingLoopInfo)
-------------------------------
Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds
- Prewarm: \(String(format: "%.2f", timings.prewarmLoadTime)) seconds
- Encoder: \(String(format: "%.2f", timings.encoderLoadTime)) seconds
- Decoder: \(String(format: "%.2f", timings.decoderLoadTime)) seconds
- Tokenizer: \(String(format: "%.2f", timings.tokenizerLoadTime)) seconds
Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds
- Decoding Loop (Avg/window): \(String(format: "%.2f", decodeTimePerWindow)) seconds
- Audio Windows: \(String(format: "%.2f", timings.totalAudioProcessingRuns))
Expand Down Expand Up @@ -652,8 +654,10 @@ public struct TranscriptionTimings: Codable {
public var firstTokenTime: CFAbsoluteTime
public var inputAudioSeconds: TimeInterval
public var modelLoading: TimeInterval
public var prewarmLoadTime: TimeInterval
public var encoderLoadTime: TimeInterval
public var decoderLoadTime: TimeInterval
public var tokenizerLoadTime: TimeInterval
public var audioLoading: TimeInterval
public var audioProcessing: TimeInterval
public var logmels: TimeInterval
Expand Down Expand Up @@ -694,8 +698,10 @@ public struct TranscriptionTimings: Codable {

/// Initialize with all time intervals set to zero.
public init(modelLoading: TimeInterval = 0,
prewarmLoadTime: TimeInterval = 0,
encoderLoadTime: TimeInterval = 0,
decoderLoadTime: TimeInterval = 0,
tokenizerLoadTime: TimeInterval = 0,
audioLoading: TimeInterval = 0,
audioProcessing: TimeInterval = 0,
logmels: TimeInterval = 0,
Expand Down Expand Up @@ -725,8 +731,10 @@ public struct TranscriptionTimings: Codable {
self.firstTokenTime = Double.greatestFiniteMagnitude
self.inputAudioSeconds = 0.001
self.modelLoading = modelLoading
self.prewarmLoadTime = prewarmLoadTime
self.encoderLoadTime = encoderLoadTime
self.decoderLoadTime = decoderLoadTime
self.tokenizerLoadTime = tokenizerLoadTime
self.audioLoading = audioLoading
self.audioProcessing = audioProcessing
self.logmels = logmels
Expand Down
18 changes: 18 additions & 0 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,22 @@ public func modelSupport(for deviceName: String) -> (default: String, disabled:
return ("openai_whisper-base", [""])
}

public func detectModelURL(inFolder path: URL, named modelName: String) -> URL {
let compiledUrl = path.appending(path: "\(modelName).mlmodelc")
let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel")

let compiledModelExists: Bool = FileManager.default.fileExists(atPath: compiledUrl.path)
let packageModelExists: Bool = FileManager.default.fileExists(atPath: packageUrl.path)

// Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc
var modelURL = compiledUrl
if (packageModelExists && !compiledModelExists) {
modelURL = packageUrl
}

return modelURL
}

public func resolveAbsolutePath(_ inputPath: String) -> String {
let fileManager = FileManager.default

Expand Down Expand Up @@ -621,8 +637,10 @@ public func mergeTranscriptionResults(_ results: [TranscriptionResult?], confirm
// Update the merged timings with non-overlapping time values
var mergedTimings = TranscriptionTimings(
modelLoading: validResults.map { $0.timings.modelLoading }.max() ?? 0,
prewarmLoadTime: validResults.map { $0.timings.prewarmLoadTime }.max() ?? 0,
encoderLoadTime: validResults.map { $0.timings.encoderLoadTime }.max() ?? 0,
decoderLoadTime: validResults.map { $0.timings.decoderLoadTime }.max() ?? 0,
tokenizerLoadTime: validResults.map { $0.timings.tokenizerLoadTime }.max() ?? 0,
audioLoading: validResults.map { $0.timings.audioLoading }.reduce(0, +),
audioProcessing: validResults.map { $0.timings.audioProcessing }.reduce(0, +),
logmels: validResults.map { $0.timings.logmels }.reduce(0, +),
Expand Down
63 changes: 19 additions & 44 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -261,36 +261,11 @@ open class WhisperKit {

Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)")

var logmelUrl = path.appending(path: "MelSpectrogram.mlmodelc")
var encoderUrl = path.appending(path: "AudioEncoder.mlmodelc")
var decoderUrl = path.appending(path: "TextDecoder.mlmodelc")
var decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc")

var logmelMLPackageUrl = path.appending(path: "MelSpectrogram.mlpackage/Data/com.apple.CoreML/model.mlmodel")
var encoderMLPackageUrl = path.appending(path: "AudioEncoder.mlpackage/Data/com.apple.CoreML/model.mlmodel")
var decoderMLPackageUrl = path.appending(path: "TextDecoder.mlpackage/Data/com.apple.CoreML/model.mlmodel")
var decoderPrefillMLPackageUrl = path.appending(path: "TextDecoderContextPrefill/Data/com.apple.CoreML/model.mlmodel")

let encoderMLModelc: Bool = FileManager.default.fileExists(atPath: encoderUrl.path)
let encoderMLPackage: Bool = FileManager.default.fileExists(atPath: encoderMLPackageUrl.path)

let decoderMLModelc: Bool = FileManager.default.fileExists(atPath: decoderUrl.path)
let decoderMLPackage: Bool = FileManager.default.fileExists(atPath: decoderMLPackageUrl.path)

let logmelMLModelc: Bool = FileManager.default.fileExists(atPath: logmelUrl.path)
let logmelMLPackage: Bool = FileManager.default.fileExists(atPath: logmelMLPackageUrl.path)

// Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc
let swapURLIfTrue: (Bool, Bool, URL, URL) -> URL = { foundMLPackage, foundMLModelc, mlPackageURL, mlModelcUrl in
if (foundMLPackage && !foundMLModelc) {
return mlPackageURL
}
return mlModelcUrl
}

encoderUrl = swapURLIfTrue(encoderMLPackage, encoderMLModelc, encoderMLPackageUrl, encoderUrl)
decoderUrl = swapURLIfTrue(decoderMLPackage, decoderMLModelc, decoderMLPackageUrl, decoderUrl)
logmelUrl = swapURLIfTrue(logmelMLPackage, logmelMLModelc, logmelMLPackageUrl, logmelUrl)
// Find either mlmodelc or mlpackage models
let logmelUrl = detectModelURL(inFolder: path, named: "MelSpectrogram")
let encoderUrl = detectModelURL(inFolder: path, named: "AudioEncoder")
let decoderUrl = detectModelURL(inFolder: path, named: "TextDecoder")
let decoderPrefillUrl = detectModelURL(inFolder: path, named: "TextDecoderContextPrefill")

for item in [logmelUrl, encoderUrl, decoderUrl] {
if !FileManager.default.fileExists(atPath: item.path) {
Expand Down Expand Up @@ -321,38 +296,34 @@ open class WhisperKit {

if let textDecoder = textDecoder as? WhisperMLModel {
Logging.debug("Loading text decoder")
let decoderLoadStart = modelLoadStart //CFAbsoluteTimeGetCurrent()
let decoderLoadStart = CFAbsoluteTimeGetCurrent()
try await textDecoder.loadModel(
at: decoderUrl,
computeUnits: modelCompute.textDecoderCompute,
prewarmMode: prewarmMode
)
let decoderLoadEnd = CFAbsoluteTimeGetCurrent()

currentTimings.decoderLoadTime = (decoderLoadEnd - decoderLoadStart)
currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart

Logging.debug("Loaded text decoder")
Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s")
}

if let audioEncoder = audioEncoder as? WhisperMLModel {
Logging.debug("Loading audio encoder")
let encoderLoadStart = modelLoadStart // CFAbsoluteTimeGetCurrent()
let encoderLoadStart = CFAbsoluteTimeGetCurrent()

try await audioEncoder.loadModel(
at: encoderUrl,
computeUnits: modelCompute.audioEncoderCompute,
prewarmMode: prewarmMode
)
let encoderLoadEnd = CFAbsoluteTimeGetCurrent()
currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart

currentTimings.encoderLoadTime = encoderLoadEnd - encoderLoadStart - currentTimings.decoderLoadTime

Logging.debug("Loaded audio encoder")
Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s")
}

if prewarmMode {
modelState = .prewarmed
currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart
currentTimings.prewarmLoadTime = CFAbsoluteTimeGetCurrent() - modelLoadStart
return
}

Expand All @@ -363,20 +334,24 @@ open class WhisperKit {
textDecoder.isModelMultilingual = isModelMultilingual(logitsDim: logitsDim)
modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim)
Logging.debug("Loading tokenizer for \(modelVariant)")
let tokenizerLoadStart = CFAbsoluteTimeGetCurrent()

let tokenizer = try await loadTokenizer(
for: modelVariant,
tokenizerFolder: tokenizerFolder,
useBackgroundSession: useBackgroundDownloadSession
)
currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - tokenizerLoadStart

self.tokenizer = tokenizer
textDecoder.tokenizer = tokenizer
Logging.debug("Loaded tokenizer")
Logging.debug("Loaded tokenizer in \(String(format: "%.2f", currentTimings.tokenizerLoadTime))s")

modelState = .loaded

currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart
currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.prewarmLoadTime

Logging.info("Loaded models for whisper size: \(modelVariant)")
Logging.info("Loaded models for whisper size: \(modelVariant) in \(String(format: "%.2f", currentTimings.modelLoading))s")
}

public func unloadModels() async {
Expand Down
1 change: 1 addition & 0 deletions Sources/WhisperKitCLI/TranscribeCLI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ struct TranscribeCLI: AsyncParsableCommand {
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug,
prewarm: false,
load: true,
useBackgroundDownloadSession: false
)
Expand Down

0 comments on commit fde66a8

Please sign in to comment.