Skip to content

Commit

Permalink
Remove AIContent.ModelId, add StreamingChatCompletionUpdate.ModelId (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Oct 22, 2024
1 parent e1eb9bd commit 651546f
Show file tree
Hide file tree
Showing 18 changed files with 29 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ public IList<AIContent> Contents
/// <summary>Gets or sets the finish reason for the operation.</summary>
public ChatFinishReason? FinishReason { get; set; }

/// <summary>Gets or sets the model ID using in the creation of the chat completion of which this update is a part.</summary>
public string? ModelId { get; set; }

/// <inheritdoc/>
public override string ToString() => Text ?? string.Empty;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ protected AIContent()
[JsonIgnore]
public object? RawRepresentation { get; set; }

/// <summary>
/// Gets or sets the model ID used to generate the content.
/// </summary>
public string? ModelId { get; set; }

/// <summary>Gets or sets additional properties for the content.</summary>
public AdditionalPropertiesDictionary? AdditionalProperties { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ public async Task<ChatCompletion> CompleteAsync(
if (toolCall is ChatCompletionsFunctionToolCall ftc && !string.IsNullOrWhiteSpace(ftc.Name))
{
FunctionCallContent callContent = ParseCallContentFromJsonString(ftc.Arguments, toolCall.Id, ftc.Name);
callContent.ModelId = response.Model;
callContent.RawRepresentation = toolCall;

returnMessage.Contents.Add(callContent);
Expand All @@ -109,7 +108,6 @@ public async Task<ChatCompletion> CompleteAsync(
{
returnMessage.Contents.Add(new TextContent(choice.Message.Content)
{
ModelId = response.Model,
RawRepresentation = choice.Message
});
}
Expand Down Expand Up @@ -173,17 +171,15 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
CompletionId = chatCompletionUpdate.Id,
CreatedAt = chatCompletionUpdate.Created,
FinishReason = finishReason,
ModelId = modelId,
RawRepresentation = chatCompletionUpdate,
Role = streamedRole,
};

// Transfer over content update items.
if (chatCompletionUpdate.ContentUpdate is string update)
{
completionUpdate.Contents.Add(new TextContent(update)
{
ModelId = modelId,
});
completionUpdate.Contents.Add(new TextContent(update));
}

// Transfer over tool call updates.
Expand Down Expand Up @@ -218,6 +214,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
CompletionId = completionId,
CreatedAt = createdAt,
FinishReason = finishReason,
ModelId = modelId,
Role = streamedRole,
};

Expand All @@ -230,9 +227,6 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
fci.Arguments?.ToString() ?? string.Empty,
fci.CallId!,
fci.Name!);

callContent.ModelId = modelId;

completionUpdate.Contents.Add(callContent);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,24 +125,25 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
continue;
}

string? modelId = chunk.Model ?? Metadata.ModelId;

StreamingChatCompletionUpdate update = new()
{
Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null,
CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null,
AdditionalProperties = ParseOllamaChatResponseProps(chunk),
FinishReason = ToFinishReason(chunk),
ModelId = modelId,
};

string? modelId = chunk.Model ?? Metadata.ModelId;

if (chunk.Message is { } message)
{
update.Contents.Add(new TextContent(message.Content) { ModelId = modelId });
update.Contents.Add(new TextContent(message.Content));
}

if (ParseOllamaChatResponseUsage(chunk) is { } usage)
{
update.Contents.Add(new UsageContent(usage) { ModelId = modelId });
update.Contents.Add(new UsageContent(usage));
}

yield return update;
Expand Down
19 changes: 6 additions & 13 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public async Task<ChatCompletion> CompleteAsync(
// Populate its content from those in the OpenAI response content.
foreach (ChatMessageContentPart contentPart in response.Content)
{
if (ToAIContent(contentPart, response.Model) is AIContent aiContent)
if (ToAIContent(contentPart) is AIContent aiContent)
{
returnMessage.Contents.Add(aiContent);
}
Expand All @@ -125,7 +125,6 @@ public async Task<ChatCompletion> CompleteAsync(
if (!string.IsNullOrWhiteSpace(toolCall.FunctionName))
{
var callContent = ParseCallContentFromBinaryData(toolCall.FunctionArguments, toolCall.Id, toolCall.FunctionName);
callContent.ModelId = response.Model;
callContent.RawRepresentation = toolCall;

returnMessage.Contents.Add(callContent);
Expand Down Expand Up @@ -214,6 +213,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
CompletionId = chatCompletionUpdate.CompletionId,
CreatedAt = chatCompletionUpdate.CreatedAt,
FinishReason = finishReason,
ModelId = modelId,
RawRepresentation = chatCompletionUpdate,
Role = streamedRole,
};
Expand All @@ -239,7 +239,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
{
foreach (ChatMessageContentPart contentPart in chatCompletionUpdate.ContentUpdate)
{
if (ToAIContent(contentPart, modelId) is AIContent aiContent)
if (ToAIContent(contentPart) is AIContent aiContent)
{
completionUpdate.Contents.Add(aiContent);
}
Expand Down Expand Up @@ -292,10 +292,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs

// TODO: Add support for prompt token details (e.g. cached tokens) once it's exposed in OpenAI library.

completionUpdate.Contents.Add(new UsageContent(usageDetails)
{
ModelId = modelId
});
completionUpdate.Contents.Add(new UsageContent(usageDetails));
}

// Now yield the item.
Expand All @@ -310,6 +307,7 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
CompletionId = completionId,
CreatedAt = createdAt,
FinishReason = finishReason,
ModelId = modelId,
Role = streamedRole,
};

Expand All @@ -322,9 +320,6 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
fci.Arguments?.ToString() ?? string.Empty,
fci.CallId!,
fci.Name!);

callContent.ModelId = modelId;

completionUpdate.Contents.Add(callContent);
}
}
Expand Down Expand Up @@ -531,9 +526,8 @@ private sealed class OpenAIChatToolJson

/// <summary>Creates an <see cref="AIContent"/> from a <see cref="ChatMessageContentPart"/>.</summary>
/// <param name="contentPart">The content part to convert into a content.</param>
/// <param name="modelId">The model ID.</param>
/// <returns>The constructed <see cref="AIContent"/>, or null if the content part could not be converted.</returns>
private static AIContent? ToAIContent(ChatMessageContentPart contentPart, string? modelId)
private static AIContent? ToAIContent(ChatMessageContentPart contentPart)
{
AIContent? aiContent = null;

Expand Down Expand Up @@ -564,7 +558,6 @@ private sealed class OpenAIChatToolJson
(additionalProperties ??= [])[nameof(contentPart.Refusal)] = refusal;
}

aiContent.ModelId = modelId;
aiContent.AdditionalProperties = additionalProperties;
aiContent.RawRepresentation = contentPart;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ next.Contents[0] is not TextContent ||
TextContent coalescedContent = new(null) // will patch the text after examining all items in the run
{
AdditionalProperties = textContent.AdditionalProperties?.Clone(),
ModelId = textContent.ModelId,
};

StreamingChatCompletionUpdate coalesced = new()
Expand All @@ -141,6 +140,7 @@ next.Contents[0] is not TextContent ||
Contents = [coalescedContent],
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
ModelId = update.ModelId,
Role = update.Role,

// Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used
Expand All @@ -160,16 +160,15 @@ next.Contents[0] is not TextContent ||
StreamingChatCompletionUpdate next = capturedItems[i];
capturedItems[i] = null!;

TextContent nextContent = (TextContent)next.Contents[0];
var nextContent = (TextContent)next.Contents[0];
_ = coalescedText.Append(nextContent.Text);

coalesced.AuthorName ??= next.AuthorName;
coalesced.CompletionId ??= next.CompletionId;
coalesced.CreatedAt ??= next.CreatedAt;
coalesced.FinishReason ??= next.FinishReason;
coalesced.ModelId ??= next.ModelId;
coalesced.Role ??= next.Role;

coalescedContent.ModelId ??= nextContent.ModelId;
}

