From f37e6aa58cdcc494fbdb33d79bb94cebe14d975e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Fri, 8 Dec 2023 23:24:18 +0700 Subject: [PATCH] Fix cache offset casting with low precision policies --- lib/bumblebee/audio/whisper.ex | 2 +- lib/bumblebee/layers.ex | 2 +- lib/bumblebee/layers/decoder.ex | 8 ++++---- mix.exs | 3 ++- mix.lock | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/bumblebee/audio/whisper.ex b/lib/bumblebee/audio/whisper.ex index a6482319..ef591046 100644 --- a/lib/bumblebee/audio/whisper.ex +++ b/lib/bumblebee/audio/whisper.ex @@ -398,7 +398,7 @@ defmodule Bumblebee.Audio.Whisper do offset = case offset do %Axon.None{} -> 0 - offset -> Nx.as_type(offset, {:s, 64}) + offset -> offset end input_sequence_length = Nx.axis_size(input_embeddings, 1) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 6a7bba06..7a1a6f88 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -120,7 +120,7 @@ defmodule Bumblebee.Layers do bias offset -> - mask_shift = Nx.as_type(offset, {:s, 64}) + mask_shift = offset query_length = Nx.axis_size(query, 1) Nx.slice_along_axis(bias, mask_shift, query_length, axis: 2) end diff --git a/lib/bumblebee/layers/decoder.ex b/lib/bumblebee/layers/decoder.ex index e77a009d..b1cfe6ce 100644 --- a/lib/bumblebee/layers/decoder.ex +++ b/lib/bumblebee/layers/decoder.ex @@ -82,9 +82,9 @@ defmodule Bumblebee.Layers.Decoder do |> List.duplicate(decoder_num_blocks) |> List.to_tuple() - offset = Nx.tensor(0.0) + offset = Nx.tensor(0) - attention_mask = Nx.broadcast(0.0, {batch_size, max_length}) + attention_mask = Nx.broadcast(0, {batch_size, max_length}) %{blocks: blocks, offset: offset, attention_mask: attention_mask} end @@ -170,7 +170,7 @@ defmodule Bumblebee.Layers.Decoder do defnp append_attention_cache(key, value, attention_cache, offset, _opts \\ []) do %{key: cached_key, value: cached_value} = attention_cache - indices = [0, Nx.as_type(offset, {:s, 64}), 0, 0] + indices = [0, offset, 0, 0] key = Nx.put_slice(cached_key, indices, key) value = Nx.put_slice(cached_value, indices, value) updated_cache = %{key: key, value: value} @@ -276,7 +276,7 @@ defmodule Bumblebee.Layers.Decoder do causal_mask offset -> - mask_shift = Nx.as_type(offset, {:s, 64}) + mask_shift = offset query_length = Nx.axis_size(query, 1) Nx.slice_along_axis(causal_mask, mask_shift, query_length, axis: 2) end diff --git a/mix.exs b/mix.exs index 796171ac..013a67a4 100644 --- a/mix.exs +++ b/mix.exs @@ -30,7 +30,8 @@ defmodule Bumblebee.MixProject do defp deps do [ - {:axon, "~> 0.6.0", axon_opts()}, + # {:axon, "~> 0.6.0", axon_opts()}, + {:axon, github: "elixir-nx/axon", override: true}, {:tokenizers, "~> 0.4"}, # {:nx, "~> 0.6.2"}, # {:exla, ">= 0.0.0", only: [:dev, :test]}, diff --git a/mix.lock b/mix.lock index 64c53b22..232e4bb2 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "axon": {:hex, :axon, "0.6.0", "fd7560079581e4cedebaf0cd5f741d6ac3516d06f204ebaf1283b1093bf66ff6", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "204e7aeb50d231a30b25456adf17bfbaae33fe7c085e03793357ac3bf62fd853"}, + "axon": {:git, "https://github.com/elixir-nx/axon.git", "67b48c7a43438f5eec2a35311572565cafe889d7", []}, "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"},