From c2a324a42d1cf63dd3d0ed3c48f828da4ed01417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 12 Dec 2023 00:47:12 +0700 Subject: [PATCH 1/2] Refactor attention implementation --- .formatter.exs | 2 +- lib/bumblebee/layers.ex | 218 +++++++++++++++++++--------- lib/bumblebee/layers/decoder.ex | 58 -------- lib/bumblebee/layers/transformer.ex | 40 ++--- mix.lock | 6 +- 5 files changed, 172 insertions(+), 152 deletions(-) diff --git a/.formatter.exs b/.formatter.exs index c185fe18..e7f3f46f 100644 --- a/.formatter.exs +++ b/.formatter.exs @@ -1,4 +1,4 @@ -# Used by "mix format" [ + import_deps: [:nx], inputs: ["{mix,.formatter}.exs", "{config,lib,test,examples}/**/*.{ex,exs}"] ] diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 7a1a6f88..1c55aa71 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -56,33 +56,6 @@ defmodule Bumblebee.Layers do input * Nx.sigmoid(1.702 * input) end - @doc """ - Expands an attention mask of shape `{batch_size, sequence_length}` to - a full mask. - """ - def expand_attention_mask(attention_mask) do - Axon.nx(attention_mask, fn attention_mask -> - attention_mask - |> Nx.new_axis(-2) - |> Nx.new_axis(-2) - end) - end - - @doc """ - Converts attention mask to bias. - """ - def attention_bias(attention_mask) do - attention_mask - |> Axon.optional() - |> Axon.nx(fn - %Axon.None{} -> - Nx.tensor(0) - - attention_mask -> - Nx.select(Nx.greater(attention_mask, 0), 0, -1.0e10) - end) - end - @doc """ Computes relative attention bias. """ @@ -130,7 +103,8 @@ defmodule Bumblebee.Layers do end defnp compute_relative_position_buckets(query, key, attention_cache, opts \\ []) do - opts = keyword!(opts, mode: :train, bidirectional: true, num_buckets: 32, max_distance: 128) + opts = + keyword!(opts, mode: :inference, bidirectional: true, num_buckets: 32, max_distance: 128) {key_length, query_length} = key_query_lengths(query, key, attention_cache) @@ -191,71 +165,185 @@ defmodule Bumblebee.Layers do end end - @doc """ - Computes attention weights. + @doc ~S""" + Computes scaled dot-product attention for multiple attention heads. + + This is the core calculation behind multi-head attention, the projection + layers should be applied on top of this layer. + + Given input sequences $Q, K, V \in R^{N \times d}$, where $N$ is the + sequence length and $d$ is the head dimension, the scaled dot-product + attention is defined as: + + $$ + Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d}})V + $$ + + This operations is further batched across multiple heads and multiple + input sequences. + + Intuitively scaled dot-product attention can be thought of as information + retrieval, where for each sequence element in $Q$ the objective is + to extract relevant context from sequence elements in $V$. In this + analogy, $K$ is the summarization of information, while $V$ is the + actual information. Then, assuming $Q$ and $K$ are embedded into a + common space (which is the job of prior projection layers), the + $QK^T$ dot product is a cosine similarity and gives us relevance + weights for sequence elements in $V$. + + In case of self-attention, where $Q, K, V$ originate from the same + sequence, the $QK^T$ weights indicate how much "each word attends + to other words". + + ## Parameter Shapes + + * `query` - `{batch_size, sequence_length, num_heads, head_size}` + * `key` - `{batch_size, kv_sequence_length, num_heads, head_size}` + * `value` - `{batch_size, kv_sequence_length, num_heads, head_size}` + * `key_mask` (optional) - `{batch_size, kv_sequence_length}` + * `head_mask` (optional) - `{num_heads}` + * `bias` (optional) - `{batch_size | 1, num_heads | 1, sequence_length, kv_sequence_length}` + * `offset` (optional) - `{}` + + ## Output Shape + + `{batch_size, sequence_length, num_heads, head_size}` ## Options - * `:scale` - whether to scale the weights. Defaults to `true` + * `:causal` - whether to apply causal mask to attention weights. + This is typically used for next token prediction and it + effectively makes each input token use information exclusively + from prior tokens. Defaults to `false` + + * `:scale` - whether to scale attention weights by $\frac{1}{\sqrt{d}}$. + Defaults to `true` + + * `:dropout_rate` - the dropout rate for attention weights dropout. + Defaults to `0.0` + + ## References + + * [Attention Is All You Need](https://arxiv.org/abs/1706.03762), Figure 2 (left) """ - def attention_weights(query, key, bias, opts \\ []) do - Axon.layer(&attention_weights_impl/4, [query, key, bias], opts) + def attention(query, key, value, key_mask, head_mask, bias, offset, opts \\ []) do + opts = Keyword.validate!(opts, causal: false, scale: true, dropout_rate: 0.0) + + weights = + Axon.layer( + &attention_weights_impl/7, + [ + query, + key, + Axon.optional(key_mask), + Axon.optional(head_mask), + Axon.optional(bias), + Axon.optional(offset) + ], + causal: opts[:causal], + scale: opts[:scale] + ) + |> Axon.dropout(rate: opts[:dropout_rate]) + + output = Axon.layer(&attention_output_impl/3, [weights, value], opts) + + {output, weights} end - defnp attention_weights_impl(query, key, bias, opts \\ []) do - opts = keyword!(opts, mode: :train, scale: true) + defnp attention_weights_impl(query, key, key_mask, head_mask, bias, offset, opts \\ []) do + opts = keyword!(opts, mode: :inference, scale: true, causal: false) - key = Nx.transpose(key, axes: [0, 2, 1, 3]) query = Nx.transpose(query, axes: [0, 2, 1, 3]) + key = Nx.transpose(key, axes: [0, 2, 1, 3]) weights = Nx.dot(query, [3], [0, 1], key, [3], [0, 1]) weights = if opts[:scale] do depth = Nx.axis_size(query, -1) - weights / Nx.sqrt(depth) + weights / Nx.as_type(Nx.sqrt(depth), Nx.type(query)) else weights end + key_mask = + case key_mask do + %Axon.None{} -> Nx.broadcast(1, {1, 1, 1, 1}) + key_mask -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1) + end + + causal_mask = + if opts[:causal] do + query_sequence_length = Nx.axis_size(query, 2) + key_sequence_length = Nx.axis_size(key, 2) + offset = ensure_offset(offset) + + Nx.greater_equal( + Nx.iota({query_sequence_length, 1}) + offset, + Nx.iota({1, key_sequence_length}) + ) + |> Nx.new_axis(0) + |> Nx.new_axis(0) + else + Nx.broadcast(1, {1, 1, 1, 1}) + end + + mask = Nx.logical_and(key_mask, causal_mask) + + bias = + case bias do + %Axon.None{} -> + Nx.select( + mask, + Nx.tensor(0.0, type: Nx.type(query)), + Nx.Constants.min_finite(Nx.type(query)) + ) + + bias -> + Nx.select( + Nx.broadcast(mask, max_shape(mask, bias)), + bias, + Nx.Constants.min_finite(Nx.type(query)) + ) + end + weights = weights + bias - Axon.Activations.softmax(weights, axis: -1) - end - @doc """ - Computes attention outputs. - """ - def attention_output(attention_weights, value) do - Axon.layer(&attention_output_impl/3, [attention_weights, value]) + weights = Axon.Activations.softmax(weights, axis: -1) + + case head_mask do + %Axon.None{} -> + weights + + head_mask -> + head_mask = Nx.reshape(head_mask, {1, :auto, 1, 1}) + Nx.multiply(weights, head_mask) + end end - defnp attention_output_impl(attention_weights, value, _opts \\ []) do + defnp attention_output_impl(weights, value, _opts \\ []) do value = Nx.transpose(value, axes: [0, 2, 1, 3]) - out = Nx.dot(attention_weights, [3], [0, 1], value, [2], [0, 1]) + out = Nx.dot(weights, [3], [0, 1], value, [2], [0, 1]) Nx.transpose(out, axes: [0, 2, 1, 3]) end - @doc """ - Applies head mask to the given attention weights. - - This layer expects computed attention weights and an optional mask. - If the mask is not specified, it will skip masking altogether. - """ - def apply_attention_head_mask(attention_weights, head_mask) do - if_present head_mask do - Axon.layer( - fn attention_weights, head_mask, _ -> - head_mask = Nx.reshape(head_mask, {1, :auto, 1, 1}) - Nx.multiply(attention_weights, head_mask) - end, - [attention_weights, head_mask] - ) - else - attention_weights + defnp ensure_offset(offset) do + case offset do + %Axon.None{} -> 0 + offset -> offset end end + deftransformp max_shape(left, right) do + Enum.zip_with( + Tuple.to_list(Nx.shape(left)), + Tuple.to_list(Nx.shape(right)), + &max/2 + ) + |> List.to_tuple() + end + @doc """ Adds a dense layer to the network. @@ -1063,8 +1151,8 @@ defmodule Bumblebee.Layers do position_ids = Nx.as_type(position_ids, :s64) - cos = cos |> Nx.take(position_ids) |> Nx.new_axis(2) - sin = sin |> Nx.take(position_ids) |> Nx.new_axis(2) + cos = cos |> Nx.take(position_ids) |> Nx.new_axis(2) |> Nx.as_type(Nx.type(query)) + sin = sin |> Nx.take(position_ids) |> Nx.new_axis(2) |> Nx.as_type(Nx.type(query)) rotated_query = query * cos + rotate_half(query) * sin rotated_key = key * cos + rotate_half(key) * sin diff --git a/lib/bumblebee/layers/decoder.ex b/lib/bumblebee/layers/decoder.ex index b1cfe6ce..e08c7c0c 100644 --- a/lib/bumblebee/layers/decoder.ex +++ b/lib/bumblebee/layers/decoder.ex @@ -243,62 +243,4 @@ defmodule Bumblebee.Layers.Decoder do [cache, input_embeddings] ) end - - @doc """ - Builds a causal mask and combines it with the given attention mask. - - A causal mask is used to mask bidirectional self-attention, such - that it works in a single direction. - - Accepts an optional offset, which should be set when passing a - partial query. - """ - def apply_causal_mask(attention_mask, query, offset) do - Axon.layer( - fn - %Axon.None{}, query, %Axon.None{}, _opts -> - # The default attention mask would be all ones (matching - # the batch size and sequence length in query), so we can - # skip it altogether - sequence_length = Nx.axis_size(query, 1) - build_causal_mask(Nx.broadcast(1, {1, sequence_length})) - - attention_mask, query, offset, _opts -> - sequence_length = Nx.axis_size(attention_mask, -1) - - # We generate a full causal mask, then slice it in case of - # iterative decoding - causal_mask = build_causal_mask(Nx.broadcast(1, {1, sequence_length})) - - causal_mask = - case offset do - %Axon.None{} -> - causal_mask - - offset -> - mask_shift = offset - query_length = Nx.axis_size(query, 1) - Nx.slice_along_axis(causal_mask, mask_shift, query_length, axis: 2) - end - - Nx.logical_and(attention_mask, causal_mask) - end, - [Axon.optional(attention_mask), query, Axon.optional(offset)] - ) - end - - defnp build_causal_mask(input) do - size = Nx.axis_size(input, -1) - idx = Nx.iota({size}) |> Nx.broadcast(input) - build_attention_mask(idx, idx) - end - - # Expects a batched, flat inputs of length corresponding to query - # and key length respectively. - defnp build_attention_mask(query_input, key_input) do - query_input - |> Nx.new_axis(-1) - |> Nx.greater_equal(Nx.new_axis(key_input, -2)) - |> Nx.new_axis(-3) - end end diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 13107ac2..b5b4a6dc 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -710,7 +710,7 @@ defmodule Bumblebee.Layers.Transformer do ## References - * [Attention Is All You Need](https://arxiv.org/abs/1706.03762), Figure 2 + * [Attention Is All You Need](https://arxiv.org/abs/1706.03762), Figure 2 (right) """ def multi_head_attention(query, key, value, opts) do @@ -847,17 +847,6 @@ defmodule Bumblebee.Layers.Transformer do {key, value, attention_cache} = Layers.Decoder.cached_attention_key_values(key, value, attention_cache, offset) - attention_mask = Layers.expand_attention_mask(attention_mask) - - attention_mask = - if causal do - Layers.Decoder.apply_causal_mask(attention_mask, query, offset) - else - attention_mask - end - - attention_bias = Layers.attention_bias(attention_mask) - attention_relative_bias = case attention_relative_bias do %Axon{} -> @@ -876,21 +865,22 @@ defmodule Bumblebee.Layers.Transformer do ) end - attention_bias = - Layers.if_present attention_relative_bias do - Axon.add(attention_bias, attention_relative_bias) - else - attention_bias - end - - attention_weights = - Layers.attention_weights(query, key, attention_bias, scale: scale_attention_weights) - |> Axon.dropout(rate: dropout_rate) - |> Layers.apply_attention_head_mask(attention_head_mask) + {attention_output, attention_weights} = + Layers.attention( + query, + key, + value, + attention_mask, + attention_head_mask, + attention_relative_bias, + offset, + scale: scale_attention_weights, + causal: causal, + dropout_rate: dropout_rate + ) attention_output = - attention_weights - |> Layers.attention_output(value) + attention_output |> Layers.flatten_trailing() |> Axon.dense(hidden_size, kernel_initializer: kernel_initializer, diff --git a/mix.lock b/mix.lock index 232e4bb2..20e35911 100644 --- a/mix.lock +++ b/mix.lock @@ -12,14 +12,14 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.33", "3c3fd9673bb5dcc9edc28dd90f50c87ce506d1f71b70e3de69aa8154bc695d44", [:mix], [], "hexpm", "2d526833729b59b9fdb85785078697c72ac5e5066350663e5be6a1182da61b8f"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"}, - "exla": {:git, "https://github.com/elixir-nx/nx.git", "9f2854d860e7520be5ed427cd8c0bfca087ddc51", [sparse: "exla"]}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "07e95ad34883dac7e6ca5a2d41cfbe19aa6477b1", [sparse: "exla"]}, "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "9f2854d860e7520be5ed427cd8c0bfca087ddc51", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "07e95ad34883dac7e6ca5a2d41cfbe19aa6477b1", [sparse: "nx"]}, "nx_image": {:hex, :nx_image, "0.1.1", "69cf0d2fd873d12b028583aa49b5e0a25f6aca307afc337a5d871851a20fba1d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "55c8206a822237f6027168f11214e3887263c5b8a1f8e0634eea82c96e5093e3"}, "nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"}, "plug": {:hex, :plug, "1.14.2", "cff7d4ec45b4ae176a227acd94a7ab536d9b37b942c8e8fa6dfc0fff98ff4d80", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "842fc50187e13cf4ac3b253d47d9474ed6c296a8732752835ce4a86acdf68d13"}, @@ -33,7 +33,7 @@ "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, - "torchx": {:git, "https://github.com/elixir-nx/nx.git", "9f2854d860e7520be5ed427cd8c0bfca087ddc51", [sparse: "torchx"]}, + "torchx": {:git, "https://github.com/elixir-nx/nx.git", "07e95ad34883dac7e6ca5a2d41cfbe19aa6477b1", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"}, "xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"}, From 7997735a3cddbd6d3c82b2fceb04c4be3740e1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 12 Dec 2023 14:12:00 +0700 Subject: [PATCH 2/2] Update mix.exs --- mix.exs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mix.exs b/mix.exs index 013a67a4..9e721d80 100644 --- a/mix.exs +++ b/mix.exs @@ -30,7 +30,7 @@ defmodule Bumblebee.MixProject do defp deps do [ - # {:axon, "~> 0.6.0", axon_opts()}, + # {:axon, "~> 0.6.0"}, {:axon, github: "elixir-nx/axon", override: true}, {:tokenizers, "~> 0.4"}, # {:nx, "~> 0.6.2"}, @@ -53,14 +53,6 @@ defmodule Bumblebee.MixProject do ] end - defp axon_opts do - if path = System.get_env("AXON_PATH") do - [path: path] - else - [] - end - end - defp docs do [ main: "Bumblebee",