Skip to content

Commit

Permalink
Add ChatOptions.Seed (#5587)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Oct 31, 2024
1 parent e18a055 commit 17e5ecd
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ public class ChatOptions
/// <summary>Gets or sets the presence penalty for generating chat responses.</summary>
public float? PresencePenalty { get; set; }

/// <summary>Gets or sets a seed value used by a service to control the reproducability of results.</summary>
public long? Seed { get; set; }

/// <summary>
/// Gets or sets the response format for the chat request.
/// </summary>
Expand Down Expand Up @@ -74,6 +77,7 @@ public virtual ChatOptions Clone()
TopK = TopK,
FrequencyPenalty = FrequencyPenalty,
PresencePenalty = PresencePenalty,
Seed = Seed,
ResponseFormat = ResponseFormat,
ModelId = ModelId,
ToolMode = ToolMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
result.NucleusSamplingFactor = options.TopP;
result.PresencePenalty = options.PresencePenalty;
result.Temperature = options.Temperature;
result.Seed = options.Seed;

if (options.StopSequences is { Count: > 0 } stopSequences)
{
Expand All @@ -306,11 +307,6 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
{
switch (prop.Key)
{
// These properties are strongly-typed on the ChatCompletionsOptions class but not on the ChatOptions class.
case nameof(result.Seed) when prop.Value is long seed:
result.Seed = seed;
break;

// Propagate everything else to the ChatCompletionOptions' AdditionalProperties.
default:
if (prop.Value is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ private OllamaChatRequest ToOllamaChatRequest(IList<ChatMessage> chatMessages, C
TransferMetadataValue<bool>(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value);
TransferMetadataValue<int>(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value);
TransferMetadataValue<long>(nameof(OllamaRequestOptions.seed), (options, value) => options.seed = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value);
TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value);
Expand Down Expand Up @@ -314,6 +313,11 @@ private OllamaChatRequest ToOllamaChatRequest(IList<ChatMessage> chatMessages, C
{
(request.Options ??= new()).top_k = topK;
}

if (options.Seed is long seed)
{
(request.Options ??= new()).seed = seed;
}
}

return request;
Expand Down
10 changes: 3 additions & 7 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,9 @@ private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
result.TopP = options.TopP;
result.PresencePenalty = options.PresencePenalty;
result.Temperature = options.Temperature;
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates.
result.Seed = options.Seed;
#pragma warning restore OPENAI001

if (options.StopSequences is { Count: > 0 } stopSequences)
{
Expand Down Expand Up @@ -426,13 +429,6 @@ private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
result.AllowParallelToolCalls = allowParallelToolCalls;
}

#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
if (additionalProperties.TryGetValue(nameof(result.Seed), out long seed))
{
result.Seed = seed;
}
#pragma warning restore OPENAI001

if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt))
{
result.TopLogProbabilityCount = topLogProbabilityCountInt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ private static ChatCompletion ComposeStreamingUpdatesIntoChatCompletion(
_ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "response_format"), responseFormat);
}

if (options.AdditionalProperties?.TryGetValue("seed", out long seed) is true)
if (options.Seed is long seed)
{
_ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "seed"), seed);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public void Constructor_Parameterless_PropsDefaulted()
Assert.Null(options.TopK);
Assert.Null(options.FrequencyPenalty);
Assert.Null(options.PresencePenalty);
Assert.Null(options.Seed);
Assert.Null(options.ResponseFormat);
Assert.Null(options.ModelId);
Assert.Null(options.StopSequences);
Expand All @@ -33,6 +34,7 @@ public void Constructor_Parameterless_PropsDefaulted()
Assert.Null(clone.TopK);
Assert.Null(clone.FrequencyPenalty);
Assert.Null(clone.PresencePenalty);
Assert.Null(options.Seed);
Assert.Null(clone.ResponseFormat);
Assert.Null(clone.ModelId);
Assert.Null(clone.StopSequences);
Expand Down Expand Up @@ -69,6 +71,7 @@ public void Properties_Roundtrip()
options.TopK = 42;
options.FrequencyPenalty = 0.4f;
options.PresencePenalty = 0.5f;
options.Seed = 12345;
options.ResponseFormat = ChatResponseFormat.Json;
options.ModelId = "modelId";
options.StopSequences = stopSequences;
Expand All @@ -82,6 +85,7 @@ public void Properties_Roundtrip()
Assert.Equal(42, options.TopK);
Assert.Equal(0.4f, options.FrequencyPenalty);
Assert.Equal(0.5f, options.PresencePenalty);
Assert.Equal(12345, options.Seed);
Assert.Same(ChatResponseFormat.Json, options.ResponseFormat);
Assert.Equal("modelId", options.ModelId);
Assert.Same(stopSequences, options.StopSequences);
Expand All @@ -96,6 +100,7 @@ public void Properties_Roundtrip()
Assert.Equal(42, clone.TopK);
Assert.Equal(0.4f, clone.FrequencyPenalty);
Assert.Equal(0.5f, clone.PresencePenalty);
Assert.Equal(12345, options.Seed);
Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat);
Assert.Equal("modelId", clone.ModelId);
Assert.Equal(stopSequences, clone.StopSequences);
Expand Down Expand Up @@ -126,6 +131,7 @@ public void JsonSerialization_Roundtrips()
options.TopK = 42;
options.FrequencyPenalty = 0.4f;
options.PresencePenalty = 0.5f;
options.Seed = 12345;
options.ResponseFormat = ChatResponseFormat.Json;
options.ModelId = "modelId";
options.StopSequences = stopSequences;
Expand All @@ -148,6 +154,7 @@ public void JsonSerialization_Roundtrips()
Assert.Equal(42, deserialized.TopK);
Assert.Equal(0.4f, deserialized.FrequencyPenalty);
Assert.Equal(0.5f, deserialized.PresencePenalty);
Assert.Equal(12345, deserialized.Seed);
Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat);
Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat);
Assert.Equal("modelId", deserialized.ModelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ public async Task MultipleMessages_NonStreaming()
],
"presence_penalty": 0.5,
"frequency_penalty": 0.75,
"model": "gpt-4o-mini",
"seed": 42
"seed": 42,
"model": "gpt-4o-mini"
}
""";

Expand Down Expand Up @@ -303,7 +303,7 @@ public async Task MultipleMessages_NonStreaming()
FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f,
StopSequences = ["great"],
AdditionalProperties = new() { ["seed"] = 42L },
Seed = 42,
});
Assert.NotNull(response);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public async Task PromptBasedFunctionCalling_NoArgs()
ModelId = "llama3:8b",
Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")],
Temperature = 0,
AdditionalProperties = new() { ["seed"] = 0L },
Seed = 0,
});

Assert.Single(response.Choices);
Expand Down Expand Up @@ -83,7 +83,7 @@ public async Task PromptBasedFunctionCalling_WithArgs()
{
Tools = [stockPriceTool, irrelevantTool],
Temperature = 0,
AdditionalProperties = new() { ["seed"] = 0L },
Seed = 0,
});

Assert.Single(response.Choices);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ public async Task MultipleMessages_NonStreaming()
FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f,
StopSequences = ["great"],
AdditionalProperties = new() { ["seed"] = 42 },
Seed = 42,
});
Assert.NotNull(response);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ public async Task MultipleMessages_NonStreaming()
FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f,
StopSequences = ["great"],
AdditionalProperties = new() { ["seed"] = 42 },
Seed = 42,
});
Assert.NotNull(response);

Expand Down

0 comments on commit 17e5ecd

Please sign in to comment.