Skip to content

Commit

Permalink
Merge pull request #454 from martindevans/kv_cache_instance_methods
Browse files Browse the repository at this point in the history
kv_cache_instance_methods
  • Loading branch information
martindevans authored Jan 25, 2024
2 parents 8dfd07f + 92b9bbe commit 5cf481d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 39 deletions.
56 changes: 30 additions & 26 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Text;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples;

Expand All @@ -14,10 +15,6 @@ public class BatchedDecoding
private const int n_parallel = 8;
private const int n_len = 32;

private const int top_k = 80;
private const float top_p = 0.8f;
private const float temp = 0.75f;

public static async Task Run()
{
Console.Write("Please input your model path: ");
Expand Down Expand Up @@ -55,10 +52,9 @@ public static async Task Run()
var batch = new LLamaBatch();

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
batch.AddRange(prompt_tokens, 0, LLamaSeqId.Zero, true);

if (await context.DecodeAsync(batch) != 0)
if (await context.DecodeAsync(batch) != DecodeResult.Ok)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
Expand All @@ -68,7 +64,7 @@ public static async Task Run()
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
context.NativeHandle.KvCacheSequenceCopy((LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
}

if (n_parallel > 1)
Expand All @@ -83,15 +79,21 @@ public static async Task Run()
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.TokenCount - 1);

var n_cur = batch.TokenCount;
var n_decode = 0;

var streams = new StreamingTokenDecoder[n_parallel];
// Create per-stream decoder and sampler
var decoders = new StreamingTokenDecoder[n_parallel];
var samplers = new ISamplingPipeline[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new StreamingTokenDecoder(context);
{
decoders[i] = new StreamingTokenDecoder(context);
samplers[i] = new DefaultSamplingPipeline
{
Temperature = 0.1f + (float)i / n_parallel,
MinP = 0.25f,
};
}

var eos = model.EndOfSentenceToken;
var nl = model.NewlineToken;
var n_cur = batch.TokenCount;
var n_decode = 0;

var timer = new Stopwatch();
timer.Start();
Expand All @@ -105,31 +107,33 @@ public static async Task Run()
if (i_batch[i] < 0)
continue;

var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));

candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
candidates.Temperature(context.NativeHandle, temp);
var new_token_id = candidates.SampleToken(context.NativeHandle);
// Use the sampling pipeline to select a token
var new_token_id = samplers[i].Sample(
context.NativeHandle,
context.NativeHandle.GetLogitsIth(i_batch[i]),
Array.Empty<LLamaToken>()
);

if (new_token_id == eos || new_token_id == nl)
// Finish this stream early if necessary
if (new_token_id == model.EndOfSentenceToken || new_token_id == model.NewlineToken)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}

streams[i].Add(new_token_id);
// Add this token to the decoder, so it will be turned into text
decoders[i].Add(new_token_id);

i_batch[i] = batch.TokenCount;

// push this new token for next evaluation
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
batch.Add(new_token_id, n_cur, (LLamaSeqId)i, true);

n_decode++;
}

// all streams are finished
// Check if all streams are finished
if (batch.TokenCount == 0)
{
break;
Expand All @@ -152,7 +156,7 @@ public static async Task Run()
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");

var index = 0;
foreach (var stream in streams)
foreach (var stream in decoders)
{
var text = stream.Read();

Expand Down
73 changes: 63 additions & 10 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;

namespace LLama.Native;

Expand All @@ -22,14 +24,14 @@ public class LLamaBatch
public int TokenCount { get; private set; }

/// <summary>
/// Maximum number of tokens that can be added to this batch
/// Maximum number of tokens that can be added to this batch (automatically grows if exceeded)
/// </summary>
private int TokenCapacity { get; set; }

/// <summary>
/// Maximum number of sequences a token can be assigned to
/// Maximum number of sequences a token can be assigned to (automatically grows if exceeded)
/// </summary>
public int MaxSequences { get; private set; }
public int SequenceCapacity { get; private set; }

/// <summary>
/// Create a new batch for submitting inputs to llama.cpp
Expand All @@ -40,7 +42,7 @@ public LLamaBatch()
const int n_tokens = 128;
const int n_seq_max = 1;

MaxSequences = n_seq_max;
SequenceCapacity = n_seq_max;
TokenCapacity = n_tokens;

_logits = new byte[n_tokens];
Expand All @@ -52,9 +54,10 @@ public LLamaBatch()

_sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
}

#region grow
private void GrowTokenCapacity()
{
var n_tokens = TokenCount * 2;
Expand All @@ -73,18 +76,19 @@ private void GrowTokenCapacity()
// Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
}
}

