Skip to content

Commit

Permalink
test: test we can load real model with safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Aug 3, 2023
1 parent fbbf02b commit 5e2f208
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
36 changes: 36 additions & 0 deletions test/bumblebee/audio/whisper_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,42 @@ defmodule Bumblebee.Text.WhisperTest do
)
end

test "base model with safetensors" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "openai/whisper-tiny"},
architecture: :base,
params_filename: "model.safetensors"
)

assert %Bumblebee.Audio.Whisper{architecture: :base} = spec

input_features = Nx.sin(Nx.iota({1, 3000, 80}, type: :f32))
decoder_input_ids = Nx.tensor([[50258, 50259, 50359, 50363]])

inputs = %{
"input_features" => input_features,
"decoder_input_ids" => decoder_input_ids
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 4, 384}

assert_all_close(
outputs.hidden_state[[.., .., 1..3]],
Nx.tensor([
[
[9.1349, 0.5695, 8.7758],
[0.0160, -7.0785, 1.1313],
[6.1074, -2.0481, -1.5687],
[5.6247, -10.3924, 7.2008]
]
]),
atol: 1.0e-4
)
end

test "for conditional generation model" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "openai/whisper-tiny"})
Expand Down
12 changes: 12 additions & 0 deletions test/bumblebee_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,17 @@ defmodule BumblebeeTest do

assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params))
end

test "supports .safetensors params file" do
assert {:ok, %{params: params}} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})

assert {:ok, %{params: safetensors_params}} =
Bumblebee.load_model(
{:hf, "openai/whisper-tiny"},
params_filename: "model.safetensors"
)

assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(safetensors_params))
end
end
end

0 comments on commit 5e2f208

Please sign in to comment.