Skip to content

Commit

Permalink
Add temperature to generation options
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Nov 17, 2023
1 parent 5773ccf commit 6384b4c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ defmodule Bumblebee.Text.Generation do
end,
if config.forced_token_ids do
&forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids)
end,
if config.temperature && config.temperature != 1.0 do
&temperature_processor(&1, &2, temperature: config.temperature)
end
] ++
if config.strategy.type == :multinomial_sampling do
Expand Down
7 changes: 7 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
Nx.put_slice(logits, [token_id], Nx.broadcast(Nx.Constants.neg_infinity(), {1}))
end

defn temperature_processor(logits, _context, opts \\ []) do
opts = keyword!(opts, [:temperature])
temperature = opts[:temperature]

logits / temperature
end

# Processors manipulating the probability distribution

defn top_k_processor(logits, _context, opts \\ []) do
Expand Down
9 changes: 9 additions & 0 deletions lib/bumblebee/text/generation_config.ex
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ defmodule Bumblebee.Text.GenerationConfig do
no_repeat_ngram_length: [
default: nil,
doc: "when set, n-grams of the given length can occur only once in the generated sequence"
],
temperature: [
default: nil,
doc: """
enables exponential scaling of the output probability distribution. The temperature value effectively
determines the randomness of the predicted tokens. Values smaller than 1.0 decrease the randomness,
while bigger values increase it. Note that this is only relevant for generation `:strategy` that does
sampling based on the output probability distribution
"""
]
]

Expand Down
13 changes: 13 additions & 0 deletions test/bumblebee/text/generation/logits_processing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
end
end

describe "temperature_processor/3" do
test "scales the logits" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])

context = context([1, 0, 0, 0])

assert_equal(
LogitsProcessing.temperature_processor(logits, context, temperature: 10),
Nx.tensor([0.1, 0.2, 0.3, 0.4])
)
end
end

describe "top_k_processor/3" do
test "keeps top-k highest logits" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])
Expand Down

0 comments on commit 6384b4c

Please sign in to comment.