Skip to content

Commit

Permalink
Factor out DecodeStepT from GenerateT into a separate function.
Browse files Browse the repository at this point in the history
This will be useful for adding sampling functionality like beam decoding, parallel sampling, cot decoding (as described in the [Chain-of-Thought Reasoning Without Prompting paper](https://arxiv.org/abs/2402.10200))

PiperOrigin-RevId: 725066435
  • Loading branch information
apoorvreddy authored and copybara-github committed Feb 10, 2025
1 parent b0fe9a4 commit a48f0e3
Showing 1 changed file with 60 additions and 30 deletions.
90 changes: 60 additions & 30 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "paligemma/image.h"
#include "util/allocator.h"
Expand Down Expand Up @@ -1217,6 +1218,60 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
};
}

template <typename T>
void DecodeStepT(const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const size_t query_idx_start,
const KVCaches& kv_caches,
const QueriesPos& queries_prefix_end,
const hwy::Divisor div_seq_len,
const size_t vocab_size,
const SampleFunc& sample_token,
double prefill_start,
double gen_start,
Activations& activations,
TokenStreamer& token_streamer,
std::vector<int>& gen_tokens,
TimingInfo& timing_info,
size_t max_generated_tokens,
const QueriesMutablePos& queries_mutable_pos,
bool& all_queries_eos) {
const size_t num_queries = queries_prompt.size();
// Decode generates one token per query and increments
// queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, weights, activations,
div_seq_len, kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.

all_queries_eos = true;
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated(prefill_start, gen_start);

const bool is_eos =
token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
}
}


// Generates one continuation for each query in `queries_prompt`, which is one
// qbatch whose size is at most the `batch_size` passed to
// `activations.Allocate`.
Expand Down Expand Up @@ -1310,37 +1365,12 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
const size_t vocab_size = model.Config().vocab_size;
const double gen_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
// Decode generates one token per query and increments
// queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, weights, activations,
div_seq_len, kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.

bool all_queries_eos = true;
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated(prefill_start, gen_start);

const bool is_eos =
token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
}
DecodeStepT<T>(weights, runtime_config, queries_prompt, query_idx_start,
kv_caches, queries_prefix_end, div_seq_len, vocab_size,
sample_token, prefill_start, gen_start, activations,
token_streamer, gen_tokens, timing_info, max_generated_tokens,
queries_mutable_pos, all_queries_eos);
if (all_queries_eos) break;
} // foreach token to generate
timing_info.NotifyGenerateDone(gen_start);
Expand Down

0 comments on commit a48f0e3

Please sign in to comment.