Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingTextDecoder Fix & Tests #428

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions LLama.Unittest/StreamingTextDecoderTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System.Text;
using LLama.Common;
using Xunit.Abstractions;

namespace LLama.Unittest;

public class StreamingTextDecoderTests
: IDisposable
{
private readonly LLamaWeights _model;
private readonly ITestOutputHelper _testOutputHelper;
private readonly ModelParams _params;

public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.ModelPath);
_model = LLamaWeights.LoadFromFile(_params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void DecodesSimpleText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

const string text = "The cat sat on the mat";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);

foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);

Assert.Equal(text, decoder.Read().Trim());
}

[Fact]
public void DecodesComplexText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

const string text = "猫坐在垫子上 😀🤨🤐😏";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);

foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);

Assert.Equal(text, decoder.Read().Trim());
}
}
2 changes: 1 addition & 1 deletion LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
var length = NativeApi.llama_token_to_piece(this, token, dest);
return Math.Abs(length);
return (uint)Math.Abs(length);
}

/// <summary>
Expand Down
10 changes: 5 additions & 5 deletions LLama/StreamingTokenDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,19 @@ static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaMode
// Try to get bytes
var l = model.TokenToSpan(token, bytes);

// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
// Check if the length was larger than the buffer. If so expand the buffer and try again
if (l > bytes.Length)
{
// Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2));

// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}

Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
Debug.Assert(l <= bytes.Length);
return new Span<byte>(bytes, 0, (int)l);
}
}

Expand Down
Loading