Skip to content

Commit

Permalink
Merge pull request #136 from awaescher/merge-129
Browse files Browse the repository at this point in the history
Merge 129
  • Loading branch information
awaescher authored Nov 5, 2024
2 parents 1ccce48 + 48e1ce7 commit 6d27f97
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 63 deletions.
23 changes: 12 additions & 11 deletions demo/OllamaApiConsole.csproj
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<NoWarn>IDE0065;IDE0055;IDE0011</NoWarn>
</PropertyGroup>

<ItemGroup>
<ItemGroup>
<!--
SixLabors.ImageSharp added explicitly to fix CVE-2024-41131: https://github.com/advisories/GHSA-63p8-c4ww-9cg7
and can be removed once Spectre.Console.ImageSharp uses a version greater than 3.1.4
-->
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.5" />
<PackageReference Include="Spectre.Console" Version="0.49.1" />
<PackageReference Include="Spectre.Console.ImageSharp" Version="0.49.1" />
</ItemGroup>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\src\OllamaSharp.csproj" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\src\OllamaSharp.csproj" />
</ItemGroup>

</Project>
32 changes: 27 additions & 5 deletions src/MicrosoftAi/AbstractionMapper.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.Extensions.AI;
using OllamaSharp.Models;
Expand Down Expand Up @@ -44,13 +45,14 @@ public static class AbstractionMapper
/// <param name="chatMessages">A list of chat messages.</param>
/// <param name="options">Optional chat options to configure the request.</param>
/// <param name="stream">Indicates if the request should be streamed.</param>
public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessages, ChatOptions? options, bool stream)
/// <param name="serializerOptions">Serializer options</param>
public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessages, ChatOptions? options, bool stream, JsonSerializerOptions serializerOptions)
{
var request = new ChatRequest
{
Format = options?.ResponseFormat == ChatResponseFormat.Json ? "json" : null,
KeepAlive = null,
Messages = ToOllamaSharpMessages(chatMessages),
Messages = ToOllamaSharpMessages(chatMessages, serializerOptions),
Model = options?.ModelId ?? "", // will be set OllamaApiClient.SelectedModel if not set
Options = new Models.RequestOptions
{
Expand Down Expand Up @@ -190,17 +192,37 @@ private static string ToFunctionTypeString(JsonObject? schema)
/// Converts a list of Microsoft.Extensions.AI.<see cref="ChatMessage"/> to a list of Ollama <see cref="Message"/>.
/// </summary>
/// <param name="chatMessages">The chat messages to convert.</param>
private static IEnumerable<Message> ToOllamaSharpMessages(IList<ChatMessage> chatMessages)
/// <param name="serializerOptions">Serializer options</param>
private static IEnumerable<Message> ToOllamaSharpMessages(IList<ChatMessage> chatMessages, JsonSerializerOptions serializerOptions)
{
foreach (var cm in chatMessages)
{
var images = cm.Contents.OfType<ImageContent>().Select(ToOllamaImage).Where(s => !string.IsNullOrEmpty(s)).ToArray();
var toolCalls = cm.Contents.OfType<FunctionCallContent>().Select(ToOllamaSharpToolCall).ToArray();

yield return new Message
{
Content = cm.Text,
Images = cm.Contents.OfType<ImageContent>().Select(ToOllamaImage).Where(s => !string.IsNullOrEmpty(s)).ToArray(),
Images = images.Length > 0 ? images : null,
Role = ToOllamaSharpRole(cm.Role),
ToolCalls = cm.Contents.OfType<FunctionCallContent>().Select(ToOllamaSharpToolCall),
ToolCalls = toolCalls.Length > 0 ? toolCalls : null,
};

// If the message contains a function result, add it as a separate tool message
foreach (var frc in cm.Contents.OfType<FunctionResultContent>())
{
var jsonResult = JsonSerializer.SerializeToElement(frc.Result, serializerOptions);

yield return new Message
{
Content = JsonSerializer.Serialize(new OllamaFunctionResultContent
{
CallId = frc.CallId,
Result = jsonResult,
}, serializerOptions),
Role = Models.Chat.ChatRole.Tool,
};
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions src/MicrosoftAi/OllamaFunctionResultContent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace OllamaSharp.MicrosoftAi;

using System.Text.Json;

internal sealed class OllamaFunctionResultContent
{
public string? CallId { get; set; }
public JsonElement Result { get; set; }
}
4 changes: 2 additions & 2 deletions src/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -385,15 +385,15 @@ private async Task EnsureSuccessStatusCodeAsync(HttpResponseMessage response)
/// <inheritdoc/>
async Task<ChatCompletion> IChatClient.CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options, CancellationToken cancellationToken)
{
var request = MicrosoftAi.AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: false);
var request = MicrosoftAi.AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: false, OutgoingJsonSerializerOptions);
var response = await ChatAsync(request, cancellationToken).StreamToEndAsync().ConfigureAwait(false);
return MicrosoftAi.AbstractionMapper.ToChatCompletion(response, response?.Model ?? request.Model ?? SelectedModel) ?? new ChatCompletion([]);
}

/// <inheritdoc/>
async IAsyncEnumerable<StreamingChatCompletionUpdate> IChatClient.CompleteStreamingAsync(IList<ChatMessage> chatMessages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var request = MicrosoftAi.AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true);
var request = MicrosoftAi.AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true, OutgoingJsonSerializerOptions);
await foreach (var response in ChatAsync(request, cancellationToken).ConfigureAwait(false))
yield return MicrosoftAi.AbstractionMapper.ToStreamingChatCompletionUpdate(response);
}
Expand Down
1 change: 1 addition & 0 deletions src/OllamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<SignAssembly>True</SignAssembly>
<AssemblyOriginatorKeyFile>..\OllamaSharp.snk</AssemblyOriginatorKeyFile>
<NoWarn>IDE0065;IDE0055;IDE0011;S3881</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand Down
100 changes: 81 additions & 19 deletions test/AbstractionMapperTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Text.Json;
using FluentAssertions;
using Microsoft.Extensions.AI;
using NUnit.Framework;
Expand All @@ -8,12 +9,9 @@

namespace Tests;

#pragma warning disable CS8602 // Dereference of a possibly null reference.
#pragma warning disable CS8604 // Possible null reference argument.

public partial class AbstractionMapperTests
public class AbstractionMapperTests
{
public partial class ToOllamaSharpChatRequestMethod : AbstractionMapperTests
public class ToOllamaSharpChatRequestMethod : AbstractionMapperTests
{
[Test]
public void Maps_Partial_Options_Class()
Expand All @@ -26,7 +24,7 @@ public void Maps_Partial_Options_Class()

var options = new ChatOptions { Temperature = 0.5f, /* other properties are left out */ };

var request = AbstractionMapper.ToOllamaSharpChatRequest(messages, options, stream: true);
var request = AbstractionMapper.ToOllamaSharpChatRequest(messages, options, stream: true, JsonSerializerOptions.Default);

request.Options.F16kv.Should().BeNull();
request.Options.FrequencyPenalty.Should().BeNull();
Expand Down Expand Up @@ -95,7 +93,7 @@ public void Maps_Messages()
},
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true);
var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true, JsonSerializerOptions.Default);