private void GrowMaxSequences(int atLeast)
{
var n_seq = Math.Max(MaxSequences * 2, atLeast);
MaxSequences = n_seq;
var n_seq = Math.Max(SequenceCapacity * 2, atLeast);
SequenceCapacity = n_seq;

for (var i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds[i], MaxSequences);
Array.Resize(ref _sequenceIds[i], SequenceCapacity);
}
#endregion

internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{
Expand Down Expand Up @@ -117,6 +121,7 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
return group;
}

#region add
/// <summary>
/// Add a single token to the batch at the same position in several sequences
/// </summary>
Expand All @@ -129,7 +134,7 @@ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequenc
{
if (TokenCount == TokenCapacity)
GrowTokenCapacity();
if (sequences.Length > MaxSequences)
if (sequences.Length > SequenceCapacity)
GrowMaxSequences(sequences.Length);

_tokens[TokenCount] = token;
Expand All @@ -144,6 +149,37 @@ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequenc
TokenCount++;
}

/// <summary>
/// Add a single token to the batch at the same position in several sequences
/// </summary>
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
/// <param name="token">The token to add</param>
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void Add(LLamaToken token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
{
#if NET5_0_OR_GREATER
var seqSpan = CollectionsMarshal.AsSpan(sequences);
Add(token, pos, seqSpan, logits);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<LLamaSeqId>.Shared.Rent(sequences.Count);
try
{
sequences.CopyTo(rented, 0);
Add(token, pos, rented.AsSpan(0, sequences.Count), logits);
}
finally
{
System.Buffers.ArrayPool<LLamaSeqId>.Shared.Return(rented);
}
#endif
}

/// <summary>
/// Add a single token to the batch at a certain position for a single sequences
/// </summary>
Expand All @@ -162,6 +198,23 @@ public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits
Add(token, pos, sequences, logits);
}

/// <summary>
/// Add a range of tokens to a single sequence, start at the given position.
/// </summary>
/// <param name="tokens">The tokens to add</param>
/// <param name="start">The starting position to add tokens at</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logitsLast">Whether the final token should generate logits</param>
public void AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId sequence, bool logitsLast)
{
for (var i = 0; i < tokens.Length; i++)
{
var logits = (i == tokens.Length - 1) & logitsLast;
Add(tokens[i], start.Value + i, sequence, logits);
}
}
#endregion

/// <summary>
/// Set TokenCount to zero for this batch
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public int CountCells()
}

/// <summary>
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times
/// </summary>
/// <returns></returns>
public int CountTokens()
Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,6 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken ll
/// <param name="n_threads_batch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
public static extern void llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch);
}
}
72 changes: 71 additions & 1 deletion LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
#endregion
#endregion

#region infer
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary>
Expand Down Expand Up @@ -202,6 +203,7 @@ public int Decode(LLamaBatch batch)
using (batch.ToNativeBatch(out var nb))
return NativeApi.llama_decode(this, nb);
}
#endregion

#region state
/// <summary>
Expand Down Expand Up @@ -275,5 +277,73 @@ public void SetSeed(uint seed)
{
NativeApi.llama_set_rng_seed(this, seed);
}

/// <summary>
/// Set the number of threads used for decoding
/// </summary>
/// <param name="threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="threadsBatch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
public void SetThreads(uint threads, uint threadsBatch)
{
NativeApi.llama_set_n_threads(this, threads, threadsBatch);
}

#region KV Cache Management
/// <summary>
/// Clear the KV cache
/// </summary>
public void KvCacheClear()
{
NativeApi.llama_kv_cache_clear(this);
}

/// <summary>
/// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
/// </summary>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
public void KvCacheRemove(LLamaSeqId seq, LLamaPos p0, LLamaPos p1)
{
NativeApi.llama_kv_cache_seq_rm(this, seq, p0, p1);
}

/// <summary>
/// Copy all tokens that belong to the specified sequence to another sequence. Note that
/// this does not allocate extra KV cache memory - it simply assigns the tokens to the
/// new sequence
/// </summary>
/// <param name="src"></param>
/// <param name="dest"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
public void KvCacheSequenceCopy(LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1)
{
NativeApi.llama_kv_cache_seq_cp(this, src, dest, p0, p1);
}

/// <summary>
/// Removes all tokens that do not belong to the specified sequence
/// </summary>
/// <param name="seq"></param>
public void KvCacheSequenceKeep(LLamaSeqId seq)
{
NativeApi.llama_kv_cache_seq_keep(this, seq);
}

/// <summary>
/// Adds relative position "delta" to all tokens that belong to the specified sequence
/// and have positions in [p0, p1. If the KV cache is RoPEd, the KV data is updated
/// accordingly
/// </summary>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="delta"></param>
public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta)
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}
#endregion
}
}

0 comments on commit 5cf481d

Please sign in to comment.