Skip to content

Commit

Permalink
- reverted mlx version
Browse files Browse the repository at this point in the history
- updated readme
- updated makefile
  • Loading branch information
jkrukowski committed Jul 21, 2024
1 parent 2f3780f commit fd52493
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 127 deletions.
13 changes: 5 additions & 8 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
- {
name: "macOS",
condition: true,
clean-destination: "generic/platform=macOS",
test-destination: "platform=macOS,arch=arm64",
test-cases: "-only-testing WhisperKitTests/UnitTests -only-testing WhisperKitMLXTests/MLXUnitTests",
mlx-disabled: "0",
Expand All @@ -29,7 +28,6 @@ jobs:
- {
name: "iOS",
condition: true,
clean-destination: "generic/platform=iOS",
test-destination: "platform=iOS Simulator,OS=${{ inputs.ios-version }},name=iPhone 15",
test-cases: "-only-testing WhisperKitTests/UnitTests",
mlx-disabled: "1",
Expand All @@ -38,7 +36,6 @@ jobs:
- {
name: "watchOS",
condition: "${{ inputs.macos-runner == 'macos-14' }}",
clean-destination: "generic/platform=watchOS",
test-destination: "platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)",
test-cases: "-only-testing WhisperKitTests/UnitTests",
mlx-disabled: "1",
Expand All @@ -47,7 +44,6 @@ jobs:
- {
name: "visionOS",
condition: "${{ inputs.macos-runner == 'macos-14' }}",
clean-destination: "generic/platform=visionOS",
test-destination: "platform=visionOS Simulator,name=Apple Vision Pro",
test-cases: "-only-testing WhisperKitTests/UnitTests",
mlx-disabled: "1",
Expand Down Expand Up @@ -79,8 +75,6 @@ jobs:
echo "Available schemes:"
xcodebuild -list
xcodebuild -downloadAllPlatforms
echo "Destinations for testing:"
export ${{ matrix.run-config['compiler-flags'] }} && xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -showdestinations -skipPackagePluginValidation
- name: Boot Simulator and Wait
if: ${{ matrix.run-config['name'] != 'macOS' }} && ${{ inputs.macos-runner == 'macos-14' }}
# Slower runners require some time to fully boot the simulator
Expand All @@ -96,5 +90,8 @@ jobs:
if: ${{ matrix.run-config['condition'] == true }}
run: |
set -o pipefail
xcodebuild clean build-for-testing -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['clean-destination'] }}' -skipPackagePluginValidation | xcpretty
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['test-destination'] }}' -skipPackagePluginValidation
xcodebuild clean build-for-testing test \
${{ matrix.run-config['test-cases'] }} \
-scheme ${{ matrix.run-config['scheme'] }} \
-destination '${{ matrix.run-config['test-destination'] }}' \
-skipPackagePluginValidation
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DerivedData/
**/*.xcscheme
.netrc
.env
/.vscode

# Core ML Model Files
Models
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
{
"originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e",
"originHash" : "829222b514832cb61fe0002e0eebda98f23a75169c63f7d6ed7a320d57d5318f",
"pins" : [
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"branch" : "main",
"revision" : "c11212bff42a1b88aea83811210d42a5f99440ad"
"revision" : "d6d9472da5bf7ec2654e8914bd1d15622f45b6a9"
}
},
{
Expand Down
34 changes: 27 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: setup setup-huggingface-cli setup-model-repo download-models download-model download-mlx-models download-mlx-model build build-cli test clean-package-caches
.PHONY: setup setup-huggingface-cli setup-model-repo download-models download-model download-mlx-models download-mlx-model build build-cli test mlx-test clean-package-caches

PIP_COMMAND := pip3
PYTHON_COMMAND := python3
Expand All @@ -7,9 +7,9 @@ PYTHON_COMMAND := python3
MODEL_REPO := argmaxinc/whisperkit-coreml
MLX_MODEL_REPO := argmaxinc/whisperkit-mlx

MODEL_REPO_DIR := ./Models/whisperkit-coreml
MLX_MODEL_REPO_DIR := ./Models/whisperkit-mlx
BASE_MODEL_DIR := ./Models
MODEL_REPO_DIR := ./Sources/WhisperKitTestsUtils/Models/whisperkit-coreml
MLX_MODEL_REPO_DIR := ./Sources/WhisperKitTestsUtils/Models/whisperkit-mlx
BASE_MODEL_DIR := ./Sources/WhisperKitTestsUtils/Models


setup:
Expand Down Expand Up @@ -60,6 +60,7 @@ setup-model-repo:
git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \
fi


setup-mlx-model-repo:
@echo "Setting up mlx repository..."
@mkdir -p $(BASE_MODEL_DIR)
Expand Down Expand Up @@ -111,19 +112,38 @@ download-mlx-model: setup-mlx-model-repo
git lfs pull --include="openai_whisper-$(MODEL)/*"
@echo "MLX model $(MODEL) downloaded to $(MLX_MODEL_REPO_DIR)/openai_whisper-mlx-$(MODEL)"


build:
@echo "Building WhisperKit..."
@swift build -v
@xcodebuild CLANG_ENABLE_CODE_COVERAGE=NO VALID_ARCHS=arm64 clean build \
-configuration Release \
-scheme whisperkit-Package \
-destination generic/platform=macOS \
-derivedDataPath .build/.xcodebuild/ \
-clonedSourcePackagesDirPath .build/ \
-skipPackagePluginValidation


build-cli:
@echo "Building WhisperKit CLI..."
@swift build -c release --product whisperkit-cli
@xcodebuild CLANG_ENABLE_CODE_COVERAGE=NO VALID_ARCHS=arm64 clean build \
-configuration Release \
-scheme whisperkit-cli \
-destination generic/platform=macOS \
-derivedDataPath .build/.xcodebuild/ \
-clonedSourcePackagesDirPath .build/ \
-skipPackagePluginValidation


test:
@echo "Running tests..."
@swift test -v
@xcodebuild clean build-for-testing test \
-scheme whisperkit-Package \
-only-testing WhisperKitMLXTests/MLXUnitTests \
-only-testing WhisperKitTests/UnitTests \
-destination 'platform=macOS,arch=arm64' \
-skipPackagePluginValidation


clean-package-caches:
@trash ~/Library/Caches/org.swift.swiftpm/repositories
Expand Down
3 changes: 1 addition & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b",
"version" : "0.16.0"
"revision" : "d6d9472da5bf7ec2654e8914bd1d15622f45b6a9"
}
},
{
Expand Down
32 changes: 11 additions & 21 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ func products() -> [PackageDescription.Product] {
}

func mlxProducts() -> [PackageDescription.Product] {
let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1"
if isMLXDisabled {
if isMLXDisabled() {
return []
} else {
return [
Expand All @@ -53,12 +52,11 @@ func dependencies() -> [PackageDescription.Package.Dependency] {
}

func mlxDependencies() -> [PackageDescription.Package.Dependency] {
let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1"
if isMLXDisabled {
if isMLXDisabled() {
return []
} else {
return [
.package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.16.0"),
.package(url: "https://github.com/ml-explore/mlx-swift", revision: "d6d9472da5bf7ec2654e8914bd1d15622f45b6a9"),
]
}
}
Expand All @@ -78,21 +76,10 @@ func targets() -> [PackageDescription.Target] {
"WhisperKit",
.product(name: "Transformers", package: "swift-transformers"),
],
path: ".",
exclude: [
"Examples",
"Sources/WhisperKit",
"Sources/WhisperKitCLI",
"Tests",
"Makefile",
"README.md",
"LICENSE",
"CONTRIBUTING.md",
],
resources: [
.copy("Models/whisperkit-coreml"),
.copy("Models/whisperkit-mlx"),
.process("Sources/WhisperKitTestsUtils/Resources")
.copy("Models/whisperkit-coreml/"),
.copy("Models/whisperkit-mlx/"),
.process("Resources")
]
),
.testTarget(
Expand All @@ -107,8 +94,7 @@ func targets() -> [PackageDescription.Target] {
}

func mlxTargets() -> [PackageDescription.Target] {
let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1"
if isMLXDisabled {
if isMLXDisabled() {
return []
} else {
return [
Expand Down Expand Up @@ -146,3 +132,7 @@ func mlxTargets() -> [PackageDescription.Target] {
]
}
}

func isMLXDisabled() -> Bool {
ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1"
}
66 changes: 35 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZ
- [Quick Example](#quick-example)
- [Model Selection](#model-selection)
- [Generating Models](#generating-models)
- [Swift CLI](#swift-cli)
- [Testing](#testing)
- [Contributing \& Roadmap](#contributing--roadmap)
- [License](#license)
- [Citation](#citation)
Expand Down Expand Up @@ -66,7 +66,7 @@ You can install `WhisperKit` command line app using [Homebrew](https://brew.sh)

```bash
brew install whisperkit-cli
```
```

## Getting Started

Expand All @@ -80,76 +80,80 @@ This example demonstrates how to transcribe a local audio file:
import WhisperKit

// Initialize WhisperKit with default settings
Task {
let pipe = try? await WhisperKit()
let transcription = try? await pipe!.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text
print(transcription)
}
let pipe = try await WhisperKit()
// Transcribe the audio file
let transcription = try await pipe.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text
// Print the transcription
print(transcription)
```

### Model Selection

WhisperKit automatically downloads the recommended model for the device if not specified. You can also select a specific model by passing in the model name:

```swift
let pipe = try? await WhisperKit(model: "large-v3")
let pipe = try await WhisperKit(model: "large-v3")
```

This method also supports glob search, so you can use wildcards to select a model:

```swift
let pipe = try? await WhisperKit(model: "distil*large-v3")
let pipe = try await WhisperKit(model: "distil*large-v3")
```

Note that the model search must return a single model from the source repo, otherwise an error will be thrown.

For a list of available models, see our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml).
For MLX models, see [here](https://huggingface.co/argmaxinc/whisperkit-mlx).

### Generating Models

WhisperKit also comes with the supporting repo [`whisperkittools`](https://github.com/argmaxinc/whisperkittools) which lets you create and deploy your own fine tuned versions of Whisper in CoreML format to HuggingFace. Once generated, they can be loaded by simply changing the repo name to the one used to upload the model:

```swift
let pipe = try? await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo")
let pipe = try await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo")
```

### Swift CLI
### Backend Selection

The Swift CLI allows for quick testing and debugging outside of an Xcode project. To install it, run the following:

```bash
git clone https://github.com/argmaxinc/whisperkit.git
cd whisperkit
```
WhisperKit supports both CoreML and MLX backends. By default, it uses CoreML, but you can switch some or all pipeline components to MLX.
Available pipeline components are:
- `featureExtractor`, `FeatureExtractor` is used by default, use `MLXFeatureExtractor` to switch to MLX
- `audioEncoder`, `AudioEncoder` is used by default, use `MLXAudioEncoder` to switch to MLX
- `textDecoder`, `TextDecoder` is used by default, use `MLXTextDecoder` to switch to MLX

Then, setup the environment and download your desired model.
Here is an example of how to switch the `featureExtractor` and `audioEncoder` to MLX and keep the `textDecoder` as CoreML:

```bash
make setup
make download-model MODEL=large-v3
```swift
let pipe = try await WhisperKit(
model: "tiny",
mlxModel: "tiny",
featureExtractor: MLXFeatureExtractor(),
audioEncoder: MLXAudioEncoder()
)
```

**Note**:

1. This will download only the model specified by `MODEL` (see what's available in our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml), where we use the prefix `openai_whisper-{MODEL}`)
2. Before running `download-model`, make sure [git-lfs](https://git-lfs.com) is installed
### Testing

If you would like download all available models to your local folder, use this command instead:
If you want to run the unit tests locally, first clone the repo:

```bash
make download-models
git clone https://github.com/argmaxinc/whisperkit.git
cd whisperkit
```

You can then run them via the CLI with:
download the required models:

```bash
swift run whisperkit-cli transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}"
make setup
make download-model MODEL=tiny
make download-mlx-model MODEL=tiny
```

Which should print a transcription of the audio file. If you would like to stream the audio directly from a microphone, use:
and then run the tests:

```bash
swift run whisperkit-cli transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --stream
make test
```

## Contributing & Roadmap
Expand Down
2 changes: 2 additions & 0 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
return getModelInputDimention(model, named: "encoder_output_embeds", position: 1)
}

public init() {}

/// Override default so we an unload the prefill data as well
public func unloadModel() {
model = nil
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ extension MLMultiArray {
/// - index: The index of the element
/// - strides: The precomputed strides of the multi-array, if not provided, it will be computed. It's a performance optimization to avoid recomputing the strides every time when accessing the multi-array with multiple indexes.
@inline(__always)
func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int {
public func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int {
var linearOffset = 0
let strideInts = strideInts ?? strides.map { $0.intValue }
for (dimension, stride) in zip(index, strideInts) {
Expand Down
Loading

0 comments on commit fd52493

Please sign in to comment.