Skip to content

Commit

Permalink
feat(chat): Support extracting code blocks (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
skarllot authored Jan 6, 2025
1 parent bb3ee24 commit daa69d4
Show file tree
Hide file tree
Showing 17 changed files with 571 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Do not provide positive reinforcement or comments on good decisions. Focus solel

public Result<object, string> Parse(string key, string input) => key switch
{
JsonResponseKey => ContentDeserializer
JsonResponseKey => JsonContentDeserializer
.TryDeserialize(input, jsonContext.ImmutableListReviewerFeedbackResponse)
.Select(static object (x) => x),
_ => $"Unknown output key '{key}'"
Expand Down
3 changes: 2 additions & 1 deletion src/FlowPair/Chats/Models/ChatScript.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ public sealed record ChatScript(
StepInstruction: _ => curr.Select(v => v + 1),
MultiStepInstruction: x => Enumerable.Range(0, x.Messages.Count)
.Select((_, i) => i == 0 ? curr.First() + 1 : 1),
JsonConvertInstruction: _ => curr.Select(v => v + 1)))
JsonConvertInstruction: _ => curr.Select(v => v + 1),
CodeExtractInstruction: _ => curr.Select(v => v + 1)))
.Sum();

public static Option<ChatScript> FindChatScriptForFile(
Expand Down
22 changes: 22 additions & 0 deletions src/FlowPair/Chats/Models/ChatThread.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ public Result<ChatThread, string> RunJsonInstruction(
}
}

public Result<ChatThread, string> RunCodeInstruction(
Instruction.CodeExtractInstruction instruction,
IProxyCompleteChatHandler completeChatHandler)
{
try
{
if (IsInterrupted)
{
return this;
}

return Enumerable.Range(0, MaxJsonRetries)
.TryAggregate(
AddMessages(instruction.ToMessage(StopKeyword)),
(chat, _) => chat.CompleteChatAndDeserialize(instruction.OutputKey, completeChatHandler));
}
finally
{
Progress.Increment(1);
}
}

private ChatThread AddMessages(params ReadOnlySpan<Message> newMessages) =>
this with { Messages = [..Messages, ..newMessages] };

Expand Down
13 changes: 12 additions & 1 deletion src/FlowPair/Chats/Models/ChatWorkspace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ public Result<ChatWorkspace, string> RunInstruction(
return instruction.Match(
StepInstruction: x => RunStepInstruction(x, completeChatHandler),
MultiStepInstruction: x => RunMultiStepInstruction(x, completeChatHandler),
JsonConvertInstruction: x => RunJsonInstruction(x, completeChatHandler));
JsonConvertInstruction: x => RunJsonInstruction(x, completeChatHandler),
CodeExtractInstruction: x => RunCodeInstruction(x, completeChatHandler));
}

private Result<ChatWorkspace, string> RunStepInstruction(
Expand Down Expand Up @@ -51,4 +52,14 @@ private Result<ChatWorkspace, string> RunJsonInstruction(
.Sequence()
.Select(list => new ChatWorkspace(ChatThreads: list.ToImmutableList()));
}

private Result<ChatWorkspace, string> RunCodeInstruction(
Instruction.CodeExtractInstruction instruction,
IProxyCompleteChatHandler completeChatHandler)
{
return ChatThreads.AsParallel()
.Select(thread => thread.RunCodeInstruction(instruction, completeChatHandler))
.Sequence()
.Select(list => new ChatWorkspace(ChatThreads: list.ToImmutableList()));
}
}
5 changes: 5 additions & 0 deletions src/FlowPair/Chats/Models/CodeSnippet.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace Ciandt.FlowTools.FlowPair.Chats.Models;

