Skip to content

Commit

Permalink
[.Net] Support tools for AnthropicClient and AnthropicAgent (#2944)
Browse files Browse the repository at this point in the history
* Squash commits : support anthropic tools

* Support tool_choice

* Remove reference from TypeSafeFunctionCallCodeSnippet.cs and add own function in test proj
  • Loading branch information
DavidLuong98 authored Jun 30, 2024
1 parent e743d4d commit 80ecbf9
Show file tree
Hide file tree
Showing 19 changed files with 715 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
<ProjectReference Include="..\..\src\AutoGen.DotnetInteractive\AutoGen.DotnetInteractive.csproj" />
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace AutoGen.Anthropic.Samples;

public static class AnthropicSamples
public static class Create_Anthropic_Agent
{
public static async Task RunAsync()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Single_Anthropic_Tool.cs

using AutoGen.Anthropic.DTO;
using AutoGen.Anthropic.Extensions;
using AutoGen.Anthropic.Utils;
using AutoGen.Core;
using FluentAssertions;

namespace AutoGen.Anthropic.Samples;

#region WeatherFunction

public partial class WeatherFunction
{
/// <summary>
/// Gets the weather based on the location and the unit
/// </summary>
/// <param name="location"></param>
/// <param name="unit"></param>
/// <returns></returns>
[Function]
public async Task<string> GetWeather(string location, string unit)
{
// dummy implementation
return $"The weather in {location} is currently sunny with a tempature of {unit} (s)";
}
}
#endregion
public class Create_Anthropic_Agent_With_Tool
{
public static async Task RunAsync()
{
#region define_tool
var tool = new Tool
{
Name = "GetWeather",
Description = "Get the current weather in a given location",
InputSchema = new InputSchema
{
Type = "object",
Properties = new Dictionary<string, SchemaProperty>
{
{ "location", new SchemaProperty { Type = "string", Description = "The city and state, e.g. San Francisco, CA" } },
{ "unit", new SchemaProperty { Type = "string", Description = "The unit of temperature, either \"celsius\" or \"fahrenheit\"" } }
},
Required = new List<string> { "location" }
}
};

var weatherFunction = new WeatherFunction();
var functionMiddleware = new FunctionCallMiddleware(
functions: [
weatherFunction.GetWeatherFunctionContract,
],
functionMap: new Dictionary<string, Func<string, Task<string>>>
{
{ weatherFunction.GetWeatherFunctionContract.Name!, weatherFunction.GetWeatherWrapper },
});

#endregion

#region create_anthropic_agent

var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ??
throw new Exception("Missing ANTHROPIC_API_KEY environment variable.");

var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey);
var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku,
tools: [tool]); // Define tools for AnthropicClientAgent
#endregion

#region register_middleware

var agentWithConnector = agent
.RegisterMessageConnector()
.RegisterPrintMessage()
.RegisterStreamingMiddleware(functionMiddleware);
#endregion register_middleware

#region single_turn
var question = new TextMessage(Role.Assistant,
"What is the weather like in San Francisco?",
from: "user");
var functionCallReply = await agentWithConnector.SendAsync(question);
#endregion

#region Single_turn_verify_reply
functionCallReply.Should().BeOfType<ToolCallAggregateMessage>();
#endregion Single_turn_verify_reply

#region Multi_turn
var finalReply = await agentWithConnector.SendAsync(chatHistory: [question, functionCallReply]);
#endregion Multi_turn

#region Multi_turn_verify_reply
finalReply.Should().BeOfType<TextMessage>();
#endregion Multi_turn_verify_reply
}
}
2 changes: 1 addition & 1 deletion dotnet/sample/AutoGen.Anthropic.Samples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ internal static class Program
{
public static async Task Main(string[] args)
{
await AnthropicSamples.RunAsync();
await Create_Anthropic_Agent_With_Tool.RunAsync();
}
}
11 changes: 10 additions & 1 deletion dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -16,21 +17,27 @@ public class AnthropicClientAgent : IStreamingAgent
private readonly string _systemMessage;
private readonly decimal _temperature;
private readonly int _maxTokens;
private readonly Tool[]? _tools;
private readonly ToolChoice? _toolChoice;

public AnthropicClientAgent(
AnthropicClient anthropicClient,
string name,
string modelName,
string systemMessage = "You are a helpful AI assistant",
decimal temperature = 0.7m,
int maxTokens = 1024)
int maxTokens = 1024,
Tool[]? tools = null,
ToolChoice? toolChoice = null)
{
Name = name;
_anthropicClient = anthropicClient;
_modelName = modelName;
_systemMessage = systemMessage;
_temperature = temperature;
_maxTokens = maxTokens;
_tools = tools;
_toolChoice = toolChoice;
}

public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null,
Expand Down Expand Up @@ -59,6 +66,8 @@ private ChatCompletionRequest CreateParameters(IEnumerable<IMessage> messages, G
Model = _modelName,
Stream = shouldStream,
Temperature = (decimal?)options?.Temperature ?? _temperature,
Tools = _tools?.ToList(),
ToolChoice = _toolChoice ?? ToolChoice.Auto
};

chatCompletionRequest.Messages = BuildMessages(messages);
Expand Down
107 changes: 94 additions & 13 deletions dotnet/src/AutoGen.Anthropic/AnthropicClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ public sealed class AnthropicClient : IDisposable
private static readonly JsonSerializerOptions JsonSerializerOptions = new()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
Converters = { new ContentBaseConverter() }
Converters = { new ContentBaseConverter(), new JsonPropertyNameEnumConverter<ToolChoiceType>() }
};

private static readonly JsonSerializerOptions JsonDeserializerOptions = new()
{
Converters = { new ContentBaseConverter() }
Converters = { new ContentBaseConverter(), new JsonPropertyNameEnumConverter<ToolChoiceType>() }
};

