diff --git a/LLama.Unittest/StreamingTextDecoderTests.cs b/LLama.Unittest/StreamingTextDecoderTests.cs
new file mode 100644
index 000000000..680ca076f
--- /dev/null
+++ b/LLama.Unittest/StreamingTextDecoderTests.cs
@@ -0,0 +1,53 @@
+using System.Text;
+using LLama.Common;
+using Xunit.Abstractions;
+
+namespace LLama.Unittest;
+
+public class StreamingTextDecoderTests
+ : IDisposable
+{
+ private readonly LLamaWeights _model;
+ private readonly ITestOutputHelper _testOutputHelper;
+ private readonly ModelParams _params;
+
+ public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
+ {
+ _testOutputHelper = testOutputHelper;
+ _params = new ModelParams(Constants.ModelPath);
+ _model = LLamaWeights.LoadFromFile(_params);
+ }
+
+ public void Dispose()
+ {
+ _model.Dispose();
+ }
+
+ [Fact]
+ public void DecodesSimpleText()
+ {
+ var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
+
+ const string text = "The cat sat on the mat";
+ var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);
+
+ foreach (var lLamaToken in tokens)
+ decoder.Add(lLamaToken);
+
+ Assert.Equal(text, decoder.Read().Trim());
+ }
+
+ [Fact]
+ public void DecodesComplexText()
+ {
+ var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
+
+ const string text = "猫坐在垫子上 😀🤨🤐😏";
+ var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);
+
+ foreach (var lLamaToken in tokens)
+ decoder.Add(lLamaToken);
+
+ Assert.Equal(text, decoder.Read().Trim());
+ }
+}
\ No newline at end of file
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index c69577215..17fa13cf5 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -194,7 +194,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e
/// Token to decode
/// A span to attempt to write into. If this is too small nothing will be written
/// The size of this token. **nothing will be written** if this is larger than `dest`
- public int TokenToSpan(LLamaToken token, Span dest)
+ public uint TokenToSpan(LLamaToken token, Span dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index d4fe2d71a..47ffb4dd6 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -126,10 +126,10 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
/// Token to decode
/// A span to attempt to write into. If this is too small nothing will be written
/// The size of this token. **nothing will be written** if this is larger than `dest`
- public int TokenToSpan(LLamaToken token, Span dest)
+ public uint TokenToSpan(LLamaToken token, Span dest)
{
var length = NativeApi.llama_token_to_piece(this, token, dest);
- return Math.Abs(length);
+ return (uint)Math.Abs(length);
}
///
diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs
index a66d5c9e9..4c1ea58d0 100644
--- a/LLama/StreamingTokenDecoder.cs
+++ b/LLama/StreamingTokenDecoder.cs
@@ -113,19 +113,19 @@ static Span TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaMode
// Try to get bytes
var l = model.TokenToSpan(token, bytes);
- // Negative length indicates that the output was too small. Expand it to twice that size and try again.
- if (l < 0)
+ // Check if the length was larger than the buffer. If so expand the buffer and try again
+ if (l > bytes.Length)
{
// Return the old array to the pool and get a new one
ArrayPool.Shared.Return(bytes);
- bytes = ArrayPool.Shared.Rent(-l * 2);
+ bytes = ArrayPool.Shared.Rent((int)(l * 2));
// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}
- Debug.Assert(l >= 0);
- return new Span(bytes, 0, l);
+ Debug.Assert(l <= bytes.Length);
+ return new Span(bytes, 0, (int)l);
}
}