Skip to content

Commit

Permalink
refactor: ♻️ enhance cache-models and model-options loading (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdihadeli authored Nov 30, 2024
1 parent fec4cc3 commit d5c8e27
Show file tree
Hide file tree
Showing 25 changed files with 420 additions and 440 deletions.
54 changes: 41 additions & 13 deletions src/AIAssist/Commands/CodeAssistCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ IOptions<AppOptions> appOptions
private readonly AppOptions _appOptions = appOptions.Value;
private readonly Model _chatModel =
cacheModels.GetModel(llmOptions.Value.ChatModel)
?? throw new KeyNotFoundException($"Model '{llmOptions.Value.ChatModel}' not found in the ModelCache.");
?? throw new ArgumentNullException($"Model '{llmOptions.Value.ChatModel}' not found in the ModelCache.");
private readonly Model? _embeddingModel = cacheModels.GetModel(llmOptions.Value.EmbeddingsModel);

private static bool _running = true;
Expand Down Expand Up @@ -122,10 +122,22 @@ public override async Task<int> ExecuteAsync(CommandContext context, Settings se
SetupOptions(settings);

spectreUtilities.SummaryTextLine("Code assist mode is activated!");
spectreUtilities.SummaryTextLine(
$"Chat model: {_chatModel.Name} | Embedding model: {_embeddingModel?.Name ?? "-"} | CodeAssistType: {_chatModel.ModelOption.CodeAssistType} | CodeDiffType: {_chatModel.ModelOption.CodeDiffType}"
spectreUtilities.NormalText("Chat model: ");
spectreUtilities.HighlightTextLine(_chatModel.Name);

spectreUtilities.NormalText("Embedding model: ");
spectreUtilities.HighlightTextLine(_embeddingModel?.Name ?? "-");

spectreUtilities.NormalText("CodeAssistType: ");
spectreUtilities.HighlightTextLine(_chatModel.CodeAssistType.ToString());

spectreUtilities.NormalText("CodeDiffType: ");
spectreUtilities.HighlightTextLine(_chatModel.CodeDiffType.ToString());

spectreUtilities.NormalTextLine(
"Please 'Ctrl+H' to see all available commands in the code assist mode.",
decoration: Decoration.Bold
);
spectreUtilities.SummaryTextLine("Please 'Ctrl+H' to see all available commands in the code assist mode.");
spectreUtilities.WriteRule();

await AnsiConsole
Expand Down Expand Up @@ -189,42 +201,42 @@ private void SetupOptions(Settings settings)

if (!string.IsNullOrEmpty(settings.ChatModelApiKey))
{
_chatModel.ModelOption.ApiKey = settings.ChatModelApiKey.Trim();
_chatModel.ApiKey = settings.ChatModelApiKey.Trim();
}

if (!string.IsNullOrEmpty(settings.ChatApiVersion))
{
_chatModel.ModelOption.ApiVersion = settings.ChatApiVersion.Trim();
_chatModel.ApiVersion = settings.ChatApiVersion.Trim();
}

if (!string.IsNullOrEmpty(settings.ChatDeploymentId))
{
_chatModel.ModelOption.DeploymentId = settings.ChatDeploymentId.Trim();
_chatModel.DeploymentId = settings.ChatDeploymentId.Trim();
}

if (!string.IsNullOrEmpty(settings.ChatBaseAddress))
{
_chatModel.ModelOption.BaseAddress = settings.ChatBaseAddress.Trim();
_chatModel.BaseAddress = settings.ChatBaseAddress.Trim();
}

if (!string.IsNullOrEmpty(settings.EmbeddingsModelApiKey) && _embeddingModel is not null)
{
_embeddingModel.ModelOption.ApiKey = settings.EmbeddingsModelApiKey.Trim();
_embeddingModel.ApiKey = settings.EmbeddingsModelApiKey.Trim();
}

if (!string.IsNullOrEmpty(settings.EmbeddingsApiVersion) && _embeddingModel is not null)
{
_embeddingModel.ModelOption.ApiVersion = settings.EmbeddingsApiVersion.Trim();
_embeddingModel.ApiVersion = settings.EmbeddingsApiVersion.Trim();
}

if (!string.IsNullOrEmpty(settings.EmbeddingsDeploymentId) && _embeddingModel is not null)
{
_embeddingModel.ModelOption.DeploymentId = settings.EmbeddingsDeploymentId.Trim();
_embeddingModel.DeploymentId = settings.EmbeddingsDeploymentId.Trim();
}

if (!string.IsNullOrEmpty(settings.EmbeddingsBaseAddress) && _embeddingModel is not null)
{
_embeddingModel.ModelOption.BaseAddress = settings.EmbeddingsBaseAddress.Trim();
_embeddingModel.BaseAddress = settings.EmbeddingsBaseAddress.Trim();
}

_appOptions.ContextWorkingDirectory = !string.IsNullOrEmpty(settings.ContextWorkingDirectory)
Expand All @@ -246,21 +258,37 @@ private void SetupOptions(Settings settings)
if (settings.CodeDiffType is not null)
{
_llmOptions.CodeDiffType = settings.CodeDiffType.Value;
_chatModel.CodeDiffType = settings.CodeDiffType.Value;

if (_embeddingModel != null)
_embeddingModel.CodeDiffType = settings.CodeDiffType.Value;
}

if (settings.CodeAssistType is not null)
{
_llmOptions.CodeAssistType = settings.CodeAssistType.Value;
_chatModel.CodeAssistType = settings.CodeAssistType.Value;

if (_embeddingModel != null)
_embeddingModel.CodeAssistType = settings.CodeAssistType.Value;
}

if (settings.Threshold is not null && _embeddingModel is not null)
if (settings.Threshold is not null)
{
_llmOptions.Threshold = settings.Threshold.Value;
_chatModel.Threshold = settings.Threshold.Value;

if (_embeddingModel != null)
_embeddingModel.Threshold = settings.Threshold.Value;
}

if (settings.Temperature is not null)
{
_llmOptions.Temperature = settings.Temperature.Value;
_chatModel.Temperature = settings.Temperature.Value;

if (_embeddingModel != null)
_embeddingModel.Temperature = settings.Temperature.Value;
}
}
}
Expand Down
34 changes: 21 additions & 13 deletions src/AIAssist/Extensions/DependencyInjectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ private static void AddCodeAssistDependencies(HostApplicationBuilder builder)

var chatModel = cacheModels.GetModel(llmOptions.Value.ChatModel);

ICodeAssist codeAssist = factory.Create(chatModel.ModelOption.CodeAssistType);
ArgumentNullException.ThrowIfNull(chatModel);

ICodeAssist codeAssist = factory.Create(chatModel.CodeAssistType);

return new CodeAssistantManager(codeAssist, codeDiffManager);
});
Expand Down Expand Up @@ -284,17 +286,19 @@ private static void AddClientDependencies(HostApplicationBuilder builder)
var options = sp.GetRequiredService<IOptions<LLMOptions>>().Value;
var policyOptions = sp.GetRequiredService<IOptions<PolicyOptions>>().Value;

