Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
dclipca committed Jan 15, 2025
1 parent 87b3398 commit b6b9e39
Show file tree
Hide file tree
Showing 35 changed files with 1,457 additions and 391 deletions.
13 changes: 10 additions & 3 deletions SpongeEngine.KoboldSharp.Tests/Integration/IntegrationTestBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using SpongeEngine.KoboldSharp.Models;
using SpongeEngine.KoboldSharp.Tests.Common;
using SpongeEngine.LLMSharp.Core.Configuration;
using SpongeEngine.LLMSharp.Core;
using Xunit;
using Xunit.Abstractions;

Expand All @@ -25,8 +26,14 @@ protected IntegrationTestBase(ITestOutputHelper output)
BaseAddress = new Uri(TestConfig.NativeApiBaseUrl),
Timeout = TimeSpan.FromSeconds(TestConfig.TimeoutSeconds)
};

Client = new KoboldSharpClient(httpClient, new LlmClientOptions(), "", TestConfig.BaseUrl, Logger);

Client = new KoboldSharpClient(
new KoboldSharpClientOptions()
{
HttpClient = httpClient,
BaseUrl = TestConfig.BaseUrl,
Logger = Logger,
});
}

public async Task InitializeAsync()
Expand Down
14 changes: 7 additions & 7 deletions SpongeEngine.KoboldSharp.Tests/Integration/IntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public async Task Generate_WithSimplePrompt_ShouldReturnResponse()
Skip.If(!ServerAvailable, "KoboldCpp server is not available");

// Arrange
KoboldSharpRequest request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Once upon a time",
MaxLength = 20,
Expand All @@ -30,7 +30,7 @@ public async Task Generate_WithSimplePrompt_ShouldReturnResponse()
};

// Act
KoboldSharpResponse response = await Client.GenerateAsync(request);
KoboldSharpClient.GenerateAsyncResponse response = await Client.GenerateAsync(request);

// Assert
response.Should().NotBeNull();
Expand All @@ -44,7 +44,7 @@ public async Task GenerateStream_ShouldStreamTokens()
{
Skip.If(!ServerAvailable, "KoboldCpp server is not available");

KoboldSharpRequest request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Write a short story about",
MaxLength = 20,
Expand Down Expand Up @@ -94,7 +94,7 @@ public async Task Generate_WithStopSequence_ShouldGenerateText()
Skip.If(!ServerAvailable, "KoboldCpp server is not available");

// Arrange
KoboldSharpRequest request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Write a short story",
MaxLength = 20,
Expand All @@ -104,7 +104,7 @@ public async Task Generate_WithStopSequence_ShouldGenerateText()
};

// Act
KoboldSharpResponse response = await Client.GenerateAsync(request);
KoboldSharpClient.GenerateAsyncResponse response = await Client.GenerateAsync(request);

// Assert
response.Should().NotBeNull();
Expand All @@ -123,7 +123,7 @@ public async Task Generate_WithDifferentTemperatures_ShouldWork()
foreach (float temp in temperatures)
{
// Arrange
KoboldSharpRequest request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "The quick brown fox",
MaxLength = 20,
Expand All @@ -135,7 +135,7 @@ public async Task Generate_WithDifferentTemperatures_ShouldWork()
};

// Act
KoboldSharpResponse response = await Client.GenerateAsync(request);
KoboldSharpClient.GenerateAsyncResponse response = await Client.GenerateAsync(request);

// Assert
response.Should().NotBeNull();
Expand Down
18 changes: 12 additions & 6 deletions SpongeEngine.KoboldSharp.Tests/Unit/UnitTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using FluentAssertions;
using SpongeEngine.KoboldSharp.Models;
using SpongeEngine.KoboldSharp.Tests.Common;
using SpongeEngine.LLMSharp.Core.Configuration;
using SpongeEngine.LLMSharp.Core;
using WireMock.RequestBuilders;
using WireMock.ResponseBuilders;
using Xunit;
Expand All @@ -17,7 +17,13 @@ public class UnitTests : UnitTestBase
public UnitTests(ITestOutputHelper output) : base(output)
{
_httpClient = new HttpClient { BaseAddress = new Uri(BaseUrl) };
_client = new KoboldSharpClient(_httpClient, new LlmClientOptions(), "KoboldCpp", TestConfig.BaseUrl, Logger);
_client = new KoboldSharpClient(
new KoboldSharpClientOptions()
{
HttpClient = _httpClient,
BaseUrl = TestConfig.BaseUrl,
Logger = Logger,
});
}

