diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index a110461c..5b8f1464 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -302,8 +302,8 @@ defmodule Bumblebee.Text do this option is ignored. Defaults to `:pooled_state` * `:output_pool` - pooling to apply on top of the model output, in case - it is not already a pooled embedding. Supported values: `:mean`. By - default no pooling is applied + it is not already a pooled embedding. Supported values: `:mean_pooling`. + By default no pooling is applied * `:embedding_processor` - a post-processing step to apply to the embedding. Supported values: `:l2_norm`. By default the output is diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index c112f015..7ec18558 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -62,6 +62,16 @@ defmodule Bumblebee.Text.TextEmbedding do output :mean_pooling -> + case Nx.rank(output) do + 3 -> + :ok + + rank -> + raise ArgumentError, + "expected the output tensor to have rank 3 with :output_pool is enabled, got: #{rank}." <> + " You should either disable pooling or pick a different output using :output_attribute" + end + input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1) output