Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
dclipca committed Jan 15, 2025
1 parent b6b9e39 commit 4efbb2f
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 249 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using FluentAssertions;
using Xunit;
using Xunit.Abstractions;

namespace SpongeEngine.KoboldSharp.Tests.Integration
{
[Trait("Category", "Integration")]
[Trait("API", "Native")]
public class ExtraIntegrationTests : IntegrationTestBase
{
public ExtraIntegrationTests(ITestOutputHelper output) : base(output) {}

[SkippableFact]
[Trait("Category", "Integration")]
public async Task GenerateStream_ShouldStreamTokens()
{
Skip.If(!ServerAvailable, "KoboldCpp server is not available");

KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Write a short story about",
MaxLength = 20,
Temperature = 0.7f,
TopP = 0.9f,
TopK = 40,
RepetitionPenalty = 1.1f,
RepetitionPenaltyRange = 64,
TrimStop = true
};

List<string> tokens = new List<string>();
using CancellationTokenSource? cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

try
{
await foreach (String token in Client.GenerateStreamAsync(request, cts.Token))
{
tokens.Add(token);
Output.WriteLine($"Received token: {token}");

if (tokens.Count >= request.MaxLength)
{
Output.WriteLine("Reached max length, breaking");
break;
}
}
}
catch (OperationCanceledException) when (cts.Token.IsCancellationRequested)
{
Output.WriteLine($"Stream timed out after receiving {tokens.Count} tokens");
}
catch (Exception ex)
{
Output.WriteLine($"Error during streaming: {ex}");
throw;
}

tokens.Should().NotBeEmpty("No tokens were received from the stream");
string.Concat(tokens).Should().NotBeNullOrEmpty("Combined token text should not be empty");
}
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
using FluentAssertions;
using SpongeEngine.KoboldSharp.Models;
using Xunit;
using Xunit.Abstractions;

namespace SpongeEngine.KoboldSharp.Tests.Integration
{
[Trait("Category", "Integration")]
[Trait("API", "Native")]
public class IntegrationTests : IntegrationTestBase
public class UnitedIntegrationTests : IntegrationTestBase
{
public IntegrationTests(ITestOutputHelper output) : base(output) { }

public UnitedIntegrationTests(ITestOutputHelper output) : base(output) {}
[SkippableFact]
[Trait("Category", "Integration")]
public async Task Generate_WithSimplePrompt_ShouldReturnResponse()
Expand All @@ -37,56 +36,7 @@ public async Task Generate_WithSimplePrompt_ShouldReturnResponse()
response.Results.Should().NotBeEmpty();
response.Results[0].Text.Should().NotBeNullOrEmpty();
}

[SkippableFact]
[Trait("Category", "Integration")]
public async Task GenerateStream_ShouldStreamTokens()
{
Skip.If(!ServerAvailable, "KoboldCpp server is not available");

KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Write a short story about",
MaxLength = 20,
Temperature = 0.7f,
TopP = 0.9f,
TopK = 40,
RepetitionPenalty = 1.1f,
RepetitionPenaltyRange = 64,
TrimStop = true
};

List<string> tokens = new List<string>();
using CancellationTokenSource? cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));

try
{
await foreach (String token in Client.GenerateStreamAsync(request, cts.Token))
{
tokens.Add(token);
Output.WriteLine($"Received token: {token}");

if (tokens.Count >= request.MaxLength)
{
Output.WriteLine("Reached max length, breaking");
break;
}
}
}
catch (OperationCanceledException) when (cts.Token.IsCancellationRequested)
{
Output.WriteLine($"Stream timed out after receiving {tokens.Count} tokens");
}
catch (Exception ex)
{
Output.WriteLine($"Error during streaming: {ex}");
throw;
}

tokens.Should().NotBeEmpty("No tokens were received from the stream");
string.Concat(tokens).Should().NotBeNullOrEmpty("Combined token text should not be empty");
}


[SkippableFact]
[Trait("Category", "Integration")]
public async Task Generate_WithStopSequence_ShouldGenerateText()
Expand All @@ -111,7 +61,7 @@ public async Task Generate_WithStopSequence_ShouldGenerateText()
response.Results.Should().NotBeEmpty();
response.Results[0].Text.Should().NotBeNullOrEmpty();
}