public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey)
Expand Down Expand Up @@ -61,33 +61,75 @@ public async IAsyncEnumerable<ChatCompletionResponse> StreamingChatCompletionsAs
using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync());

var currentEvent = new SseEvent();

while (await reader.ReadLineAsync() is { } line)
{
if (!string.IsNullOrEmpty(line))
{
currentEvent.Data = line.Substring("data:".Length).Trim();
if (line.StartsWith("event:"))
{
currentEvent.EventType = line.Substring("event:".Length).Trim();
}
else if (line.StartsWith("data:"))
{
currentEvent.Data = line.Substring("data:".Length).Trim();
}
}
else
else // an empty line indicates the end of an event
{
if (currentEvent.Data == "[DONE]")
continue;
if (currentEvent.EventType == "content_block_start" && !string.IsNullOrEmpty(currentEvent.Data))
{
var dataBlock = JsonSerializer.Deserialize<DataBlock>(currentEvent.Data!);
if (dataBlock != null && dataBlock.ContentBlock?.Type == "tool_use")
{
currentEvent.ContentBlock = dataBlock.ContentBlock;
}
}

if (currentEvent.Data != null)
if (currentEvent.EventType is "message_start" or "content_block_delta" or "message_delta" && currentEvent.Data != null)
{
yield return await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)),
cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response");
cancellationToken: cancellationToken);

if (res == null)
{
throw new Exception("Failed to deserialize response");
}

if (res.Delta?.Type == "input_json_delta" && !string.IsNullOrEmpty(res.Delta.PartialJson) &&
currentEvent.ContentBlock != null)
{
currentEvent.ContentBlock.AppendDeltaParameters(res.Delta.PartialJson!);
}
else if (res.Delta is { StopReason: "tool_use" } && currentEvent.ContentBlock != null)
{
if (res.Content == null)
{
res.Content = [currentEvent.ContentBlock.CreateToolUseContent()];
}
else
{
res.Content.Add(currentEvent.ContentBlock.CreateToolUseContent());
}

currentEvent = new SseEvent();
}

yield return res;
}
else if (currentEvent.Data != null)
else if (currentEvent.EventType == "error" && currentEvent.Data != null)
{
var res = await JsonSerializer.DeserializeAsync<ErrorResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken);

throw new Exception(res?.Error?.Message);
}

// Reset the current event for the next one
currentEvent = new SseEvent();
if (currentEvent.ContentBlock == null)
{
currentEvent = new SseEvent();
}
}
}
}
Expand All @@ -113,11 +155,50 @@ public void Dispose()

private struct SseEvent
{
public string EventType { get; set; }
public string? Data { get; set; }
public ContentBlock? ContentBlock { get; set; }

public SseEvent(string? data = null)
public SseEvent(string eventType, string? data = null, ContentBlock? contentBlock = null)
{
EventType = eventType;
Data = data;
ContentBlock = contentBlock;
}
}

private class ContentBlock
{
[JsonPropertyName("type")]
public string? Type { get; set; }

[JsonPropertyName("id")]
public string? Id { get; set; }

[JsonPropertyName("name")]
public string? Name { get; set; }

[JsonPropertyName("input")]
public object? Input { get; set; }

public string? parameters { get; set; }

public void AppendDeltaParameters(string deltaParams)
{
StringBuilder sb = new StringBuilder(parameters);
sb.Append(deltaParams);
parameters = sb.ToString();
}

public ToolUseContent CreateToolUseContent()
{
return new ToolUseContent { Id = Id, Name = Name, Input = parameters };
}
}

private class DataBlock
{
[JsonPropertyName("content_block")]
public ContentBlock? ContentBlock { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert,
return JsonSerializer.Deserialize<TextContent>(text, options) ?? throw new InvalidOperationException();
case "image":
return JsonSerializer.Deserialize<ImageContent>(text, options) ?? throw new InvalidOperationException();
case "tool_use":
return JsonSerializer.Deserialize<ToolUseContent>(text, options) ?? throw new InvalidOperationException();
case "tool_result":
return JsonSerializer.Deserialize<ToolResultContent>(text, options) ?? throw new InvalidOperationException();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// JsonPropertyNameEnumCoverter.cs

using System;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace AutoGen.Anthropic.Converters;

internal class JsonPropertyNameEnumConverter<T> : JsonConverter<T> where T : struct, Enum
{
public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
string value = reader.GetString() ?? throw new JsonException("Value was null.");

foreach (var field in typeToConvert.GetFields())
{
var attribute = field.GetCustomAttribute<JsonPropertyNameAttribute>();
if (attribute?.Name == value)
{
return (T)Enum.Parse(typeToConvert, field.Name);
}
}

throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}.");
}

public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
var field = value.GetType().GetField(value.ToString());
var attribute = field.GetCustomAttribute<JsonPropertyNameAttribute>();

if (attribute != null)
{
writer.WriteStringValue(attribute.Name);
}
else
{
writer.WriteStringValue(value.ToString());
}
}
}

8 changes: 8 additions & 0 deletions dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ public class ChatCompletionRequest
[JsonPropertyName("top_p")]
public decimal? TopP { get; set; }

[JsonPropertyName("tools")]
public List<Tool>? Tools { get; set; }

[JsonPropertyName("tool_choice")]
public ToolChoice? ToolChoice { get; set; }

public ChatCompletionRequest()
{
Messages = new List<ChatMessage>();
Expand All @@ -62,4 +68,6 @@ public ChatMessage(string role, List<ContentBase> content)
Role = role;
Content = content;
}

public void AddContent(ContentBase content) => Content.Add(content);
}
Loading

0 comments on commit 80ecbf9

Please sign in to comment.