diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index d7ba0ccb..fff77bba 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -304,6 +304,9 @@ defmodule Bumblebee.Text.Generation do end, if config.forced_token_ids do &forced_tokens_processor(&1, &2, forced_token_ids: config.forced_token_ids) + end, + if config.temperature && config.temperature != 1.0 do + &temperature_processor(&1, &2, temperature: config.temperature) end ] ++ if config.strategy.type == :multinomial_sampling do diff --git a/lib/bumblebee/text/generation/logits_processing.ex b/lib/bumblebee/text/generation/logits_processing.ex index 0d619fba..1d93b4d1 100644 --- a/lib/bumblebee/text/generation/logits_processing.ex +++ b/lib/bumblebee/text/generation/logits_processing.ex @@ -117,6 +117,13 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do Nx.put_slice(logits, [token_id], Nx.broadcast(Nx.Constants.neg_infinity(), {1})) end + defn temperature_processor(logits, _context, opts \\ []) do + opts = keyword!(opts, [:temperature]) + temperature = opts[:temperature] + + logits / temperature + end + # Processors manipulating the probability distribution defn top_k_processor(logits, _context, opts \\ []) do diff --git a/lib/bumblebee/text/generation_config.ex b/lib/bumblebee/text/generation_config.ex index 09d4cadb..d8a6a003 100644 --- a/lib/bumblebee/text/generation_config.ex +++ b/lib/bumblebee/text/generation_config.ex @@ -96,6 +96,15 @@ defmodule Bumblebee.Text.GenerationConfig do no_repeat_ngram_length: [ default: nil, doc: "when set, n-grams of the given length can occur only once in the generated sequence" + ], + temperature: [ + default: nil, + doc: """ + enables exponential scaling of the output probability distribution. The temperature value effectively + determines the randomness of the predicted tokens. Values smaller than 1.0 decrease the randomness, + while bigger values increase it. Note that this is only relevant for generation `:strategy` that does + sampling based on the output probability distribution + """ ] ] diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 0d8c945d..d88409e0 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -155,6 +155,19 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do end end + describe "temperature_processor/3" do + test "scales the logits" do + logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) + + context = context([1, 0, 0, 0]) + + assert_equal( + LogitsProcessing.temperature_processor(logits, context, temperature: 10), + Nx.tensor([0.1, 0.2, 0.3, 0.4]) + ) + end + end + describe "top_k_processor/3" do test "keeps top-k highest logits" do logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])