[Fact]
Expand All @@ -33,7 +39,7 @@ public async Task GenerateAsync_ShouldReturnValidResponse()
.WithStatusCode(200)
.WithBody($"{{\"results\": [{{\"text\": \"{expectedResponse}\", \"tokens\": 3}}]}}"));

var request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Test prompt",
MaxLength = 80
Expand Down Expand Up @@ -63,7 +69,7 @@ public async Task GenerateStreamAsync_ShouldStreamTokens()
"data: {\"token\": \"!\", \"complete\": true}\n\n")
.WithHeader("Content-Type", "text/event-stream"));

var request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Test prompt",
MaxLength = 80,
Expand Down Expand Up @@ -131,7 +137,7 @@ public async Task GenerateAsync_WithValidRequest_ShouldReturnResponse()
.WithStatusCode(200)
.WithBody($"{{\"results\": [{{\"text\": \"{expectedResponse}\", \"tokens\": 3}}]}}"));

var request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Test prompt",
MaxLength = 80
Expand Down Expand Up @@ -161,7 +167,7 @@ public async Task GenerateStreamAsync_ShouldStreamResponse()
"data: {\"token\": \"!\", \"complete\": true}\n\n")
.WithHeader("Content-Type", "text/event-stream"));

var request = new KoboldSharpRequest
KoboldSharpClient.KoboldSharpRequest request = new KoboldSharpClient.KoboldSharpRequest
{
Prompt = "Test prompt",
MaxLength = 80,
Expand Down
6 changes: 6 additions & 0 deletions SpongeEngine.KoboldSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SpongeEngine.KoboldSharp",
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SpongeEngine.KoboldSharp.Tests", "SpongeEngine.KoboldSharp.Tests\SpongeEngine.KoboldSharp.Tests.csproj", "{B4C06852-1A11-45B9-8E83-CF7ABEE36F12}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SpongeEngine.LLMSharp.Core", "..\LLMSharp.Core\SpongeEngine.LLMSharp.Core\SpongeEngine.LLMSharp.Core.csproj", "{A50A731B-68AC-4D86-869B-A9CB50C91466}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand All @@ -17,5 +19,9 @@ Global
{B4C06852-1A11-45B9-8E83-CF7ABEE36F12}.Debug|Any CPU.Build.0 = Debug|Any CPU
{B4C06852-1A11-45B9-8E83-CF7ABEE36F12}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B4C06852-1A11-45B9-8E83-CF7ABEE36F12}.Release|Any CPU.Build.0 = Release|Any CPU
{A50A731B-68AC-4D86-869B-A9CB50C91466}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A50A731B-68AC-4D86-869B-A9CB50C91466}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A50A731B-68AC-4D86-869B-A9CB50C91466}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A50A731B-68AC-4D86-869B-A9CB50C91466}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
EndGlobal
73 changes: 73 additions & 0 deletions SpongeEngine.KoboldSharp/Common/KoboldSharpRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using System.Text.Json.Serialization;

