From 92b9bbe7792faf4936e517659792fdb16aa5bd67 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 23 Jan 2024 16:16:02 +0000 Subject: [PATCH] Added methods to `SafeLLamaContextHandle` for KV cache manipulation --- LLama.Examples/Examples/BatchedDecoding.cs | 56 +++++++++-------- LLama/Native/LLamaBatch.cs | 73 +++++++++++++++++++--- LLama/Native/LLamaKvCacheView.cs | 2 +- LLama/Native/NativeApi.cs | 2 +- LLama/Native/SafeLLamaContextHandle.cs | 72 ++++++++++++++++++++- 5 files changed, 166 insertions(+), 39 deletions(-) diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 9cdf01fa6..f893b613f 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -2,6 +2,7 @@ using System.Text; using LLama.Common; using LLama.Native; +using LLama.Sampling; namespace LLama.Examples.Examples; @@ -14,10 +15,6 @@ public class BatchedDecoding private const int n_parallel = 8; private const int n_len = 32; - private const int top_k = 80; - private const float top_p = 0.8f; - private const float temp = 0.75f; - public static async Task Run() { Console.Write("Please input your model path: "); @@ -55,10 +52,9 @@ public static async Task Run() var batch = new LLamaBatch(); // evaluate the initial prompt - for (var i = 0; i < prompt_tokens.Length; i++) - batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); + batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true); - if (await context.DecodeAsync(batch) != 0) + if (await context.DecodeAsync(batch) != DecodeResult.Ok) { await Console.Error.WriteLineAsync("llama_decode failed"); return; @@ -68,7 +64,7 @@ public static async Task Run() // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (var i = 1; i < n_parallel; ++i) { - NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); + context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); } if (n_parallel > 1) @@ -83,15 +79,21 @@ public static async Task Run() for (var i = 0; i < n_parallel; i++) i_batch.Add(batch.TokenCount - 1); - var n_cur = batch.TokenCount; - var n_decode = 0; - - var streams = new StreamingTokenDecoder[n_parallel]; + // Create per-stream decoder and sampler + var decoders = new StreamingTokenDecoder[n_parallel]; + var samplers = new ISamplingPipeline[n_parallel]; for (var i = 0; i < n_parallel; i++) - streams[i] = new StreamingTokenDecoder(context); + { + decoders[i] = new StreamingTokenDecoder(context); + samplers[i] = new DefaultSamplingPipeline + { + Temperature = 0.1f + (float)i / n_parallel, + MinP = 0.25f, + }; + } - var eos = model.EndOfSentenceToken; - var nl = model.NewlineToken; + var n_cur = batch.TokenCount; + var n_decode = 0; var timer = new Stopwatch(); timer.Start(); @@ -105,31 +107,33 @@ public static async Task Run() if (i_batch[i] < 0) continue; - var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i])); - - candidates.TopK(context.NativeHandle, top_k); - candidates.TopP(context.NativeHandle, top_p); - candidates.Temperature(context.NativeHandle, temp); - var new_token_id = candidates.SampleToken(context.NativeHandle); + // Use the sampling pipeline to select a token + var new_token_id = samplers[i].Sample( + context.NativeHandle, + context.NativeHandle.GetLogitsIth(i_batch[i]), + Array.Empty() + ); - if (new_token_id == eos || new_token_id == nl) + // Finish this stream early if necessary + if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken) { i_batch[i] = -1; Console.WriteLine($"Completed Stream {i} early"); continue; } - streams[i].Add(new_token_id); + // Add this token to the decoder, so it will be turned into text + decoders[i].Add(new_token_id); i_batch[i] = batch.TokenCount; // push this new token for next evaluation - batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); + batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true); n_decode++; } - // all streams are finished + // Check if all streams are finished if (batch.TokenCount == 0) { break; @@ -152,7 +156,7 @@ public static async Task Run() Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); var index = 0; - foreach (var stream in streams) + foreach (var stream in decoders) { var text = stream.Read(); diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index 20e145306..532e16fff 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; namespace LLama.Native; @@ -22,14 +24,14 @@ public class LLamaBatch public int TokenCount { get; private set; } /// - /// Maximum number of tokens that can be added to this batch + /// Maximum number of tokens that can be added to this batch (automatically grows if exceeded) /// private int TokenCapacity { get; set; } /// - /// Maximum number of sequences a token can be assigned to + /// Maximum number of sequences a token can be assigned to (automatically grows if exceeded) /// - public int MaxSequences { get; private set; } + public int SequenceCapacity { get; private set; } /// /// Create a new batch for submitting inputs to llama.cpp @@ -40,7 +42,7 @@ public LLamaBatch() const int n_tokens = 128; const int n_seq_max = 1; - MaxSequences = n_seq_max; + SequenceCapacity = n_seq_max; TokenCapacity = n_tokens; _logits = new byte[n_tokens]; @@ -52,9 +54,10 @@ public LLamaBatch() _sequenceIds = new LLamaSeqId[n_tokens][]; for (var i = 0; i < _sequenceIds.Length; i++) - _sequenceIds[i] = new LLamaSeqId[MaxSequences]; + _sequenceIds[i] = new LLamaSeqId[SequenceCapacity]; } + #region grow private void GrowTokenCapacity() { var n_tokens = TokenCount * 2; @@ -73,18 +76,19 @@ private void GrowTokenCapacity() // Growing the array filled elements with null, temporarily violating the nullability contract! // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract if (_sequenceIds[i] == null) - _sequenceIds[i] = new LLamaSeqId[MaxSequences]; + _sequenceIds[i] = new LLamaSeqId[SequenceCapacity]; } } private void GrowMaxSequences(int atLeast) { - var n_seq = Math.Max(MaxSequences * 2, atLeast); - MaxSequences = n_seq; + var n_seq = Math.Max(SequenceCapacity * 2, atLeast); + SequenceCapacity = n_seq; for (var i = 0; i < _sequenceIds.Length; i++) - Array.Resize(ref _sequenceIds[i], MaxSequences); + Array.Resize(ref _sequenceIds[i], SequenceCapacity); } + #endregion internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) { @@ -117,6 +121,7 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) return group; } + #region add /// /// Add a single token to the batch at the same position in several sequences /// @@ -129,7 +134,7 @@ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequenc { if (TokenCount == TokenCapacity) GrowTokenCapacity(); - if (sequences.Length > MaxSequences) + if (sequences.Length > SequenceCapacity) GrowMaxSequences(sequences.Length); _tokens[TokenCount] = token; @@ -144,6 +149,37 @@ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequenc TokenCount++; } + /// + /// Add a single token to the batch at the same position in several sequences + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 + /// The token to add + /// The position to add it att + /// The set of sequences to add this token to + /// + public void Add(LLamaToken token, LLamaPos pos, List sequences, bool logits) + { +#if NET5_0_OR_GREATER + var seqSpan = CollectionsMarshal.AsSpan(sequences); + Add(token, pos, seqSpan, logits); +#else + // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of + // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't + // avoid the copying. + + var rented = System.Buffers.ArrayPool.Shared.Rent(sequences.Count); + try + { + sequences.CopyTo(rented, 0); + Add(token, pos, rented.AsSpan(0, sequences.Count), logits); + } + finally + { + System.Buffers.ArrayPool.Shared.Return(rented); + } +#endif + } + /// /// Add a single token to the batch at a certain position for a single sequences /// @@ -162,6 +198,23 @@ public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits Add(token, pos, sequences, logits); } + /// + /// Add a range of tokens to a single sequence, start at the given position. + /// + /// The tokens to add + /// The starting position to add tokens at + /// The sequence to add this token to + /// Whether the final token should generate logits + public void AddRange(ReadOnlySpan tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast) + { + for (var i = 0; i < tokens.Length; i++) + { + var logits = (i == tokens.Length - 1) & logitsLast; + Add(tokens[i], start.Value + i, sequence, logits); + } + } +#endregion + /// /// Set TokenCount to zero for this batch /// diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index 65fbccba3..4cccd13c5 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -112,7 +112,7 @@ public int CountCells() } /// - /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times + /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times /// /// public int CountTokens() diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 2846b2d31..bb28e7ab1 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -422,6 +422,6 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); + public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index b10e083f0..2c5d82884 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -167,8 +167,9 @@ public uint TokenToSpan(LLamaToken token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } -#endregion + #endregion + #region infer /// /// Run the llama inference to obtain the logits and probabilities for the next token. /// @@ -202,6 +203,7 @@ public int Decode(LLamaBatch batch) using (batch.ToNativeBatch(out var nb)) return NativeApi.llama_decode(this, nb); } + #endregion #region state /// @@ -275,5 +277,73 @@ public void SetSeed(uint seed) { NativeApi.llama_set_rng_seed(this, seed); } + + /// + /// Set the number of threads used for decoding + /// + /// n_threads is the number of threads used for generation (single token) + /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + public void SetThreads(uint threads, uint threadsBatch) + { + NativeApi.llama_set_n_threads(this, threads, threadsBatch); + } + + #region KV Cache Management + /// + /// Clear the KV cache + /// + public void KvCacheClear() + { + NativeApi.llama_kv_cache_clear(this); + } + + /// + /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + /// + /// + /// + /// + public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1) + { + NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1); + } + + /// + /// Copy all tokens that belong to the specified sequence to another sequence. Note that + /// this does not allocate extra KV cache memory - it simply assigns the tokens to the + /// new sequence + /// + /// + /// + /// + /// + public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1) + { + NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1); + } + + /// + /// Removes all tokens that do not belong to the specified sequence + /// + /// + public void KvCacheSequenceKeep(LLamaSeqId seq) + { + NativeApi.llama_kv_cache_seq_keep(this, seq); + } + + /// + /// Adds relative position "delta" to all tokens that belong to the specified sequence + /// and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated + /// accordingly + /// + /// + /// + /// + /// + public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta) + { + NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); + } + #endregion } }