// Complete the coalescing by patching the text of the coalesced node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static ChatCompletion ComposeStreamingUpdatesIntoChatCompletion(
finishReason ??= update.FinishReason;
role ??= update.Role;
items.AddRange(update.Contents);
modelId ??= update.Contents.FirstOrDefault(c => c.ModelId is not null)?.ModelId;
modelId ??= update.ModelId;
}

messages.Add(new ChatMessage(role ?? ChatRole.Assistant, items));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,32 +262,26 @@ public void ItCanBeSerializeAndDeserialized()
[
new TextContent("content-1")
{
ModelId = "model-1",
AdditionalProperties = new() { ["metadata-key-1"] = "metadata-value-1" }
},
new ImageContent(new Uri("https://fake-random-test-host:123"), "mime-type/2")
{
ModelId = "model-2",
AdditionalProperties = new() { ["metadata-key-2"] = "metadata-value-2" }
},
new DataContent(new BinaryData(new[] { 1, 2, 3 }, options: TestJsonSerializerContext.Default.Options), "mime-type/3")
{
ModelId = "model-3",
AdditionalProperties = new() { ["metadata-key-3"] = "metadata-value-3" }
},
new AudioContent(new BinaryData(new[] { 3, 2, 1 }, options: TestJsonSerializerContext.Default.Options), "mime-type/4")
{
ModelId = "model-4",
AdditionalProperties = new() { ["metadata-key-4"] = "metadata-value-4" }
},
new ImageContent(new BinaryData(new[] { 2, 1, 3 }, options: TestJsonSerializerContext.Default.Options), "mime-type/5")
{
ModelId = "model-5",
AdditionalProperties = new() { ["metadata-key-5"] = "metadata-value-5" }
},
new TextContent("content-6")
{
ModelId = "model-6",
AdditionalProperties = new() { ["metadata-key-6"] = "metadata-value-6" }
},
new FunctionCallContent("function-id", "plugin-name-function-name", new Dictionary<string, object?> { ["parameter"] = "argument" }),
Expand Down Expand Up @@ -317,15 +311,13 @@ public void ItCanBeSerializeAndDeserialized()
var textContent = deserializedMessage.Contents[0] as TextContent;
Assert.NotNull(textContent);
Assert.Equal("content-1-override", textContent.Text);
Assert.Equal("model-1", textContent.ModelId);
Assert.NotNull(textContent.AdditionalProperties);
Assert.Single(textContent.AdditionalProperties);
Assert.Equal("metadata-value-1", textContent.AdditionalProperties["metadata-key-1"]?.ToString());

var imageContent = deserializedMessage.Contents[1] as ImageContent;
Assert.NotNull(imageContent);
Assert.Equal("https://fake-random-test-host:123/", imageContent.Uri);
Assert.Equal("model-2", imageContent.ModelId);
Assert.Equal("mime-type/2", imageContent.MediaType);
Assert.NotNull(imageContent.AdditionalProperties);
Assert.Single(imageContent.AdditionalProperties);
Expand All @@ -334,7 +326,6 @@ public void ItCanBeSerializeAndDeserialized()
var dataContent = deserializedMessage.Contents[2] as DataContent;
Assert.NotNull(dataContent);
Assert.True(dataContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options)));
Assert.Equal("model-3", dataContent.ModelId);
Assert.Equal("mime-type/3", dataContent.MediaType);
Assert.NotNull(dataContent.AdditionalProperties);
Assert.Single(dataContent.AdditionalProperties);
Expand All @@ -343,7 +334,6 @@ public void ItCanBeSerializeAndDeserialized()
var audioContent = deserializedMessage.Contents[3] as AudioContent;
Assert.NotNull(audioContent);
Assert.True(audioContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 3, 2, 1 }, TestJsonSerializerContext.Default.Options)));
Assert.Equal("model-4", audioContent.ModelId);
Assert.Equal("mime-type/4", audioContent.MediaType);
Assert.NotNull(audioContent.AdditionalProperties);
Assert.Single(audioContent.AdditionalProperties);
Expand All @@ -352,7 +342,6 @@ public void ItCanBeSerializeAndDeserialized()
imageContent = deserializedMessage.Contents[4] as ImageContent;
Assert.NotNull(imageContent);
Assert.True(imageContent.Data?.Span.SequenceEqual(new BinaryData(new[] { 2, 1, 3 }, TestJsonSerializerContext.Default.Options)));
Assert.Equal("model-5", imageContent.ModelId);
Assert.Equal("mime-type/5", imageContent.MediaType);
Assert.NotNull(imageContent.AdditionalProperties);
Assert.Single(imageContent.AdditionalProperties);
Expand All @@ -361,7 +350,6 @@ public void ItCanBeSerializeAndDeserialized()
textContent = deserializedMessage.Contents[5] as TextContent;
Assert.NotNull(textContent);
Assert.Equal("content-6", textContent.Text);
Assert.Equal("model-6", textContent.ModelId);
Assert.NotNull(textContent.AdditionalProperties);
Assert.Single(textContent.AdditionalProperties);
Assert.Equal("metadata-value-6", textContent.AdditionalProperties["metadata-key-6"]?.ToString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ public void Constructor_PropsDefault()
{
DerivedAIContent c = new();
Assert.Null(c.RawRepresentation);
Assert.Null(c.ModelId);
Assert.Null(c.AdditionalProperties);
}

Expand All @@ -26,10 +25,6 @@ public void Constructor_PropsRoundtrip()
c.RawRepresentation = raw;
Assert.Same(raw, c.RawRepresentation);

Assert.Null(c.ModelId);
c.ModelId = "modelId";
Assert.Equal("modelId", c.ModelId);

Assert.Null(c.AdditionalProperties);
AdditionalPropertiesDictionary props = new() { { "key", "value" } };
c.AdditionalProperties = props;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ public void Deserialize_MatchesExpectedData()
Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data!.Value.ToArray());
Assert.Equal("text/plain", content.MediaType);
Assert.True(content.ContainsData);
Assert.Equal("gpt-4", content.ModelId);
Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ public void Constructor_PropsDefault()
FunctionCallContent c = new("callId1", "name");