chatRequest.Messages.Should().HaveCount(3);

Expand Down Expand Up @@ -148,7 +146,7 @@ public void Maps_Base64_Images()
},
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true);
var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true, JsonSerializerOptions.Default);

chatRequest.Messages.Should().HaveCount(2);

Expand Down Expand Up @@ -180,7 +178,7 @@ public void Maps_Byte_Array_Images()
}
};

var request = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true);
var request = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true, JsonSerializerOptions.Default);
request.Messages.Single().Images.Single().Should().Be("QUJD");
}

Expand All @@ -207,7 +205,7 @@ public void Does_Not_Support_Image_Links()

Action act = () =>
{
var request = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true);
var request = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, null, stream: true, JsonSerializerOptions.Default);
request.Messages.Should().NotBeEmpty(); // access .Messages to invoke the evaluation of IEnumerable<Message>
};

Expand All @@ -234,7 +232,7 @@ public void Maps_Messages_With_Tools()
Tools = [new WeatherFunction()]
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true);
var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true, JsonSerializerOptions.Default);

var tool = chatRequest.Tools.Single();
tool.Function.Description.Should().Be("Gets the current weather for a current location");
Expand All @@ -251,6 +249,73 @@ public void Maps_Messages_With_Tools()
tool.Type.Should().Be("function");
}

[Test]
public void Maps_Messages_With_ToolResponse()
{
var chatMessages = new List<Microsoft.Extensions.AI.ChatMessage>
{
new()
{
AdditionalProperties = [],
AuthorName = "a1",
RawRepresentation = null,
Role = Microsoft.Extensions.AI.ChatRole.Tool,
Text = "The weather in Honolulu is 25°C."
}
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, new(), stream: true, JsonSerializerOptions.Default);

var tool = chatRequest.Messages.Single();
tool.Content.Should().Contain("The weather in Honolulu is 25°C.");
tool.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
}

