Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Binaries December 2023 #361

Merged
merged 9 commits into from
Dec 15, 2023
Prev Previous commit
Next Next commit
Added a safe handle for LLamaKvCacheView
  • Loading branch information
martindevans committed Dec 14, 2023
commit bab6b65b61fc86c9462dff71c95139ea2dfbe416
6 changes: 5 additions & 1 deletion LLama.Examples/Program.cs
Original file line number Diff line number Diff line change
@@ -7,7 +7,11 @@

Console.WriteLine("======================================================================================================");

NativeLibraryConfig.Instance.WithCuda().WithLogs();
NativeLibraryConfig
.Instance
.WithCuda()
.WithLogs()
.WithAvx(NativeLibraryConfig.AvxLevel.Avx512);

NativeApi.llama_empty_call();
Console.WriteLine();
87 changes: 82 additions & 5 deletions LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Runtime.InteropServices;
using System;
using System.Runtime.InteropServices;

namespace LLama.Native;

@@ -18,7 +19,6 @@ public struct LLamaKvCacheViewCell
/// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view)
/// </summary>
//todo: rewrite to safe handle?
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaKvCacheView
{
@@ -52,6 +52,84 @@ public unsafe struct LLamaKvCacheView
LLamaSeqId* cells_sequences;
}

/// <summary>
/// A safe handle for a LLamaKvCacheView
/// </summary>
public class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase
{
private readonly SafeLLamaContextHandle _ctx;
private LLamaKvCacheView _view;

/// <summary>
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
/// </summary>
/// <param name="ctx"></param>
/// <param name="view"></param>
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
: base(IntPtr.MaxValue, true)
{
_ctx = ctx;
_view = view;
}

/// <summary>
/// Allocate a new llama_kv_cache_view_free
/// </summary>
/// <param name="ctx"></param>
/// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param>
/// <returns></returns>
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
{
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
return new LLamaKvCacheViewSafeHandle(ctx, result);
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_kv_cache_view_free(ref _view);
SetHandle(IntPtr.Zero);

return true;
}

/// <summary>
/// Update this view
/// </summary>
public void Update()
{
NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
}

/// <summary>
/// Count the number of used cells in the KV cache
/// </summary>
/// <returns></returns>
public int CountCells()
{
return NativeApi.llama_get_kv_cache_used_cells(_ctx);
}

/// <summary>
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
/// </summary>
/// <returns></returns>
public int CountTokens()
{
return NativeApi.llama_get_kv_cache_token_count(_ctx);
}

/// <summary>
/// Get the raw KV cache view
/// </summary>
/// <returns></returns>
public ref LLamaKvCacheView GetView()
{
return ref _view;
}
}

partial class NativeApi
{
/// <summary>
@@ -66,17 +144,16 @@ partial class NativeApi
/// <summary>
/// Free a KV cache view. (use only for debugging purposes)
/// </summary>
/// <param name="view"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe void llama_kv_cache_view_free(LLamaKvCacheView* view);
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);

/// <summary>
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
/// </summary>
/// <param name="ctx"></param>
/// <param name="view"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, LLamaKvCacheView* view);
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);

/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)