Skip to content

Commit

Permalink
Check text embedding output shape before applying pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 18, 2023
1 parent 6a139fa commit b374fa6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ defmodule Bumblebee.Text do
this option is ignored. Defaults to `:pooled_state`
* `:output_pool` - pooling to apply on top of the model output, in case
it is not already a pooled embedding. Supported values: `:mean`. By
default no pooling is applied
it is not already a pooled embedding. Supported values: `:mean_pooling`.
By default no pooling is applied
* `:embedding_processor` - a post-processing step to apply to the
embedding. Supported values: `:l2_norm`. By default the output is
Expand Down
10 changes: 10 additions & 0 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ defmodule Bumblebee.Text.TextEmbedding do
output

:mean_pooling ->
case Nx.rank(output) do
3 ->
:ok

rank ->
raise ArgumentError,
"expected the output tensor to have rank 3 with :output_pool is enabled, got: #{rank}." <>
" You should either disable pooling or pick a different output using :output_attribute"
end

input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)

output
Expand Down

0 comments on commit b374fa6

Please sign in to comment.