public sealed record CodeSnippet(
string Content,
string? Language = null);
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,11 @@ partial record JsonConvertInstruction(string OutputKey, string Message, string J
```
""");
}

partial record CodeExtractInstruction(string OutputKey, string Message)
{
public Message ToMessage(string stopKeyword) => new(
SenderRole.User,
Message.Replace(ChatScript.StopKeywordPlaceholder, stopKeyword));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Ciandt.FlowTools.FlowPair.Chats.Services;

public static class ContentDeserializer
public static class JsonContentDeserializer
{
public static Result<T, string> TryDeserialize<T>(
ReadOnlySpan<char> content,
Expand Down
38 changes: 38 additions & 0 deletions src/FlowPair/Chats/Services/MarkdownCodeExtractor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System.Collections.Immutable;
using System.Text.RegularExpressions;
using Ciandt.FlowTools.FlowPair.Chats.Models;
using Ciandt.FlowTools.FlowPair.Common;
using FxKit.Parsers;

namespace Ciandt.FlowTools.FlowPair.Chats.Services;

public static partial class MarkdownCodeExtractor
{
public static Result<ImmutableList<CodeSnippet>, string> TryExtract(string content)
{
return
from value in StringParser.NonNullOrWhiteSpace(content)
.OkOr("Code not found on empty content")
select CodeBlockRegex().Matches(content)
.Select(m => new CodeSnippet(m.Groups[2].Value, m.Groups[1].Value))
.ToImmutableList();
}

public static Result<CodeSnippet, string> TryExtractSingle(string content)
{
return
from value in StringParser.NonNullOrWhiteSpace(content)
.OkOr("Code not found on empty content")
from code in CodeBlockRegex().Matches(content)
.Select(m => new CodeSnippet(m.Groups[2].Value, m.Groups[1].Value))
.TrySingle()
.MapErr(
p => p.Match(
Empty: "No code block found",
MoreThanOneElement: "More than one code block found"))
select code;
}

[GeneratedRegex(@"```(\w*)\s*\n([\s\S]*?)\r?\n\s*```", RegexOptions.Singleline)]
private static partial Regex CodeBlockRegex();
}
32 changes: 32 additions & 0 deletions src/FlowPair/Common/CollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,36 @@ public static string AggregateToStringLines<T>(this IEnumerable<T> collection, F
: curr.Append(selector(item)))
.ToString();
}

public static Result<TSource, SingleElementProblem> TrySingle<TSource>(this IEnumerable<TSource> source)
where TSource : notnull
{
if (source is IList<TSource> list)
{
switch (list.Count)
{
case 0:
return SingleElementProblem.Empty;
case 1:
return list[0];
}
}
else
{
using var e = source.GetEnumerator();

if (!e.MoveNext())
{
return SingleElementProblem.Empty;
}

var result = e.Current;
if (!e.MoveNext())
{
return result;
}
}

return SingleElementProblem.MoreThanOneElement;
}
}
10 changes: 10 additions & 0 deletions src/FlowPair/Common/SingleElementProblem.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using FxKit.CompilerServices;

namespace Ciandt.FlowTools.FlowPair.Common;

[EnumMatch]
public enum SingleElementProblem
{
Empty,
MoreThanOneElement,
}
2 changes: 1 addition & 1 deletion tests/FlowPair.Tests/Chats/Models/ChatScriptTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public void TotalStepsShouldCalculateCorrectlyWithSingleAfterMultiStepInstructio
Preamble: "Preamble",
Messages: ImmutableList.Create("Step1", "Step2", "Step3"),
Ending: "Ending"),
new Instruction.StepInstruction("Message2"),
new Instruction.CodeExtractInstruction("Key1", "Message2"),
]);

// Act
Expand Down
93 changes: 93 additions & 0 deletions tests/FlowPair.Tests/Chats/Models/ChatThreadTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,97 @@ public void RunStepInstructionShouldNotAddMessageWhenInterrupted()
_handler.ReceivedCalls().Should().BeEmpty();
_messageParser.ReceivedCalls().Should().BeEmpty();
}

[Fact]
public void RunCodeInstructionShouldAddMessageAndIncrementProgress()
{
// Arrange
const string outputKey = "CodeKey";
_messageParser
.Parse(outputKey, Arg.Any<string>())
.Returns(Unit());

var chatThread = new ChatThread(
Progress: _progressTask,
ModelType: LlmModelType.Gpt4,
StopKeyword: "<STOP>",
Messages: [new Message(SenderRole.User, "Initial")],
MessageParser: _messageParser);

var codeInstruction = new Instruction.CodeExtractInstruction(
OutputKey: outputKey,
Message: "Extract Code");

// Act
var result = chatThread.RunCodeInstruction(codeInstruction, _handler);

// Assert
result.IsOk.Should().BeTrue();
var updatedThread = result.Unwrap();
updatedThread.Messages.Should().HaveCount(3);
updatedThread.LastMessage!.Content.Should().Be(CompletionResponse);
_progressTask.Value.Should().Be(1);
_handler.ReceivedCalls().Should().HaveCount(1);
_messageParser.ReceivedCalls().Should().HaveCount(1);
}

[Fact]
public void RunCodeInstructionShouldRetryAddMessagesAndIncrementProgress()
{
// Arrange
const string outputKey = "CodeKey";
_messageParser
.Parse(outputKey, Arg.Any<string>())
.Returns("First try", "Second try", Unit());

var chatThread = new ChatThread(
Progress: _progressTask,
ModelType: LlmModelType.Gpt4,
StopKeyword: "<STOP>",
Messages: [new Message(SenderRole.User, "Initial")],
MessageParser: _messageParser);

var codeInstruction = new Instruction.CodeExtractInstruction(
OutputKey: outputKey,
Message: "Extract Code");

// Act
var result = chatThread.RunCodeInstruction(codeInstruction, _handler);

// Assert
result.IsOk.Should().BeTrue();
var updatedThread = result.Unwrap();
updatedThread.Messages.Should().HaveCount(7);
updatedThread.LastMessage!.Content.Should().Be(CompletionResponse);
_progressTask.Value.Should().Be(1);
_handler.ReceivedCalls().Should().HaveCount(3);
_messageParser.ReceivedCalls().Should().HaveCount(3);
}

[Fact]
public void RunCodeInstructionShouldNotAddMessageWhenInterrupted()
{
// Arrange
var chatThread = new ChatThread(
Progress: _progressTask,
ModelType: LlmModelType.Gpt4,
StopKeyword: "<STOP>",
Messages: [new Message(SenderRole.Assistant, "Interrupted <STOP>")],
MessageParser: _messageParser);

var codeInstruction = new Instruction.CodeExtractInstruction(
OutputKey: "CodeKey",
Message: "Extract Code");

// Act
var result = chatThread.RunCodeInstruction(codeInstruction, _handler);

// Assert
result.IsOk.Should().BeTrue();
var updatedThread = result.Unwrap();
updatedThread.Messages.Should().HaveCount(1);
_progressTask.Value.Should().Be(1);
_handler.ReceivedCalls().Should().BeEmpty();
_messageParser.ReceivedCalls().Should().BeEmpty();
}
}
50 changes: 50 additions & 0 deletions tests/FlowPair.Tests/Chats/Models/ChatWorkspaceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,56 @@ public void RunJsonInstructionShouldUpdateAllThreads()
_messageParser.ReceivedCalls().Should().HaveCount(2);
}

[Fact]
public void RunCodeInstructionShouldUpdateSingleThread()
{
// Arrange
const string outputKey = "TestKey";
_messageParser
.Parse(outputKey, Arg.Any<string>())
.Returns(Unit());

var workspace = new ChatWorkspace([CreateChatThread()]);
var codeInstruction = new Instruction.CodeExtractInstruction(
outputKey,
"Code Message");

// Act
var result = workspace.RunInstruction(codeInstruction, _handler);

// Assert
result.Should().BeOk()
.ChatThreads.Should().HaveCount(1)
.And.Subject.Should().ContainSingle(t => t.Messages.Count == 2);
_handler.ReceivedCalls().Should().HaveCount(1);
_messageParser.ReceivedCalls().Should().HaveCount(1);
}

[Fact]
public void RunCodeInstructionShouldUpdateAllThreads()
{
// Arrange
const string outputKey = "TestKey";
_messageParser
.Parse(outputKey, Arg.Any<string>())
.Returns(Unit());

var workspace = new ChatWorkspace([CreateChatThread(), CreateChatThread()]);
var codeInstruction = new Instruction.CodeExtractInstruction(
outputKey,
"Code Message");

// Act
var result = workspace.RunInstruction(codeInstruction, _handler);

// Assert
result.Should().BeOk()
.ChatThreads.Should().HaveCount(2)
.And.Subject.Should().OnlyContain(t => t.Messages.Count == 2);
_handler.ReceivedCalls().Should().HaveCount(2);
_messageParser.ReceivedCalls().Should().HaveCount(2);
}

private ChatThread CreateChatThread(ImmutableList<Message>? messages = null) =>
new(
Progress: _progressTask,
Expand Down
2 changes: 1 addition & 1 deletion tests/FlowPair.Tests/Chats/Services/ChatServiceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void RunShouldReturnValidFeedbackWhenChatScriptIsValid()
_chatDefinition
.Parse(outputKey, Arg.Any<string>())
.Returns(
c => ContentDeserializer
c => JsonContentDeserializer
.TryDeserialize((string)c[1], AgentJsonContext.Default.ImmutableListReviewerFeedbackResponse)
.Select(static object (x) => x));
_chatDefinition
Expand Down
Loading

0 comments on commit daa69d4

Please sign in to comment.