namespace SpongeEngine.KoboldSharp
{
public partial class KoboldSharpClient
{
public class KoboldSharpRequest
{
[JsonPropertyName("prompt")]
public string Prompt { get; set; } = string.Empty;

[JsonPropertyName("max_length")]
public int MaxLength { get; set; } = 80;

[JsonPropertyName("max_context_length")]
public int? MaxContextLength { get; set; }

[JsonPropertyName("temperature")]
public float Temperature { get; set; } = 0.7f;

[JsonPropertyName("top_p")]
public float TopP { get; set; } = 0.9f;

[JsonPropertyName("top_k")]
public int TopK { get; set; } = 40;

[JsonPropertyName("top_a")]
public float TopA { get; set; } = 0.0f;

[JsonPropertyName("typical")]
public float Typical { get; set; } = 1.0f;

[JsonPropertyName("tfs")]
public float Tfs { get; set; } = 1.0f;

[JsonPropertyName("rep_pen")]
public float RepetitionPenalty { get; set; } = 1.1f;

[JsonPropertyName("rep_pen_range")]
public int RepetitionPenaltyRange { get; set; } = 64;

[JsonPropertyName("mirostat")]
public int MirostatMode { get; set; } = 0;

[JsonPropertyName("mirostat_tau")]
public float MirostatTau { get; set; } = 5.0f;

[JsonPropertyName("mirostat_eta")]
public float MirostatEta { get; set; } = 0.1f;

[JsonPropertyName("stop_sequence")]
public List<string>? StopSequences { get; set; }

[JsonPropertyName("stream")]
public bool Stream { get; set; }

[JsonPropertyName("trim_stop")]
public bool TrimStop { get; set; } = true;

[JsonPropertyName("grammar")]
public string? Grammar { get; set; }

[JsonPropertyName("memory")]
public string? Memory { get; set; }

[JsonPropertyName("banned_tokens")]
public List<string>? BannedTokens { get; set; }

[JsonPropertyName("logit_bias")]
public Dictionary<string, float>? LogitBias { get; set; }
}
}
}
49 changes: 49 additions & 0 deletions SpongeEngine.KoboldSharp/Extra/AbortGenerateAsync.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using SpongeEngine.LLMSharp.Core.Exceptions;

namespace SpongeEngine.KoboldSharp
{
public partial class KoboldSharpClient
{
private class AbortResponse
{
[JsonPropertyName("success")]
public string Success { get; set; } = string.Empty;

[JsonPropertyName("done")]
public string Done { get; set; } = string.Empty;
}

public async Task<bool> AbortGenerateAsync(string? genKey = null, CancellationToken cancellationToken = default)
{
try
{
using HttpRequestMessage httpRequest = new(HttpMethod.Post, "api/extra/abort");
if (!string.IsNullOrEmpty(genKey))
{
httpRequest.Content = JsonContent.Create(new { genkey = genKey });
}

using HttpResponseMessage? response = await Options.HttpClient.SendAsync(httpRequest, cancellationToken);
string? content = await response.Content.ReadAsStringAsync(cancellationToken);

if (!response.IsSuccessStatusCode)
{
throw new LlmSharpException(
"Failed to abort generation",
(int)response.StatusCode,
content);
}

var result = JsonSerializer.Deserialize<AbortResponse>(content);
return result?.Success == "true";
}
catch (Exception ex) when (ex is not LlmSharpException)
{
throw new LlmSharpException("Failed to abort generation", innerException: ex);
}
}
}
}
57 changes: 57 additions & 0 deletions SpongeEngine.KoboldSharp/Extra/CountTokensAsync.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using SpongeEngine.LLMSharp.Core.Exceptions;

namespace SpongeEngine.KoboldSharp
{
public partial class KoboldSharpClient
{
public class CountTokensRequest
{
[JsonPropertyName("prompt")]
public string Prompt { get; set; } = string.Empty;

[JsonPropertyName("special")]
public bool IncludeSpecialTokens { get; set; } = true;
}

public class CountTokensResponse
{
[JsonPropertyName("value")]
public int Count { get; set; }

[JsonPropertyName("ids")]
public List<int> TokenIds { get; set; } = new();
}

/// <summary>
/// Counts tokens in a given prompt string.
/// </summary>
public async Task<CountTokensResponse> CountTokensAsync(CountTokensRequest request, CancellationToken cancellationToken = default)
{
try
{
using HttpRequestMessage httpRequest = new(HttpMethod.Post, "api/extra/tokencount");
httpRequest.Content = JsonContent.Create(request);

using HttpResponseMessage? response = await Options.HttpClient.SendAsync(httpRequest, cancellationToken);
string? content = await response.Content.ReadAsStringAsync(cancellationToken);

if (!response.IsSuccessStatusCode)
{
throw new LlmSharpException(
"Failed to count tokens",
(int)response.StatusCode,
content);
}

return JsonSerializer.Deserialize<CountTokensResponse>(content) ?? throw new LlmSharpException("Failed to deserialize token count response");
}
catch (Exception ex) when (ex is not LlmSharpException)
{
throw new LlmSharpException("Failed to count tokens", innerException: ex);
}
}
}
}
Loading

0 comments on commit b6b9e39

Please sign in to comment.