[Test]
public void Maps_Messages_With_MultipleToolResponse()
{
var aiChatMessages = new List<Microsoft.Extensions.AI.ChatMessage>
{
new()
{
AdditionalProperties = [],
AuthorName = "a1",
RawRepresentation = null,
Role = Microsoft.Extensions.AI.ChatRole.User,
Contents = [
new TextContent("I have found those 2 results"),
new FunctionResultContent(
callId: "123",
name: "Function1",
result: new { Temperature = 40 }),

new FunctionResultContent(
callId: "456",
name: "Function2",
result: new { Summary = "This is a tool result test" }
),
]
}
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(aiChatMessages, new(), stream: true, JsonSerializerOptions.Default);
var chatMessages = chatRequest.Messages?.ToList();

chatMessages.Should().HaveCount(3);

var user = chatMessages[0];
var tool1 = chatMessages[1];
var tool2 = chatMessages[2];
tool1.Content.Should().Contain("\"Temperature\":40");
tool1.Content.Should().Contain("\"CallId\":\"123\"");
tool1.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
tool2.Content.Should().Contain("\"Summary\":\"This is a tool result test\"");
tool2.Content.Should().Contain("\"CallId\":\"456\"");
tool2.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
user.Content.Should().Contain("I have found those 2 results");
user.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.User);
}

[Test]
public void Maps_Options()
{
Expand All @@ -268,7 +333,7 @@ public void Maps_Options()
TopP = 10.1f
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true);
var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true, JsonSerializerOptions.Default);

chatRequest.Format.Should().Be("json");
chatRequest.Model.Should().Be("llama3.1:405b");
Expand Down Expand Up @@ -346,7 +411,7 @@ public void Maps_Ollama_Options()
.AddOllamaOption(OllamaOption.UseMmap, true)
.AddOllamaOption(OllamaOption.VocabOnly, false);

var ollamaRequest = AbstractionMapper.ToOllamaSharpChatRequest([], options, stream: true);
var ollamaRequest = AbstractionMapper.ToOllamaSharpChatRequest([], options, stream: true, JsonSerializerOptions.Default);

ollamaRequest.Options.F16kv.Should().Be(true);
ollamaRequest.Options.FrequencyPenalty.Should().Be(0.11f);
Expand Down Expand Up @@ -507,7 +572,7 @@ public void Maps_ToolCalls()
}
}

public partial class ToOllamaEmbedRequestMethod : AbstractionMapperTests
public class ToOllamaEmbedRequestMethod : AbstractionMapperTests
{
[Test]
public void Maps_Request()
Expand Down Expand Up @@ -544,7 +609,7 @@ public void Maps_KeepAlive_And_Truncate_From_AdditionalProperties()
}
}

public partial class ToGeneratedEmbeddingsMethod : AbstractionMapperTests
public class ToGeneratedEmbeddingsMethod : AbstractionMapperTests
{
[Test]
public void Maps_Response()
Expand Down Expand Up @@ -575,7 +640,4 @@ public void Maps_Response()
mappedResponse.Usage.TotalTokenCount.Should().Be(18);
}
}
}

#pragma warning restore CS8602 // Dereference of a possibly null reference.
#pragma warning restore CS8604 // Possible null reference argument.
}
3 changes: 1 addition & 2 deletions test/ChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public async Task Sends_Assistant_ToolsCall_To_Streamer()

chat.Messages.Last().Role.Should().Be(ChatRole.Assistant);
chat.Messages.Last().ToolCalls.Should().HaveCount(1);
chat.Messages.Last().ToolCalls!.ElementAt(0).Function!.Name.Should().Be("get_current_weather");
chat.Messages.Last().ToolCalls.ElementAt(0).Function.Name.Should().Be("get_current_weather");
}

[Test]
Expand Down Expand Up @@ -117,7 +117,6 @@ public async Task Sends_Messages_As_Defined_Role()
history[1].Content.Should().Be("Hi tool.");
}


[Test]
public async Task Sends_Image_Bytes_As_Base64()
{
Expand Down
6 changes: 1 addition & 5 deletions test/IAsyncEnumerableExtensionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace Tests;

#pragma warning disable CS8602 // Dereference of a possibly null reference.

public class IAsyncEnumerableExtensionTests
{
public class StreamToEndAsyncMethod : IAsyncEnumerableExtensionTests
Expand Down Expand Up @@ -73,6 +71,4 @@ public async Task Throws_If_No_Done_Response_Was_Send()
private static Message CreateMessage(ChatRole role, string content)
=> new() { Role = role, Content = content };
}
}

#pragma warning restore CS8602 // Dereference of a possibly null reference.
}
Loading

0 comments on commit 6d27f97

Please sign in to comment.