-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
318 additions
and
249 deletions.
There are no files selected for viewing
62 changes: 62 additions & 0 deletions
62
SpongeEngine.KoboldSharp.Tests/Integration/ExtraIntegrationTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using FluentAssertions; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace SpongeEngine.KoboldSharp.Tests.Integration | ||
{ | ||
[Trait("Category", "Integration")] | ||
[Trait("API", "Native")] | ||
public class ExtraIntegrationTests : IntegrationTestBase | ||
{ | ||
public ExtraIntegrationTests(ITestOutputHelper output) : base(output) {} | ||
|
||
[SkippableFact] | ||
[Trait("Category", "Integration")] | ||
public async Task GenerateStream_ShouldStreamTokens() | ||
{ | ||
Skip.If(!ServerAvailable, "KoboldCpp server is not available"); | ||
|
||
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest | ||
{ | ||
Prompt = "Write a short story about", | ||
MaxLength = 20, | ||
Temperature = 0.7f, | ||
TopP = 0.9f, | ||
TopK = 40, | ||
RepetitionPenalty = 1.1f, | ||
RepetitionPenaltyRange = 64, | ||
TrimStop = true | ||
}; | ||
|
||
List<string> tokens = new List<string>(); | ||
using CancellationTokenSource? cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); | ||
|
||
try | ||
{ | ||
await foreach (String token in Client.GenerateStreamAsync(request, cts.Token)) | ||
{ | ||
tokens.Add(token); | ||
Output.WriteLine($"Received token: {token}"); | ||
|
||
if (tokens.Count >= request.MaxLength) | ||
{ | ||
Output.WriteLine("Reached max length, breaking"); | ||
break; | ||
} | ||
} | ||
} | ||
catch (OperationCanceledException) when (cts.Token.IsCancellationRequested) | ||
{ | ||
Output.WriteLine($"Stream timed out after receiving {tokens.Count} tokens"); | ||
} | ||
catch (Exception ex) | ||
{ | ||
Output.WriteLine($"Error during streaming: {ex}"); | ||
throw; | ||
} | ||
|
||
tokens.Should().NotBeEmpty("No tokens were received from the stream"); | ||
string.Concat(tokens).Should().NotBeNullOrEmpty("Combined token text should not be empty"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
using FluentAssertions; | ||
using WireMock.RequestBuilders; | ||
using WireMock.ResponseBuilders; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace SpongeEngine.KoboldSharp.Tests.Unit | ||
{ | ||
public class ExtraUnitTests : UnitTestBase | ||
{ | ||
public ExtraUnitTests(ITestOutputHelper output) : base(output) {} | ||
|
||
[Fact] | ||
public async Task GenerateStreamAsync_ShouldStreamTokens() | ||
{ | ||
// Arrange | ||
var tokens = new[] { "Hello", " world", "!" }; | ||
Server | ||
.Given(Request.Create() | ||
.WithPath("/api/extra/generate/stream") | ||
.UsingPost()) | ||
.RespondWith(Response.Create() | ||
.WithStatusCode(200) | ||
.WithBody("data: {\"token\": \"Hello\", \"complete\": false}\n\n" + | ||
"data: {\"token\": \" world\", \"complete\": false}\n\n" + | ||
"data: {\"token\": \"!\", \"complete\": true}\n\n") | ||
.WithHeader("Content-Type", "text/event-stream")); | ||
|
||
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest | ||
{ | ||
Prompt = "Test prompt", | ||
MaxLength = 80, | ||
Stream = true | ||
}; | ||
|
||
// Act | ||
var receivedTokens = new List<string>(); | ||
await foreach (var token in Client.GenerateStreamAsync(request)) | ||
{ | ||
receivedTokens.Add(token); | ||
} | ||
|
||
// Assert | ||
receivedTokens.Should().BeEquivalentTo(tokens); | ||
} | ||
|
||
[Fact] | ||
public async Task GenerateStreamAsync_ShouldStreamResponse() | ||
{ | ||
// Arrange | ||
var tokens = new[] { "Hello", " world", "!" }; | ||
Server | ||
.Given(Request.Create() | ||
.WithPath("/api/extra/generate/stream") | ||
.UsingPost()) | ||
.RespondWith(Response.Create() | ||
.WithStatusCode(200) | ||
.WithBody("data: {\"token\": \"Hello\", \"complete\": false}\n\n" + | ||
"data: {\"token\": \" world\", \"complete\": false}\n\n" + | ||
"data: {\"token\": \"!\", \"complete\": true}\n\n") | ||
.WithHeader("Content-Type", "text/event-stream")); | ||
|
||
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest | ||
{ | ||
Prompt = "Test prompt", | ||
MaxLength = 80, | ||
Stream = true | ||
}; | ||
|
||
// Act | ||
var receivedTokens = new List<string>(); | ||
await foreach (var token in Client.GenerateStreamAsync(request)) | ||
{ | ||
receivedTokens.Add(token); | ||
} | ||
|
||
// Assert | ||
receivedTokens.Should().BeEquivalentTo(tokens); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.