Skip to content

Commit

Permalink
Fix cache offset casting with low precision policies (#299)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Dec 8, 2023
1 parent 21832f3 commit e726e26
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ defmodule Bumblebee.Audio.Whisper do
offset =
case offset do
%Axon.None{} -> 0
offset -> Nx.as_type(offset, {:s, 64})
offset -> offset
end

input_sequence_length = Nx.axis_size(input_embeddings, 1)
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ defmodule Bumblebee.Layers do
bias

offset ->
mask_shift = Nx.as_type(offset, {:s, 64})
mask_shift = offset
query_length = Nx.axis_size(query, 1)
Nx.slice_along_axis(bias, mask_shift, query_length, axis: 2)
end
Expand Down
8 changes: 4 additions & 4 deletions lib/bumblebee/layers/decoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ defmodule Bumblebee.Layers.Decoder do
|> List.duplicate(decoder_num_blocks)
|> List.to_tuple()

offset = Nx.tensor(0.0)
offset = Nx.tensor(0)

attention_mask = Nx.broadcast(0.0, {batch_size, max_length})
attention_mask = Nx.broadcast(0, {batch_size, max_length})

%{blocks: blocks, offset: offset, attention_mask: attention_mask}
end
Expand Down Expand Up @@ -170,7 +170,7 @@ defmodule Bumblebee.Layers.Decoder do

defnp append_attention_cache(key, value, attention_cache, offset, _opts \\ []) do
%{key: cached_key, value: cached_value} = attention_cache
indices = [0, Nx.as_type(offset, {:s, 64}), 0, 0]
indices = [0, offset, 0, 0]
key = Nx.put_slice(cached_key, indices, key)
value = Nx.put_slice(cached_value, indices, value)
updated_cache = %{key: key, value: value}
Expand Down Expand Up @@ -276,7 +276,7 @@ defmodule Bumblebee.Layers.Decoder do
causal_mask

offset ->
mask_shift = Nx.as_type(offset, {:s, 64})
mask_shift = offset
query_length = Nx.axis_size(query, 1)
Nx.slice_along_axis(causal_mask, mask_shift, query_length, axis: 2)
end
Expand Down
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ defmodule Bumblebee.MixProject do

defp deps do
[
{:axon, "~> 0.6.0", axon_opts()},
# {:axon, "~> 0.6.0", axon_opts()},
{:axon, github: "elixir-nx/axon", override: true},
{:tokenizers, "~> 0.4"},
# {:nx, "~> 0.6.2"},
# {:exla, ">= 0.0.0", only: [:dev, :test]},
Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
%{
"axon": {:hex, :axon, "0.6.0", "fd7560079581e4cedebaf0cd5f741d6ac3516d06f204ebaf1283b1093bf66ff6", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "204e7aeb50d231a30b25456adf17bfbaae33fe7c085e03793357ac3bf62fd853"},
"axon": {:git, "https://github.com/elixir-nx/axon.git", "67b48c7a43438f5eec2a35311572565cafe889d7", []},
"bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"},
"castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"},
"cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"},
Expand Down

0 comments on commit e726e26

Please sign in to comment.