Assert.Null(c.RawRepresentation);
Assert.Null(c.ModelId);
Assert.Null(c.AdditionalProperties);

Assert.Equal("callId1", c.CallId);
Expand All @@ -39,7 +38,6 @@ public void Constructor_ArgumentsRoundtrip()
FunctionCallContent c = new("id", "name", args);

Assert.Null(c.RawRepresentation);
Assert.Null(c.ModelId);
Assert.Null(c.AdditionalProperties);

Assert.Equal("name", c.Name);
Expand All @@ -57,10 +55,6 @@ public void Constructor_PropsRoundtrip()
c.RawRepresentation = raw;
Assert.Same(raw, c.RawRepresentation);

Assert.Null(c.ModelId);
c.ModelId = "modelId";
Assert.Equal("modelId", c.ModelId);

Assert.Null(c.AdditionalProperties);
AdditionalPropertiesDictionary props = new() { { "key", "value" } };
c.AdditionalProperties = props;
Expand Down Expand Up @@ -322,8 +316,8 @@ public static void CreateFromParsedArguments_ObjectJsonInput_ReturnsElementArgum
[InlineData(typeof(NotSupportedException))]
public static void CreateFromParsedArguments_ParseException_HasExpectedHandling(Type exceptionType)
{
Exception exc = (Exception)Activator.CreateInstance(exceptionType)!;
FunctionCallContent content = FunctionCallContent.CreateFromParsedArguments(exc, "callId", "functionName", ThrowingParser);
var exc = (Exception)Activator.CreateInstance(exceptionType)!;
var content = FunctionCallContent.CreateFromParsedArguments(exc, "callId", "functionName", ThrowingParser);

Assert.Equal("functionName", content.Name);
Assert.Equal("callId", content.CallId);
Expand Down
Loading

0 comments on commit 651546f

Please sign in to comment.