From fde66a82a1f875c9720320e4dfe2a7d5ed8a2049 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 10 Aug 2024 12:23:51 -0700 Subject: [PATCH] Add helper for mlpackage detection --- Sources/WhisperKit/Core/Models.swift | 12 ++++- Sources/WhisperKit/Core/Utils.swift | 18 +++++++ Sources/WhisperKit/Core/WhisperKit.swift | 63 +++++++---------------- Sources/WhisperKitCLI/TranscribeCLI.swift | 1 + 4 files changed, 48 insertions(+), 46 deletions(-) diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 7bbad5d..b564695 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -536,8 +536,6 @@ public struct TranscriptionResult: Codable { // NOTE: this is a relative value for percentage calculations let fullDecodingDuration = max(timings.decodingLoop, timings.fullPipeline) * 1000 // Convert to milliseconds - let encoderLoadTime = timings.encoderLoadTime - let decoderLoadTime = timings.decoderLoadTime let audioLoadTime = formatTimeWithPercentage(timings.audioLoading, 1, fullDecodingDuration) let audioProcTime = formatTimeWithPercentage(timings.audioProcessing, timings.totalAudioProcessingRuns, fullDecodingDuration) let logmelsTime = formatTimeWithPercentage(timings.logmels, timings.totalLogmelRuns, fullDecodingDuration) @@ -574,6 +572,10 @@ public struct TranscriptionResult: Codable { Decoding Full Loop: \(decodingLoopInfo) ------------------------------- Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds + - Prewarm: \(String(format: "%.2f", timings.prewarmLoadTime)) seconds + - Encoder: \(String(format: "%.2f", timings.encoderLoadTime)) seconds + - Decoder: \(String(format: "%.2f", timings.decoderLoadTime)) seconds + - Tokenizer: \(String(format: "%.2f", timings.tokenizerLoadTime)) seconds Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds - Decoding Loop (Avg/window): \(String(format: "%.2f", decodeTimePerWindow)) seconds - Audio Windows: \(String(format: "%.2f", timings.totalAudioProcessingRuns)) @@ -652,8 +654,10 @@ public struct TranscriptionTimings: Codable { public var firstTokenTime: CFAbsoluteTime public var inputAudioSeconds: TimeInterval public var modelLoading: TimeInterval + public var prewarmLoadTime: TimeInterval public var encoderLoadTime: TimeInterval public var decoderLoadTime: TimeInterval + public var tokenizerLoadTime: TimeInterval public var audioLoading: TimeInterval public var audioProcessing: TimeInterval public var logmels: TimeInterval @@ -694,8 +698,10 @@ public struct TranscriptionTimings: Codable { /// Initialize with all time intervals set to zero. public init(modelLoading: TimeInterval = 0, + prewarmLoadTime: TimeInterval = 0, encoderLoadTime: TimeInterval = 0, decoderLoadTime: TimeInterval = 0, + tokenizerLoadTime: TimeInterval = 0, audioLoading: TimeInterval = 0, audioProcessing: TimeInterval = 0, logmels: TimeInterval = 0, @@ -725,8 +731,10 @@ public struct TranscriptionTimings: Codable { self.firstTokenTime = Double.greatestFiniteMagnitude self.inputAudioSeconds = 0.001 self.modelLoading = modelLoading + self.prewarmLoadTime = prewarmLoadTime self.encoderLoadTime = encoderLoadTime self.decoderLoadTime = decoderLoadTime + self.tokenizerLoadTime = tokenizerLoadTime self.audioLoading = audioLoading self.audioProcessing = audioProcessing self.logmels = logmels diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index f9ff449..561ddd9 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -482,6 +482,22 @@ public func modelSupport(for deviceName: String) -> (default: String, disabled: return ("openai_whisper-base", [""]) } +public func detectModelURL(inFolder path: URL, named modelName: String) -> URL { + let compiledUrl = path.appending(path: "\(modelName).mlmodelc") + let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel") + + let compiledModelExists: Bool = FileManager.default.fileExists(atPath: compiledUrl.path) + let packageModelExists: Bool = FileManager.default.fileExists(atPath: packageUrl.path) + + // Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc + var modelURL = compiledUrl + if (packageModelExists && !compiledModelExists) { + modelURL = packageUrl + } + + return modelURL +} + public func resolveAbsolutePath(_ inputPath: String) -> String { let fileManager = FileManager.default @@ -621,8 +637,10 @@ public func mergeTranscriptionResults(_ results: [TranscriptionResult?], confirm // Update the merged timings with non-overlapping time values var mergedTimings = TranscriptionTimings( modelLoading: validResults.map { $0.timings.modelLoading }.max() ?? 0, + prewarmLoadTime: validResults.map { $0.timings.prewarmLoadTime }.max() ?? 0, encoderLoadTime: validResults.map { $0.timings.encoderLoadTime }.max() ?? 0, decoderLoadTime: validResults.map { $0.timings.decoderLoadTime }.max() ?? 0, + tokenizerLoadTime: validResults.map { $0.timings.tokenizerLoadTime }.max() ?? 0, audioLoading: validResults.map { $0.timings.audioLoading }.reduce(0, +), audioProcessing: validResults.map { $0.timings.audioProcessing }.reduce(0, +), logmels: validResults.map { $0.timings.logmels }.reduce(0, +), diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 141b36b..1a6837c 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -261,36 +261,11 @@ open class WhisperKit { Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)") - var logmelUrl = path.appending(path: "MelSpectrogram.mlmodelc") - var encoderUrl = path.appending(path: "AudioEncoder.mlmodelc") - var decoderUrl = path.appending(path: "TextDecoder.mlmodelc") - var decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") - - var logmelMLPackageUrl = path.appending(path: "MelSpectrogram.mlpackage/Data/com.apple.CoreML/model.mlmodel") - var encoderMLPackageUrl = path.appending(path: "AudioEncoder.mlpackage/Data/com.apple.CoreML/model.mlmodel") - var decoderMLPackageUrl = path.appending(path: "TextDecoder.mlpackage/Data/com.apple.CoreML/model.mlmodel") - var decoderPrefillMLPackageUrl = path.appending(path: "TextDecoderContextPrefill/Data/com.apple.CoreML/model.mlmodel") - - let encoderMLModelc: Bool = FileManager.default.fileExists(atPath: encoderUrl.path) - let encoderMLPackage: Bool = FileManager.default.fileExists(atPath: encoderMLPackageUrl.path) - - let decoderMLModelc: Bool = FileManager.default.fileExists(atPath: decoderUrl.path) - let decoderMLPackage: Bool = FileManager.default.fileExists(atPath: decoderMLPackageUrl.path) - - let logmelMLModelc: Bool = FileManager.default.fileExists(atPath: logmelUrl.path) - let logmelMLPackage: Bool = FileManager.default.fileExists(atPath: logmelMLPackageUrl.path) - - // Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc - let swapURLIfTrue: (Bool, Bool, URL, URL) -> URL = { foundMLPackage, foundMLModelc, mlPackageURL, mlModelcUrl in - if (foundMLPackage && !foundMLModelc) { - return mlPackageURL - } - return mlModelcUrl - } - - encoderUrl = swapURLIfTrue(encoderMLPackage, encoderMLModelc, encoderMLPackageUrl, encoderUrl) - decoderUrl = swapURLIfTrue(decoderMLPackage, decoderMLModelc, decoderMLPackageUrl, decoderUrl) - logmelUrl = swapURLIfTrue(logmelMLPackage, logmelMLModelc, logmelMLPackageUrl, logmelUrl) + // Find either mlmodelc or mlpackage models + let logmelUrl = detectModelURL(inFolder: path, named: "MelSpectrogram") + let encoderUrl = detectModelURL(inFolder: path, named: "AudioEncoder") + let decoderUrl = detectModelURL(inFolder: path, named: "TextDecoder") + let decoderPrefillUrl = detectModelURL(inFolder: path, named: "TextDecoderContextPrefill") for item in [logmelUrl, encoderUrl, decoderUrl] { if !FileManager.default.fileExists(atPath: item.path) { @@ -321,38 +296,34 @@ open class WhisperKit { if let textDecoder = textDecoder as? WhisperMLModel { Logging.debug("Loading text decoder") - let decoderLoadStart = modelLoadStart //CFAbsoluteTimeGetCurrent() + let decoderLoadStart = CFAbsoluteTimeGetCurrent() try await textDecoder.loadModel( at: decoderUrl, computeUnits: modelCompute.textDecoderCompute, prewarmMode: prewarmMode ) - let decoderLoadEnd = CFAbsoluteTimeGetCurrent() - - currentTimings.decoderLoadTime = (decoderLoadEnd - decoderLoadStart) + currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart - Logging.debug("Loaded text decoder") + Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s") } if let audioEncoder = audioEncoder as? WhisperMLModel { Logging.debug("Loading audio encoder") - let encoderLoadStart = modelLoadStart // CFAbsoluteTimeGetCurrent() + let encoderLoadStart = CFAbsoluteTimeGetCurrent() try await audioEncoder.loadModel( at: encoderUrl, computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) - let encoderLoadEnd = CFAbsoluteTimeGetCurrent() + currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart - currentTimings.encoderLoadTime = encoderLoadEnd - encoderLoadStart - currentTimings.decoderLoadTime - - Logging.debug("Loaded audio encoder") + Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s") } if prewarmMode { modelState = .prewarmed - currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.prewarmLoadTime = CFAbsoluteTimeGetCurrent() - modelLoadStart return } @@ -363,20 +334,24 @@ open class WhisperKit { textDecoder.isModelMultilingual = isModelMultilingual(logitsDim: logitsDim) modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim) Logging.debug("Loading tokenizer for \(modelVariant)") + let tokenizerLoadStart = CFAbsoluteTimeGetCurrent() + let tokenizer = try await loadTokenizer( for: modelVariant, tokenizerFolder: tokenizerFolder, useBackgroundSession: useBackgroundDownloadSession ) + currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - tokenizerLoadStart + self.tokenizer = tokenizer textDecoder.tokenizer = tokenizer - Logging.debug("Loaded tokenizer") + Logging.debug("Loaded tokenizer in \(String(format: "%.2f", currentTimings.tokenizerLoadTime))s") modelState = .loaded - currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.prewarmLoadTime - Logging.info("Loaded models for whisper size: \(modelVariant)") + Logging.info("Loaded models for whisper size: \(modelVariant) in \(String(format: "%.2f", currentTimings.modelLoading))s") } public func unloadModels() async { diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 578645c..d347fd2 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -312,6 +312,7 @@ struct TranscribeCLI: AsyncParsableCommand { computeOptions: computeOptions, verbose: cliArguments.verbose, logLevel: .debug, + prewarm: false, load: true, useBackgroundDownloadSession: false )