var cacheModels = sp.GetRequiredService<ICacheModels>();
ArgumentException.ThrowIfNullOrEmpty(options.ChatModel);

var cacheModels = sp.GetRequiredService<ICacheModels>();
var chatModel = cacheModels.GetModel(options.ChatModel);
ArgumentNullException.ThrowIfNull(chatModel);

client.Timeout = TimeSpan.FromSeconds(policyOptions.TimeoutSeconds);

var chatApiKey =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatModelApiKey)
?? chatModel.ModelOption.ApiKey;
?? chatModel.ApiKey;

switch (chatModel.ModelInformation.AIProvider)
switch (chatModel.AIProvider)
{
case AIProvider.Openai:
{
Expand All @@ -303,7 +307,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder)

var baseAddress =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatBaseAddress)
?? chatModel.ModelOption.BaseAddress
?? chatModel.BaseAddress
?? "https://api.openai.com";

client.BaseAddress = new Uri(baseAddress.Trim());
Expand All @@ -320,7 +324,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder)

var baseAddress =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatBaseAddress)
?? chatModel.ModelOption.BaseAddress;
?? chatModel.BaseAddress;
ArgumentException.ThrowIfNullOrEmpty(baseAddress);

client.BaseAddress = new Uri(baseAddress.Trim());
Expand All @@ -332,7 +336,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder)
{
var baseAddress =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.ChatBaseAddress)
?? chatModel.ModelOption.BaseAddress
?? chatModel.BaseAddress
?? "http://localhost:11434";

// https://github.com/ollama/ollama/blob/main/docs/api.md
Expand All @@ -359,15 +363,17 @@ private static void AddClientDependencies(HostApplicationBuilder builder)
var cacheModels = sp.GetRequiredService<ICacheModels>();

ArgumentException.ThrowIfNullOrEmpty(options.EmbeddingsModel);

var embeddingModel = cacheModels.GetModel(options.EmbeddingsModel);
ArgumentNullException.ThrowIfNull(embeddingModel);

client.Timeout = TimeSpan.FromSeconds(policyOptions.TimeoutSeconds);

var embeddingsApiKey =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsModelApiKey)
?? embeddingModel.ModelOption.ApiKey;
?? embeddingModel.ApiKey;

switch (embeddingModel.ModelInformation.AIProvider)
switch (embeddingModel.AIProvider)
{
case AIProvider.Openai:
{
Expand All @@ -376,7 +382,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder)

var baseAddress =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsBaseAddress)
?? embeddingModel.ModelOption.BaseAddress
?? embeddingModel.BaseAddress
?? "https://api.openai.com";

client.BaseAddress = new Uri(baseAddress.Trim());
Expand All @@ -393,7 +399,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder)

var baseAddress =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsBaseAddress)
?? embeddingModel.ModelOption.BaseAddress;
?? embeddingModel.BaseAddress;
ArgumentException.ThrowIfNullOrEmpty(baseAddress);