[SkippableFact]
[Trait("Category", "Integration")]
public async Task Generate_WithDifferentTemperatures_ShouldWork()
Expand Down Expand Up @@ -148,5 +98,20 @@ public async Task Generate_WithDifferentTemperatures_ShouldWork()
await Task.Delay(500);
}
}

[SkippableFact]
[Trait("Category", "Integration")]
public async Task GetVersionInfo_ShouldReturnValidVersion()
{
Skip.If(!ServerAvailable, "KoboldCpp server is not available");

// Act
var versionInfo = await Client.GetVersionInfoAsync();

// Assert
versionInfo.Should().NotBeNull();
versionInfo.Version.Should().NotBeNullOrEmpty();
Output.WriteLine($"Server version: {versionInfo.Version}");
}
}
}
81 changes: 81 additions & 0 deletions SpongeEngine.KoboldSharp.Tests/Unit/ExtraUnitTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using FluentAssertions;
using WireMock.RequestBuilders;
using WireMock.ResponseBuilders;
using Xunit;
using Xunit.Abstractions;

namespace SpongeEngine.KoboldSharp.Tests.Unit
{
public class ExtraUnitTests : UnitTestBase
{
public ExtraUnitTests(ITestOutputHelper output) : base(output) {}

[Fact]
public async Task GenerateStreamAsync_ShouldStreamTokens()
{
// Arrange
var tokens = new[] { "Hello", " world", "!" };
Server
.Given(Request.Create()
.WithPath("/api/extra/generate/stream")
.UsingPost())
.RespondWith(Response.Create()
.WithStatusCode(200)
.WithBody("data: {\"token\": \"Hello\", \"complete\": false}\n\n" +
"data: {\"token\": \" world\", \"complete\": false}\n\n" +
"data: {\"token\": \"!\", \"complete\": true}\n\n")
.WithHeader("Content-Type", "text/event-stream"));

KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Test prompt",
MaxLength = 80,
Stream = true
};

// Act
var receivedTokens = new List<string>();
await foreach (var token in Client.GenerateStreamAsync(request))
{
receivedTokens.Add(token);
}

// Assert
receivedTokens.Should().BeEquivalentTo(tokens);
}

[Fact]
public async Task GenerateStreamAsync_ShouldStreamResponse()
{
// Arrange
var tokens = new[] { "Hello", " world", "!" };
Server
.Given(Request.Create()
.WithPath("/api/extra/generate/stream")
.UsingPost())
.RespondWith(Response.Create()
.WithStatusCode(200)
.WithBody("data: {\"token\": \"Hello\", \"complete\": false}\n\n" +
"data: {\"token\": \" world\", \"complete\": false}\n\n" +
"data: {\"token\": \"!\", \"complete\": true}\n\n")
.WithHeader("Content-Type", "text/event-stream"));

KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Test prompt",
MaxLength = 80,
Stream = true
};

// Act
var receivedTokens = new List<string>();
await foreach (var token in Client.GenerateStreamAsync(request))
{
receivedTokens.Add(token);
}

// Assert
receivedTokens.Should().BeEquivalentTo(tokens);
}
}
}
13 changes: 13 additions & 0 deletions SpongeEngine.KoboldSharp.Tests/Unit/UnitTestBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Microsoft.Extensions.Logging;
using SpongeEngine.KoboldSharp.Models;
using SpongeEngine.KoboldSharp.Tests.Common;
using WireMock.Server;
using Xunit.Abstractions;

Expand All @@ -9,6 +11,8 @@ public abstract class UnitTestBase : IDisposable
protected readonly WireMockServer Server;
protected readonly ILogger Logger;
protected readonly string BaseUrl;
protected readonly KoboldSharpClient Client;
protected readonly HttpClient HttpClient;

protected UnitTestBase(ITestOutputHelper output)
{
Expand All @@ -17,10 +21,19 @@ protected UnitTestBase(ITestOutputHelper output)
Logger = LoggerFactory
.Create(builder => builder.AddXUnit(output))
.CreateLogger(GetType());
HttpClient = new HttpClient { BaseAddress = new Uri(BaseUrl) };
Client = new KoboldSharpClient(
new KoboldSharpClientOptions()
{
HttpClient = HttpClient,
BaseUrl = TestConfig.BaseUrl,
Logger = Logger,
});
}

public virtual void Dispose()
{
HttpClient.Dispose();
Server.Dispose();
}
}
Expand Down
Loading

0 comments on commit 4efbb2f

Please sign in to comment.