From a48f0e3ccbb881545a0a1f3632b25eff22fe755a Mon Sep 17 00:00:00 2001 From: Apoorv Reddy Date: Sun, 9 Feb 2025 22:03:06 -0800 Subject: [PATCH] Factor out DecodeStepT from GenerateT into a separate function. This will be useful for adding sampling functionality like beam decoding, parallel sampling, cot decoding (as described in the [Chain-of-Thought Reasoning Without Prompting paper](https://arxiv.org/abs/2402.10200)) PiperOrigin-RevId: 725066435 --- gemma/gemma-inl.h | 90 +++++++++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 30 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 01b3930..a9065c4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -27,6 +27,7 @@ #include "gemma/common.h" #include "gemma/configs.h" #include "gemma/gemma.h" +#include "gemma/kv_cache.h" #include "gemma/weights.h" #include "paligemma/image.h" #include "util/allocator.h" @@ -1217,6 +1218,60 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { }; } +template +void DecodeStepT(const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const size_t query_idx_start, + const KVCaches& kv_caches, + const QueriesPos& queries_prefix_end, + const hwy::Divisor div_seq_len, + const size_t vocab_size, + const SampleFunc& sample_token, + double prefill_start, + double gen_start, + Activations& activations, + TokenStreamer& token_streamer, + std::vector& gen_tokens, + TimingInfo& timing_info, + size_t max_generated_tokens, + const QueriesMutablePos& queries_mutable_pos, + bool& all_queries_eos) { + const size_t num_queries = queries_prompt.size(); + // Decode generates one token per query and increments + // queries_mutable_pos. + Transformer(QueriesToken(gen_tokens.data(), num_queries), + queries_mutable_pos, queries_prefix_end, weights, activations, + div_seq_len, kv_caches, runtime_config.layers_output, + runtime_config.activations_observer); + // queries_pos are incremented by Transformer. + + all_queries_eos = true; + { + PROFILER_ZONE("Gen.EmbeddingMatmul"); + // Compute logits from last layer activations. + MatMul(ConstMatFromBatch(num_queries, activations.x), + ConstMatFromWeights(weights.embedder_input_embedding), + /*add=*/nullptr, *activations.env, + RowPtrFromBatch(activations.logits)); + } + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); + MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, + vocab_size); + const TokenAndProb tp = sample_token(logits, vocab_size); + timing_info.NotifyGenerated(prefill_start, gen_start); + + const bool is_eos = + token_streamer(query_idx_start + query_idx, + queries_mutable_pos[query_idx], tp.token, tp.prob); + all_queries_eos &= is_eos; + gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; + } +} + + // Generates one continuation for each query in `queries_prompt`, which is one // qbatch whose size is at most the `batch_size` passed to // `activations.Allocate`. @@ -1310,37 +1365,12 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, const size_t vocab_size = model.Config().vocab_size; const double gen_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - // Decode generates one token per query and increments - // queries_mutable_pos. - Transformer(QueriesToken(gen_tokens.data(), num_queries), - queries_mutable_pos, queries_prefix_end, weights, activations, - div_seq_len, kv_caches, runtime_config.layers_output, - runtime_config.activations_observer); - // queries_pos are incremented by Transformer. - bool all_queries_eos = true; - { - PROFILER_ZONE("Gen.EmbeddingMatmul"); - // Compute logits from last layer activations. - MatMul(ConstMatFromBatch(num_queries, activations.x), - ConstMatFromWeights(weights.embedder_input_embedding), - /*add=*/nullptr, *activations.env, - RowPtrFromBatch(activations.logits)); - } - PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); - MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, - vocab_size); - const TokenAndProb tp = sample_token(logits, vocab_size); - timing_info.NotifyGenerated(prefill_start, gen_start); - - const bool is_eos = - token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], tp.token, tp.prob); - all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; - } + DecodeStepT(weights, runtime_config, queries_prompt, query_idx_start, + kv_caches, queries_prefix_end, div_seq_len, vocab_size, + sample_token, prefill_start, gen_start, activations, + token_streamer, gen_tokens, timing_info, max_generated_tokens, + queries_mutable_pos, all_queries_eos); if (all_queries_eos) break; } // foreach token to generate timing_info.NotifyGenerateDone(gen_start);