client.BaseAddress = new Uri(baseAddress.Trim());
Expand All @@ -405,7 +411,7 @@ private static void AddClientDependencies(HostApplicationBuilder builder)
{
var baseAddress =
Environment.GetEnvironmentVariable(ClientsConstants.Environments.EmbeddingsBaseAddress)
?? embeddingModel.ModelOption.BaseAddress
?? embeddingModel.BaseAddress
?? "http://localhost:11434";

// https://github.com/ollama/ollama/blob/main/docs/api.md
Expand Down Expand Up @@ -505,7 +511,9 @@ private static void AddCodeDiffDependency(HostApplicationBuilder builder)
var cacheModels = sp.GetRequiredService<ICacheModels>();
var chatModel = cacheModels.GetModel(options.Value.ChatModel);

var codeDiffParser = factory.Create(chatModel.ModelOption.CodeDiffType);
ArgumentNullException.ThrowIfNull(chatModel);

var codeDiffParser = factory.Create(chatModel.CodeDiffType);

var codeDiffUpdater = sp.GetRequiredService<ICodeDiffUpdater>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ public Task<IEnumerable<string>> GetCodeTreeContents(IList<string>? codeFiles)

var systemPrompt = promptManager.GetSystemPrompt(
embeddingOriginalTreeCodes,
llmClientManager.ChatModel.ModelOption.CodeAssistType,
llmClientManager.ChatModel.ModelOption.CodeDiffType
llmClientManager.ChatModel.CodeAssistType,
llmClientManager.ChatModel.CodeDiffType
);

// Generate a response from the language model (e.g., OpenAI or Llama)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public Task<IEnumerable<string>> GetCodeTreeContents(IList<string>? codeFiles)

var systemPrompt = promptManager.GetSystemPrompt(
summaryTreeCodes,
llmClientManager.ChatModel.ModelOption.CodeAssistType,
llmClientManager.ChatModel.ModelOption.CodeDiffType
llmClientManager.ChatModel.CodeAssistType,
llmClientManager.ChatModel.CodeDiffType
);

// Generate a response from the language model (e.g., OpenAI or Llama)
Expand Down
15 changes: 9 additions & 6 deletions src/AIAssist/Services/LLMClientManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ ICacheModels cacheModels
_tokenizer = tokenizer;

EmbeddingModel = cacheModels.GetModel(llmOptions.Value.EmbeddingsModel);
ChatModel = cacheModels.GetModel(llmOptions.Value.ChatModel);
EmbeddingThreshold = EmbeddingModel.ModelOption.Threshold;
ChatModel =
cacheModels.GetModel(llmOptions.Value.ChatModel)
?? throw new ArgumentNullException($"Model '{llmOptions.Value.ChatModel}' not found in the CacheModels.");
EmbeddingThreshold = EmbeddingModel?.Threshold ?? 0.2m;
}

public Model ChatModel { get; }
public Model EmbeddingModel { get; }
public Model? EmbeddingModel { get; }
public decimal EmbeddingThreshold { get; }

public async IAsyncEnumerable<string?> GetCompletionStreamAsync(
Expand All @@ -50,7 +52,7 @@ ICacheModels cacheModels

var chatItems = chatSession.GetChatItemsFromHistory();

var llmClientStratgey = _clientFactory.CreateClient(ChatModel.ModelInformation.AIProvider);
var llmClientStratgey = _clientFactory.CreateClient(ChatModel.AIProvider);

var chatCompletionResponseStreams = llmClientStratgey.GetCompletionStreamAsync(
new ChatCompletionRequest(chatItems.Select(x => new ChatCompletionRequestItem(x.Role, x.Prompt))),
Expand Down Expand Up @@ -94,14 +96,15 @@ public async Task<GetEmbeddingResult> GetEmbeddingAsync(
CancellationToken cancellationToken = default
)
{
var llmClientStratgey = _clientFactory.CreateClient(EmbeddingModel.ModelInformation.AIProvider);
ArgumentNullException.ThrowIfNull(EmbeddingModel);
var llmClientStratgey = _clientFactory.CreateClient(EmbeddingModel.AIProvider);

var embeddingResponse = await llmClientStratgey.GetEmbeddingAsync(inputs, path, cancellationToken);

// in embedding output tokens and its cost is 0
var inputTokens =
embeddingResponse?.TokenUsage?.InputTokens ?? await _tokenizer.GetTokenCount(string.Concat(inputs));
var cost = inputTokens * EmbeddingModel.ModelInformation.InputCostPerToken;
var cost = inputTokens * EmbeddingModel.InputCostPerToken;

return new GetEmbeddingResult(embeddingResponse?.Embeddings ?? new List<IList<double>>(), inputTokens, cost);
}
Expand Down
2 changes: 2 additions & 0 deletions src/BuildingBlocks/SpectreConsole/ColorTheme.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ public class ColorTheme
{
public string Name { get; set; } = default!;

public string? Foreground { get; set; } = default!;

[JsonPropertyName("console")]
public ConsoleStyle ConsoleStyle { get; set; } = default!;

Expand Down
Loading

0 comments on commit d5c8e27

Please sign in to comment.