Skip to content

Commit

Permalink
Merge pull request #138 from RogerBarreto/message-tool-fix
Browse files Browse the repository at this point in the history
Fix: Avoid generating empty tool messages when there's no content for it
  • Loading branch information
awaescher authored Nov 7, 2024
2 parents 6d27f97 + cd0cf70 commit 62081c7
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 39 deletions.
83 changes: 44 additions & 39 deletions src/MicrosoftAi/AbstractionMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,38 @@ public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessag

if (options?.AdditionalProperties?.Any() ?? false)
{
TryAddOllamaOption<bool?>(options, OllamaOption.F16kv, v => request.Options.F16kv = v);
TryAddOllamaOption<float?>(options, OllamaOption.FrequencyPenalty, v => request.Options.FrequencyPenalty = v);
TryAddOllamaOption<bool?>(options, OllamaOption.LogitsAll, v => request.Options.LogitsAll = v);
TryAddOllamaOption<bool?>(options, OllamaOption.LowVram, v => request.Options.LowVram = v);
TryAddOllamaOption<int?>(options, OllamaOption.MainGpu, v => request.Options.MainGpu = v);
TryAddOllamaOption<float?>(options, OllamaOption.MinP, v => request.Options.MinP = v);
TryAddOllamaOption<int?>(options, OllamaOption.MiroStat, v => request.Options.MiroStat = v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatEta, v => request.Options.MiroStatEta = v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatTau, v => request.Options.MiroStatTau = v);
TryAddOllamaOption<bool?>(options, OllamaOption.Numa, v => request.Options.Numa = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumBatch, v => request.Options.NumBatch = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumCtx, v => request.Options.NumCtx = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGpu, v => request.Options.NumGpu = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGqa, v => request.Options.NumGqa = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumKeep, v => request.Options.NumKeep = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumPredict, v => request.Options.NumPredict = v);
TryAddOllamaOption<int?>(options, OllamaOption.NumThread, v => request.Options.NumThread = v);
TryAddOllamaOption<bool?>(options, OllamaOption.PenalizeNewline, v => request.Options.PenalizeNewline = v);
TryAddOllamaOption<float?>(options, OllamaOption.PresencePenalty, v => request.Options.PresencePenalty = v);
TryAddOllamaOption<int?>(options, OllamaOption.RepeatLastN, v => request.Options.RepeatLastN = v);
TryAddOllamaOption<float?>(options, OllamaOption.RepeatPenalty, v => request.Options.RepeatPenalty = v);
TryAddOllamaOption<int?>(options, OllamaOption.Seed, v => request.Options.Seed = v);
TryAddOllamaOption<string[]?>(options, OllamaOption.Stop, v => request.Options.Stop = v);
TryAddOllamaOption<float?>(options, OllamaOption.Temperature, v => request.Options.Temperature = v);
TryAddOllamaOption<float?>(options, OllamaOption.TfsZ, v => request.Options.TfsZ = v);
TryAddOllamaOption<int?>(options, OllamaOption.TopK, v => request.Options.TopK = v);
TryAddOllamaOption<float?>(options, OllamaOption.TopP, v => request.Options.TopP = v);
TryAddOllamaOption<float?>(options, OllamaOption.TypicalP, v => request.Options.TypicalP = v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMlock, v => request.Options.UseMlock = v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMmap, v => request.Options.UseMmap = v);
TryAddOllamaOption<bool?>(options, OllamaOption.VocabOnly, v => request.Options.VocabOnly = v);
TryAddOllamaOption<bool?>(options, OllamaOption.F16kv, v => request.Options.F16kv = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.FrequencyPenalty, v => request.Options.FrequencyPenalty = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LogitsAll, v => request.Options.LogitsAll = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.LowVram, v => request.Options.LowVram = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MainGpu, v => request.Options.MainGpu = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MinP, v => request.Options.MinP = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.MiroStat, v => request.Options.MiroStat = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatEta, v => request.Options.MiroStatEta = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.MiroStatTau, v => request.Options.MiroStatTau = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.Numa, v => request.Options.Numa = (bool?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumBatch, v => request.Options.NumBatch = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumCtx, v => request.Options.NumCtx = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGpu, v => request.Options.NumGpu = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumGqa, v => request.Options.NumGqa = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumKeep, v => request.Options.NumKeep = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumPredict, v => request.Options.NumPredict = (int?)v);
TryAddOllamaOption<int?>(options, OllamaOption.NumThread, v => request.Options.NumThread = (int?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.PenalizeNewline, v => request.Options.PenalizeNewline = (bool?)v);
TryAddOllamaOption<float?>(options, OllamaOption.PresencePenalty, v => request.Options.PresencePenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.RepeatLastN, v => request.Options.RepeatLastN = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.RepeatPenalty, v => request.Options.RepeatPenalty = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.Seed, v => request.Options.Seed = (int?)v);
TryAddOllamaOption<string[]?>(options, OllamaOption.Stop,
v => request.Options.Stop = (v as IEnumerable<string>)?.ToArray());
TryAddOllamaOption<float?>(options, OllamaOption.Temperature, v => request.Options.Temperature = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TfsZ, v => request.Options.TfsZ = (float?)v);
TryAddOllamaOption<int?>(options, OllamaOption.TopK, v => request.Options.TopK = (int?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TopP, v => request.Options.TopP = (float?)v);
TryAddOllamaOption<float?>(options, OllamaOption.TypicalP, v => request.Options.TypicalP = (float?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMlock, v => request.Options.UseMlock = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.UseMmap, v => request.Options.UseMmap = (bool?)v);
TryAddOllamaOption<bool?>(options, OllamaOption.VocabOnly, v => request.Options.VocabOnly = (bool?)v);
}

return request;
Expand All @@ -113,10 +114,10 @@ public static ChatRequest ToOllamaSharpChatRequest(IList<ChatMessage> chatMessag
/// <param name="microsoftChatOptions">The chat options from the Microsoft abstraction</param>
/// <param name="option">The Ollama setting to add</param>
/// <param name="optionSetter">The setter to set the Ollama option if available in the chat options</param>
private static void TryAddOllamaOption<T>(ChatOptions microsoftChatOptions, OllamaOption option, Action<T> optionSetter)
private static void TryAddOllamaOption<T>(ChatOptions microsoftChatOptions, OllamaOption option, Action<object?> optionSetter)
{
if ((microsoftChatOptions?.AdditionalProperties?.TryGetValue(option.Name, out var value) ?? false) && value is not null)
optionSetter((T)value);
optionSetter(value);
}

/// <summary>
Expand Down Expand Up @@ -200,13 +201,17 @@ private static IEnumerable<Message> ToOllamaSharpMessages(IList<ChatMessage> cha
var images = cm.Contents.OfType<ImageContent>().Select(ToOllamaImage).Where(s => !string.IsNullOrEmpty(s)).ToArray();
var toolCalls = cm.Contents.OfType<FunctionCallContent>().Select(ToOllamaSharpToolCall).ToArray();

yield return new Message
// Only generates a message if there is text/content, images or tool calls
if (cm.Text is not null || images.Length > 0 || toolCalls.Length > 0)
{
Content = cm.Text,
Images = images.Length > 0 ? images : null,
Role = ToOllamaSharpRole(cm.Role),
ToolCalls = toolCalls.Length > 0 ? toolCalls : null,
};
yield return new Message
{
Content = cm.Text,
Images = images.Length > 0 ? images : null,
Role = ToOllamaSharpRole(cm.Role),
ToolCalls = toolCalls.Length > 0 ? toolCalls : null,
};
}

// If the message contains a function result, add it as a separate tool message
foreach (var frc in cm.Contents.OfType<FunctionResultContent>())
Expand Down
88 changes: 88 additions & 0 deletions test/AbstractionMapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,53 @@ public void Maps_Messages_With_Tools()
tool.Type.Should().Be("function");
}

[TestCaseSource(nameof(StopSequencesTestData))]
public void Maps_Messages_With_IEnumerable_StopSequences(object? enumerable)
{
var chatMessages = new List<Microsoft.Extensions.AI.ChatMessage>
{
new()
{
AdditionalProperties = [],
AuthorName = "a1",
RawRepresentation = null,
Role = Microsoft.Extensions.AI.ChatRole.User,
Text = "What's the weather in Honululu?"
}
};

var options = new ChatOptions()
{
AdditionalProperties = new AdditionalPropertiesDictionary() { ["stop"] = enumerable }
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(chatMessages, options, stream: true, JsonSerializerOptions.Default);

var stopSequences = chatRequest.Options.Stop;
var typedEnumerable = (IEnumerable<string>?)enumerable;

if (typedEnumerable == null)
{
stopSequences.Should().BeNull();
return;
}
stopSequences.Should().HaveCount(typedEnumerable?.Count() ?? 0);
}

public static IEnumerable<TestCaseData> StopSequencesTestData
{
get
{
yield return new TestCaseData((object?)(JsonSerializer.Deserialize<JsonElement>("[\"stop1\", \"stop2\"]")).EnumerateArray().Select(e => e.GetString()));
yield return new TestCaseData((object?)(IEnumerable<string>?)null);
yield return new TestCaseData((object?)new List<string> { "stop1", "stop2", "stop3", "stop4" });
yield return new TestCaseData((object?)new string[] { "stop1", "stop2", "stop3" });
yield return new TestCaseData((object?)new HashSet<string> { "stop1", "stop2", });
yield return new TestCaseData((object?)new Stack<string>(new[] { "stop1" }));
yield return new TestCaseData((object?)new Queue<string>(new[] { "stop1" }));
}
}

[Test]
public void Maps_Messages_With_ToolResponse()
{
Expand Down Expand Up @@ -316,6 +363,47 @@ public void Maps_Messages_With_MultipleToolResponse()
user.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.User);
}

[Test]
public void Maps_Messages_WithoutContent_MultipleToolResponse()
{
var aiChatMessages = new List<Microsoft.Extensions.AI.ChatMessage>
{
new()
{
AdditionalProperties = [],
AuthorName = "a1",
RawRepresentation = null,
Role = Microsoft.Extensions.AI.ChatRole.User,
Contents = [
new FunctionResultContent(
callId: "123",
name: "Function1",
result: new { Temperature = 40 }),

new FunctionResultContent(
callId: "456",
name: "Function2",
result: new { Summary = "This is a tool result test" }
),
]
}
};

var chatRequest = AbstractionMapper.ToOllamaSharpChatRequest(aiChatMessages, new(), stream: true, JsonSerializerOptions.Default);
var chatMessages = chatRequest.Messages?.ToList();

chatMessages.Should().HaveCount(2);

var tool1 = chatMessages[0];
var tool2 = chatMessages[1];
tool1.Content.Should().Contain("\"Temperature\":40");
tool1.Content.Should().Contain("\"CallId\":\"123\"");
tool1.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
tool2.Content.Should().Contain("\"Summary\":\"This is a tool result test\"");
tool2.Content.Should().Contain("\"CallId\":\"456\"");
tool2.Role.Should().Be(OllamaSharp.Models.Chat.ChatRole.Tool);
}

[Test]
public void Maps_Options()
{
Expand Down

0 comments on commit 62081c7

Please sign in to comment.