Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed unnecessary parameters from some low level sampler methods #125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions LLama/Extensions/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System.Collections.Generic;

namespace LLama.Extensions
{
internal static class DictionaryExtensions
{
#if NETSTANDARD2_0
public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{
return dictionary.TryGetValue(key, out var value) ? value : defaultValue;
}
#endif
}
}
21 changes: 21 additions & 0 deletions LLama/Extensions/IEnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Collections.Generic;
using System.Linq;

namespace LLama.Extensions
{
internal static class IEnumerableExtensions
{
#if NETSTANDARD2_0
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> source, int count)
{
var list = source.ToList();

if (count >= list.Count)
return list;

list.RemoveRange(0, list.Count - count);
return list;
}
#endif
}
}
45 changes: 25 additions & 20 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,36 +355,41 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var n_vocab = _ctx.VocabCount;
var logits = _ctx.GetLogits();

// Apply params.logit_bias map
if(logitBias is not null)
if (logitBias is not null)
{
foreach (var (key, value) in logitBias)
{
logits[key] += value;
}
}

var candidates = new LLamaTokenData[n_vocab];
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);

// Apply penalties
float nl_logit = logits[NativeApi.llama_token_nl()];
int lastTokensCount = lastTokens.Count();
var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
(ulong)last_n_repeat, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p,
lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(),
(ulong)last_n_repeat, alphaFrequency, alphaPresence);
// Save the newline logit value
var nl_token = NativeApi.llama_token_nl();
var nl_logit = logits[nl_token];

// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);

// Extract most recently returned tokens
var last_n_repeat = Math.Min(ContextSize, repeatLastTokensCount);
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();

// Apply penalties to candidates
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence);

// Restore newline token logit value if necessary
if (!penalizeNL)
{
logits[NativeApi.llama_token_nl()] = nl_logit;
var candidatesSpan = candidates_p.data.Span;
for (var i = 0; i < candidates_p.data.Length; i++)
{
ref var item = ref candidatesSpan[i];
if (item.id == nl_token)
item.logit = nl_logit;
}
candidates_p.sorted = false;
}

return candidates_p;
Expand Down
20 changes: 18 additions & 2 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
using System.Buffers;
using System.Runtime.InteropServices;

using llama_token = System.Int32;

namespace LLama.Native
{
/// <summary>
Expand All @@ -15,9 +17,9 @@ public struct LLamaTokenDataArray
public readonly Memory<LLamaTokenData> data;

/// <summary>
/// Indicates if `data` is sorted
/// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_.
/// </summary>
public readonly bool sorted;
public bool sorted;

/// <summary>
/// Create a new LLamaTokenDataArray
Expand All @@ -29,6 +31,20 @@ public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
data = tokens;
sorted = isSorted;
}

/// <summary>
/// Create a new LLamaTokenDataArray, copying the data from the given logits
/// </summary>
/// <param name="logits"></param>
/// <returns></returns>
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
{
var candidates = new LLamaTokenData[logits.Length];
for (var token_id = 0; token_id < logits.Length; token_id++)
candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f);

return new LLamaTokenDataArray(candidates);
}
}

/// <summary>
Expand Down
31 changes: 29 additions & 2 deletions LLama/Native/SamplingApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,25 @@ public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDa
/// <param name="last_tokens"></param>
/// <param name="last_tokens_size"></param>
/// <param name="penalty"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float penalty)
{
llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
}

/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="penalty"></param>
public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float penalty)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty);
NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
}

/// <summary>
Expand All @@ -42,12 +55,26 @@ public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, L
/// <param name="last_tokens_size"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
[Obsolete("last_tokens_size parameter is no longer needed")]
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
{
llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
}

/// <summary>
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
/// <param name="last_tokens"></param>
/// <param name="alpha_frequency"></param>
/// <param name="alpha_presence"></param>
public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory<llama_token> last_tokens, float alpha_frequency, float alpha_presence)
{
using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
using var last_tokens_handle = last_tokens.Pin();

NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence);
NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
}

/// <summary>
Expand Down