diff --git a/readme.md b/readme.md index 521d0b7..991a330 100644 --- a/readme.md +++ b/readme.md @@ -1,12 +1,14 @@ # AI Assist -> AI assistant for coding, chat, code explanation, review with supporting local and online language models. +> `Context Aware` AI assistant for coding, chat, code explanation, review with supporting local and online language models. `AIAssist` is compatible with [OpenAI](https://platform.openai.com/docs/api-reference/introduction) and [Azure AI Services](https://azure.microsoft.com/en-us/products/ai-services) through apis or [Ollama models](https://ollama.com/search) through [ollama engine](https://ollama.com/) locally. > [!TIP] > You can use ollama and its models that are more compatible with code like [deepseek-v2.5](https://ollama.com/library/deepseek-v2.5) or [qwen2.5-coder](https://ollama.com/library/qwen2.5-coder) locally. To use local models, you will need to run [Ollama](https://github.com/ollama/ollama) process first. For running ollama you can use [ollama docker](https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image) container. +Note: `vscode` and `jetbrains` plugins are in the plan and I will add them soon. + ## Features - ✅ `Context Aware` ai code assistant through [ai embeddings](src/AIAssistant/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs) which is based on Retrieval Augmented Generation (RAG) or [tree-sitter application summarization](src/AIAssistant/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs) to summarize application context and understanding by AI. diff --git a/src/AIAssistant/Commands/CodeAssistCommand.cs b/src/AIAssistant/Commands/CodeAssistCommand.cs index c2d660d..fce8ffb 100644 --- a/src/AIAssistant/Commands/CodeAssistCommand.cs +++ b/src/AIAssistant/Commands/CodeAssistCommand.cs @@ -63,17 +63,17 @@ public sealed class Settings : CommandSettings [Description("[grey] the type of code assist. it can be `embedding` or `summary`.[/].")] public CodeAssistType? CodeAssistType { get; set; } - [CommandOption("--threshold ")] [Description("[grey] the threshold is a value for using in the `embedding`.[/].")] public decimal? Threshold { get; set; } - [CommandOption("--temperature ")] [Description( "[grey] the temperature is a value for controlling creativity or randomness on the llm response.[/]." )] public decimal? Temperature { get; set; } - [CommandOption("--chat-api-key ")] + [CommandOption("--chat-api-key ")] [Description("[grey] the chat model api key.[/].")] public string? ChatModelApiKey { get; set; } @@ -159,7 +159,7 @@ await AnsiConsole console.Write(new Rule()); - userInput = "can you remove all comments from Add.cs file?"; + //userInput = "can you remove all comments from Add.cs file?"; _running = await internalCommandProcessor.ProcessCommand(userInput, scope); } diff --git a/src/AIAssistant/Contracts/IEmbeddingService.cs b/src/AIAssistant/Contracts/IEmbeddingService.cs index 9546162..da68bf1 100644 --- a/src/AIAssistant/Contracts/IEmbeddingService.cs +++ b/src/AIAssistant/Contracts/IEmbeddingService.cs @@ -1,3 +1,4 @@ +using System.Collections; using AIAssistant.Chat.Models; using AIAssistant.Data; using AIAssistant.Dtos; @@ -9,7 +10,7 @@ namespace AIAssistant.Contracts; public interface IEmbeddingService { Task AddOrUpdateEmbeddingsForFiles( - IEnumerable codeFilesMap, + IList codeFilesMap, ChatSession chatSession ); Task GetRelatedEmbeddings(string userQuery, ChatSession chatSession); diff --git a/src/AIAssistant/Contracts/ILLMClientManager.cs b/src/AIAssistant/Contracts/ILLMClientManager.cs index 8e053f5..ac829e7 100644 --- a/src/AIAssistant/Contracts/ILLMClientManager.cs +++ b/src/AIAssistant/Contracts/ILLMClientManager.cs @@ -14,7 +14,7 @@ public interface ILLMClientManager CancellationToken cancellationToken = default ); Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ); diff --git a/src/AIAssistant/Dtos/GetBatchEmbeddingResult.cs b/src/AIAssistant/Dtos/GetBatchEmbeddingResult.cs new file mode 100644 index 0000000..f703ac4 --- /dev/null +++ b/src/AIAssistant/Dtos/GetBatchEmbeddingResult.cs @@ -0,0 +1,8 @@ +namespace AIAssistant.Dtos; + +public class GetBatchEmbeddingResult(IList> embeddings, int totalTokensCount, decimal totalCost) +{ + public IList> Embeddings { get; } = embeddings; + public int TotalTokensCount { get; } = totalTokensCount; + public decimal TotalCost { get; } = totalCost; +} diff --git a/src/AIAssistant/Dtos/GetEmbeddingResult.cs b/src/AIAssistant/Dtos/GetEmbeddingResult.cs index c8292a8..4c5b112 100644 --- a/src/AIAssistant/Dtos/GetEmbeddingResult.cs +++ b/src/AIAssistant/Dtos/GetEmbeddingResult.cs @@ -1,3 +1,7 @@ namespace AIAssistant.Dtos; -public record GetEmbeddingResult(IList Embeddings, int TotalTokensCount, decimal TotalCost); +public record GetEmbeddingResult( + IList> Embeddings, // Multiple embeddings for batch + int TotalTokensCount, + decimal TotalCost +); diff --git a/src/AIAssistant/Models/FileBatch.cs b/src/AIAssistant/Models/FileBatch.cs new file mode 100644 index 0000000..719cec5 --- /dev/null +++ b/src/AIAssistant/Models/FileBatch.cs @@ -0,0 +1,18 @@ +namespace AIAssistant.Models; + +/// +/// Represents a batch of files and their chunks to be processed in a single embedding request. +/// +public class FileBatch +{ + public IList Files { get; set; } = new List(); + public int TotalTokens { get; set; } + + /// + /// Combines all chunked inputs for this batch into a single list for API calls. + /// + public IList GetBatchInputs() + { + return Files.SelectMany(file => file.Chunks).ToList(); + } +} diff --git a/src/AIAssistant/Models/FileChunkGroup.cs b/src/AIAssistant/Models/FileChunkGroup.cs new file mode 100644 index 0000000..c980a88 --- /dev/null +++ b/src/AIAssistant/Models/FileChunkGroup.cs @@ -0,0 +1,14 @@ +using TreeSitter.Bindings.CustomTypes.TreeParser; + +namespace AIAssistant.Models; + +/// +/// Represents a file and its associated chunks for embedding. +/// +public class FileChunkGroup(CodeFileMap file, List chunks) +{ + public CodeFileMap File { get; } = file; + public IList Chunks { get; } = chunks; + + public string Input => string.Join("\n", Chunks); +} diff --git a/src/AIAssistant/Services/EmbeddingService.cs b/src/AIAssistant/Services/EmbeddingService.cs index b8a1fdb..96cd63d 100644 --- a/src/AIAssistant/Services/EmbeddingService.cs +++ b/src/AIAssistant/Services/EmbeddingService.cs @@ -4,7 +4,6 @@ using AIAssistant.Dtos; using AIAssistant.Models; using BuildingBlocks.LLM; -using BuildingBlocks.Utils; using TreeSitter.Bindings.CustomTypes.TreeParser; namespace AIAssistant.Services; @@ -12,41 +11,79 @@ namespace AIAssistant.Services; public class EmbeddingService( ILLMClientManager llmClientManager, ICodeEmbeddingsRepository codeEmbeddingsRepository, - IPromptManager promptManager + IPromptManager promptManager, + ITokenizer tokenizer ) : IEmbeddingService { public async Task AddOrUpdateEmbeddingsForFiles( - IEnumerable codeFilesMap, + IList codeFilesMap, ChatSession chatSession ) { int totalTokens = 0; decimal totalCost = 0; - IList codeEmbeddings = new List(); + var fileEmbeddingsMap = new Dictionary>>(); + + // Group files and manage batching using the updated tokenizer logic + var fileBatches = await BatchFilesByTokenLimitAsync(codeFilesMap, maxBatchTokens: 8192); + + foreach (var batch in fileBatches) + { + var batchInputs = batch.GetBatchInputs(); + var embeddingResult = await llmClientManager.GetEmbeddingAsync(batchInputs, null); + + int resultIndex = 0; + foreach (var fileChunkGroup in batch.Files) + { + // Extract embeddings for the current file's chunks + var fileEmbeddings = embeddingResult + .Embeddings.Skip(resultIndex) + .Take(fileChunkGroup.Chunks.Count) + .ToList(); + + resultIndex += fileChunkGroup.Chunks.Count; + + // Group embeddings by file path + if (!fileEmbeddingsMap.TryGetValue(fileChunkGroup.File.RelativePath, out List>? value)) + { + value = new List>(); + fileEmbeddingsMap[fileChunkGroup.File.RelativePath] = value; + } + + value.AddRange(fileEmbeddings); + } + + totalTokens += embeddingResult.TotalTokensCount; + totalCost += embeddingResult.TotalCost; + } - foreach (var codeFileMap in codeFilesMap) + // Merge and create final embeddings for each file + var codeEmbeddings = new List(); + foreach (var entry in fileEmbeddingsMap) { - var input = promptManager.GetEmbeddingInputString(codeFileMap.TreeSitterFullCode); - var embeddingResult = await llmClientManager.GetEmbeddingAsync(input, codeFileMap.RelativePath); + var filePath = entry.Key; + var embeddings = entry.Value; + + // Merge embeddings for the file + var mergedEmbedding = MergeEmbeddings(embeddings); + + // Retrieve the original file details from codeFilesMap + var fileDetails = codeFilesMap.First(file => file.RelativePath == filePath); codeEmbeddings.Add( new CodeEmbedding { - RelativeFilePath = codeFileMap.RelativePath, - TreeSitterFullCode = codeFileMap.TreeSitterFullCode, - TreeOriginalCode = codeFileMap.TreeOriginalCode, - Code = codeFileMap.OriginalCode, + RelativeFilePath = fileDetails.RelativePath, + TreeSitterFullCode = fileDetails.TreeSitterFullCode, + TreeOriginalCode = fileDetails.TreeOriginalCode, + Code = fileDetails.OriginalCode, SessionId = chatSession.SessionId, - Embeddings = embeddingResult.Embeddings, + Embeddings = mergedEmbedding, } ); - - totalTokens += embeddingResult.TotalTokensCount; - totalCost += embeddingResult.TotalCost; } - // we can replace it with an embedded database like `chromadb`, it can give us n of most similarity items await codeEmbeddingsRepository.AddOrUpdateCodeEmbeddings(codeEmbeddings); return new AddEmbeddingsForFilesResult(totalTokens, totalCost); @@ -59,7 +96,7 @@ public async Task GetRelatedEmbeddings(string userQu // Find relevant code based on the user query var relevantCodes = codeEmbeddingsRepository.Query( - embeddingsResult.Embeddings, + embeddingsResult.Embeddings.First(), chatSession.SessionId, llmClientManager.EmbeddingThreshold ); @@ -82,6 +119,147 @@ public IEnumerable QueryByFilter( public async Task GenerateEmbeddingForUserInput(string userInput) { - return await llmClientManager.GetEmbeddingAsync(userInput, null); + return await llmClientManager.GetEmbeddingAsync(new List { userInput }, null); + } + + private async Task> BatchFilesByTokenLimitAsync( + IEnumerable codeFilesMap, + int maxBatchTokens + ) + { + var fileBatches = new List(); + var currentBatch = new FileBatch(); + + foreach (var file in codeFilesMap) + { + // Convert the full code to an input string and split into chunks + var input = promptManager.GetEmbeddingInputString(file.TreeSitterFullCode); + var chunks = await SplitTextIntoChunksAsync(input, maxTokens: 8192); + + var tokenCountTasks = chunks.Select(chunk => tokenizer.GetTokenCount(chunk)); + var tokenCounts = await Task.WhenAll(tokenCountTasks); + + // Pair chunks with their token counts + var chunkWithTokens = chunks.Zip( + tokenCounts, + (chunk, tokenCount) => new { Chunk = chunk, TokenCount = tokenCount } + ); + + foreach (var chunkGroup in chunkWithTokens) + { + // If adding this chunk would exceed the batch token limit + if (currentBatch.TotalTokens + chunkGroup.TokenCount > maxBatchTokens && currentBatch.Files.Count > 0) + { + // Finalize the current batch and start a new one + fileBatches.Add(currentBatch); + currentBatch = new FileBatch(); + } + + // Add this chunk to the current batch + if (currentBatch.Files.All(f => f.File != file)) + { + // If this is the first chunk of this file in the current batch, add a new FileChunkGroup + currentBatch.Files.Add(new FileChunkGroup(file, new List { chunkGroup.Chunk })); + } + else + { + // Add the chunk to the existing FileChunkGroup for this file + var fileGroup = currentBatch.Files.First(f => f.File == file); + fileGroup.Chunks.Add(chunkGroup.Chunk); + } + + currentBatch.TotalTokens += chunkGroup.TokenCount; + } + } + + // Add the last batch if it has content + if (currentBatch.Files.Count > 0) + { + fileBatches.Add(currentBatch); + } + + return fileBatches; + } + + private async Task> SplitTextIntoChunksAsync(string text, int maxTokens) + { + var words = text.Split(' '); + var chunks = new List(); + var currentChunk = new List(); + + foreach (var word in words) + { + currentChunk.Add(word); + + // Check token count only when the chunk exceeds a certain word threshold + if (currentChunk.Count % 50 == 0 || currentChunk.Count == words.Length) + { + var currentText = string.Join(" ", currentChunk); + var currentTokenCount = await tokenizer.GetTokenCount(currentText); + + if (currentTokenCount > maxTokens) + { + // Ensure the chunk size is within limits + while (currentTokenCount > maxTokens && currentChunk.Count > 1) + { + currentChunk.RemoveAt(currentChunk.Count - 1); + currentText = string.Join(" ", currentChunk); + currentTokenCount = await tokenizer.GetTokenCount(currentText); + } + + // Add the finalized chunk only if it fits the token limit + if (currentTokenCount <= maxTokens) + { + chunks.Add(currentText); + } + + // Start a new chunk with the current word + currentChunk.Clear(); + currentChunk.Add(word); + } + } + } + + // Add the final chunk if it has content and is within the token limit + if (currentChunk.Count > 0) + { + var finalText = string.Join(" ", currentChunk); + var finalTokenCount = await tokenizer.GetTokenCount(finalText); + + if (finalTokenCount <= maxTokens) + { + chunks.Add(finalText); + } + } + + return chunks; + } + + private IList MergeEmbeddings(IList> embeddings) + { + if (embeddings == null || embeddings.Count == 0) + throw new ArgumentException("The embeddings list cannot be null or empty."); + + int dimension = embeddings.First().Count; + var mergedEmbedding = new double[dimension]; + + foreach (var embedding in embeddings) + { + if (embedding.Count != dimension) + throw new InvalidOperationException("All embeddings must have the same dimensionality."); + + for (int i = 0; i < dimension; i++) + { + mergedEmbedding[i] += embedding[i]; + } + } + + // Average the embeddings to unify them into one + for (int i = 0; i < dimension; i++) + { + mergedEmbedding[i] /= embeddings.Count; + } + + return mergedEmbedding; } } diff --git a/src/AIAssistant/Services/LLMClientManager.cs b/src/AIAssistant/Services/LLMClientManager.cs index 1207300..74c818e 100644 --- a/src/AIAssistant/Services/LLMClientManager.cs +++ b/src/AIAssistant/Services/LLMClientManager.cs @@ -89,19 +89,20 @@ ICacheModels cacheModels } public async Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ) { var llmClientStratgey = _clientFactory.CreateClient(EmbeddingModel.ModelInformation.AIProvider); - var embeddingResponse = await llmClientStratgey.GetEmbeddingAsync(input, path, cancellationToken); + var embeddingResponse = await llmClientStratgey.GetEmbeddingAsync(inputs, path, cancellationToken); // in embedding output tokens and its cost is 0 - var inputTokens = embeddingResponse?.TokenUsage?.InputTokens ?? await _tokenizer.GetTokenCount(input); + var inputTokens = + embeddingResponse?.TokenUsage?.InputTokens ?? await _tokenizer.GetTokenCount(string.Concat(inputs)); var cost = inputTokens * EmbeddingModel.ModelInformation.InputCostPerToken; - return new GetEmbeddingResult(embeddingResponse?.Embeddings ?? new List(), inputTokens, cost); + return new GetEmbeddingResult(embeddingResponse?.Embeddings ?? new List>(), inputTokens, cost); } } diff --git a/src/AIAssistant/aiassist-config.json b/src/AIAssistant/aiassist-config.json index 6732b65..77b9baf 100644 --- a/src/AIAssistant/aiassist-config.json +++ b/src/AIAssistant/aiassist-config.json @@ -1,23 +1,12 @@ { "Test":"Test1", - "ModelsOptions": { - "azure/gpt-4o": { - "CodeDiffType": "CodeBlockDiff", - "CodeAssistType": "Embedding", - "Temperature": 0.2 - }, - "azure/text-embedding-3-large": { - "Threshold": 0.3, - "Temperature": 0.2 - } - }, "AppOptions": { "ThemeName": "dracula", "PrintCostEnabled": true }, "LLMOptions": { - "ChatModel": "azure/gpt-4o", - "EmbeddingsModel": "azure/text-embedding-3-large" + "ChatModel": "ollama/llama3", + "EmbeddingsModel": "ollama/nomic-embed-text" }, "Serilog": { "MinimumLevel": { diff --git a/src/BuildingBlocks/LLM/Tokenizers/GptTokenizer.cs b/src/BuildingBlocks/LLM/Tokenizers/GptTokenizer.cs index 878cf93..9255cfa 100644 --- a/src/BuildingBlocks/LLM/Tokenizers/GptTokenizer.cs +++ b/src/BuildingBlocks/LLM/Tokenizers/GptTokenizer.cs @@ -7,12 +7,12 @@ namespace BuildingBlocks.LLM.Tokenizers; public class GptTokenizer(string modelName = "GPT-4o") : ITokenizer { + // https://learn.microsoft.com/en-us/dotnet/machine-learning/whats-new/overview#additional-tokenizer-support + private readonly Tokenizer _tokenizer = TiktokenTokenizer.CreateForModel(modelName); + public Task GetVectorTokens(string prompt) { - // https://learn.microsoft.com/en-us/dotnet/machine-learning/whats-new/overview#additional-tokenizer-support - Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(modelName); - - IReadOnlyList encodedIds = tokenizer.EncodeToIds(prompt); + IReadOnlyList encodedIds = _tokenizer.EncodeToIds(prompt); return Task.FromResult(encodedIds.Select(x => (double)x).ToArray()); } @@ -20,8 +20,6 @@ public Task GetVectorTokens(string prompt) public Task GetTokenCount(string prompt) { // https://learn.microsoft.com/en-us/dotnet/machine-learning/whats-new/overview - Tokenizer tokenizer = TiktokenTokenizer.CreateForModel(modelName); - - return Task.FromResult(tokenizer.CountTokens(prompt)); + return Task.FromResult(_tokenizer.CountTokens(prompt)); } } diff --git a/src/Clients/AnthropicClient.cs b/src/Clients/AnthropicClient.cs index a376046..c2ea97d 100644 --- a/src/Clients/AnthropicClient.cs +++ b/src/Clients/AnthropicClient.cs @@ -194,7 +194,7 @@ AsyncPolicyWrap combinedPolicy } public Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ) diff --git a/src/Clients/AzureClient.cs b/src/Clients/AzureClient.cs index f030a63..f333796 100644 --- a/src/Clients/AzureClient.cs +++ b/src/Clients/AzureClient.cs @@ -239,15 +239,20 @@ AsyncPolicyWrap combinedPolicy } public async Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ) { - await ValidateEmbeddingMaxInputToken(input); - ValidateRequestSizeAndContent(input); + await ValidateEmbeddingMaxInputToken(string.Concat(inputs)); + ValidateRequestSizeAndContent(string.Concat(inputs)); - var requestBody = new { input = new[] { input }, model = _embeddingModel.Name.Trim() }; + var requestBody = new + { + input = inputs, + model = _embeddingModel.Name.Trim(), + dimensions = _embeddingModel.ModelInformation.EmbeddingDimensions, + }; var client = httpClientFactory.CreateClient("llm_embeddings_client"); @@ -288,8 +293,6 @@ AsyncPolicyWrap combinedPolicy HandleException(httpResponseMessage, embeddingResponse); - var embedding = embeddingResponse.Data.FirstOrDefault()?.Embedding ?? new List(); - var inputTokens = embeddingResponse.Usage?.PromptTokens ?? 0; var outTokens = embeddingResponse.Usage?.CompletionTokens ?? 0; var inputCostPerToken = _embeddingModel.ModelInformation.InputCostPerToken; @@ -297,6 +300,8 @@ AsyncPolicyWrap combinedPolicy ValidateEmbeddingMaxToken(inputTokens + outTokens, path); + var embedding = embeddingResponse.Data?.Select(x => x.Embedding).ToList() ?? new List>(); + return new EmbeddingsResponse( embedding, new TokenUsageResponse(inputTokens, inputCostPerToken, outTokens, outputCostPerToken) diff --git a/src/Clients/Contracts/ILLMClient.cs b/src/Clients/Contracts/ILLMClient.cs index 34dc934..5306297 100644 --- a/src/Clients/Contracts/ILLMClient.cs +++ b/src/Clients/Contracts/ILLMClient.cs @@ -1,5 +1,4 @@ using Clients.Dtos; -using Clients.Models; namespace Clients.Contracts; @@ -14,7 +13,7 @@ public interface ILLMClient CancellationToken cancellationToken = default ); Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ); diff --git a/src/Clients/Dtos/BatchEmbeddingsResponse.cs b/src/Clients/Dtos/BatchEmbeddingsResponse.cs new file mode 100644 index 0000000..d2019c1 --- /dev/null +++ b/src/Clients/Dtos/BatchEmbeddingsResponse.cs @@ -0,0 +1,8 @@ +namespace Clients.Dtos; + +public class BatchEmbeddingsResponse(IList> embeddings, int totalTokensCount, decimal totalCost) +{ + public IList> Embeddings { get; } = embeddings; + public int TotalTokensCount { get; } = totalTokensCount; + public decimal TotalCost { get; } = totalCost; +} diff --git a/src/Clients/Dtos/EmbeddingsResponse.cs b/src/Clients/Dtos/EmbeddingsResponse.cs index 30aa4a8..f52a256 100644 --- a/src/Clients/Dtos/EmbeddingsResponse.cs +++ b/src/Clients/Dtos/EmbeddingsResponse.cs @@ -1,3 +1,3 @@ namespace Clients.Dtos; -public record EmbeddingsResponse(IList? Embeddings, TokenUsageResponse? TokenUsage); +public record EmbeddingsResponse(IList>? Embeddings, TokenUsageResponse? TokenUsage); diff --git a/src/Clients/LLMs/models_information_list.json b/src/Clients/LLMs/models_information_list.json index 5fec366..eafa543 100644 --- a/src/Clients/LLMs/models_information_list.json +++ b/src/Clients/LLMs/models_information_list.json @@ -18,7 +18,8 @@ "InputCostPerToken": 0.00000013, "OutputCostPerToken": 0.000000, "AIProvider": "Openai", - "ModelType": "Embedding" + "ModelType": "Embedding", + "EmbeddingDimensions": 1024 }, "text-embedding-3-small": { "MaxTokens": 8191, @@ -27,7 +28,8 @@ "InputCostPerToken": 0.00000002, "OutputCostPerToken": 0.000000, "AIProvider": "Openai", - "ModelType": "Embedding" + "ModelType": "Embedding", + "EmbeddingDimensions": 512 }, "azure/gpt-4o": { "MaxTokens": 8192, @@ -46,7 +48,8 @@ "InputCostPerToken": 0.00000013, "OutputCostPerToken": 0.000000, "AIProvider": "Azure", - "ModelType": "Embedding" + "ModelType": "Embedding", + "EmbeddingDimensions": 1024 }, "azure/text-embedding-3-small": { "MaxTokens": 8191, @@ -54,7 +57,8 @@ "InputCostPerToken": 0.00000002, "OutputCostPerToken": 0.000000, "AIProvider": "Azure", - "ModelType": "Embedding" + "ModelType": "Embedding", + "EmbeddingDimensions": 512 }, "ollama/codegeex4": { "MaxTokens": 32768, @@ -134,9 +138,9 @@ "SupportsFunctionCalling": true }, "ollama/llama3.2": { - "MaxTokens": 32768, - "MaxInputTokens": 8192, - "MaxOutputTokens": 8192, + "MaxTokens": 128000, + "MaxInputTokens": 128000, + "MaxOutputTokens": 4000, "InputCostPerToken": 0.0, "OutputCostPerToken": 0.0, "AIProvider": "Ollama", @@ -153,8 +157,8 @@ "ModelType": "Chat" }, "ollama/nomic-embed-text": { - "MaxTokens": 2048, - "MaxInputTokens": 2048, + "MaxTokens": 8192, + "MaxInputTokens": 8192, "InputCostPerToken": 0.0, "OutputCostPerToken": 0.0, "AIProvider": "Ollama", @@ -162,8 +166,8 @@ "EmbeddingDimensions": 512 }, "ollama/mxbai-embed-large": { - "MaxTokens": 4096, - "MaxInputTokens": 2048, + "MaxTokens": 8192, + "MaxInputTokens": 8192, "InputCostPerToken": 0.0, "OutputCostPerToken": 0.0, "AIProvider": "Ollama", diff --git a/src/Clients/Models/Ollama/Embeddings/LlamaEmbeddingResponse.cs b/src/Clients/Models/Ollama/Embeddings/LlamaEmbeddingResponse.cs index 8207669..cf1ea86 100644 --- a/src/Clients/Models/Ollama/Embeddings/LlamaEmbeddingResponse.cs +++ b/src/Clients/Models/Ollama/Embeddings/LlamaEmbeddingResponse.cs @@ -5,5 +5,5 @@ namespace Clients.Models.Ollama.Embeddings; public class LlamaEmbeddingResponse : OllamaResponseBase { [JsonPropertyName("embeddings")] - public IList> Embeddings { get; set; } = default!; + public IList>? Embeddings { get; set; } = default!; } diff --git a/src/Clients/Models/OpenAI/Embeddings/OpenAIEmbeddingResponse.cs b/src/Clients/Models/OpenAI/Embeddings/OpenAIEmbeddingResponse.cs index 578ad7b..cb43556 100644 --- a/src/Clients/Models/OpenAI/Embeddings/OpenAIEmbeddingResponse.cs +++ b/src/Clients/Models/OpenAI/Embeddings/OpenAIEmbeddingResponse.cs @@ -5,5 +5,5 @@ namespace Clients.Models.OpenAI.Embeddings; public class OpenAIEmbeddingResponse : OpenAIBaseResponse { [JsonPropertyName("data")] - public IList Data { get; set; } = new List(); + public IList? Data { get; set; } = new List(); } diff --git a/src/Clients/OllamaClient.cs b/src/Clients/OllamaClient.cs index a1d6553..0796659 100644 --- a/src/Clients/OllamaClient.cs +++ b/src/Clients/OllamaClient.cs @@ -182,18 +182,18 @@ AsyncPolicyWrap combinedPolicy } public async Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ) { - await ValidateEmbeddingMaxInputToken(input, path); - ValidateRequestSizeAndContent(input); + await ValidateEmbeddingMaxInputToken(string.Concat(inputs), path); + ValidateRequestSizeAndContent(string.Concat(inputs)); // https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings var requestBody = new { - input = new[] { input }, + input = inputs, model = _embeddingModel.Name, options = new { temperature = _embeddingModel.ModelOption.Temperature }, keep_alive = "30m", @@ -219,8 +219,6 @@ AsyncPolicyWrap combinedPolicy HandleException(httpResponseMessage, embeddingResponse); - var embedding = embeddingResponse.Embeddings.FirstOrDefault() ?? new List(); - var inputTokens = embeddingResponse.PromptEvalCount; var outTokens = embeddingResponse.EvalCount; var inputCostPerToken = _embeddingModel.ModelInformation.InputCostPerToken; @@ -228,6 +226,8 @@ AsyncPolicyWrap combinedPolicy ValidateEmbeddingMaxToken(inputTokens + outTokens, path); + var embedding = embeddingResponse.Embeddings ?? new List>(); + return new EmbeddingsResponse( embedding, new TokenUsageResponse(inputTokens, inputCostPerToken, outTokens, outputCostPerToken) diff --git a/src/Clients/OpenAiClient.cs b/src/Clients/OpenAiClient.cs index 424a3f4..3ffabde 100644 --- a/src/Clients/OpenAiClient.cs +++ b/src/Clients/OpenAiClient.cs @@ -204,15 +204,20 @@ AsyncPolicyWrap combinedPolicy } public async Task GetEmbeddingAsync( - string input, + IList inputs, string? path, CancellationToken cancellationToken = default ) { - await ValidateEmbeddingMaxInputToken(input, path); - ValidateRequestSizeAndContent(input); + await ValidateEmbeddingMaxInputToken(string.Concat(inputs), path); + ValidateRequestSizeAndContent(string.Concat(inputs)); - var requestBody = new { input = new[] { input }, model = _embeddingModel.Name.Trim() }; + var requestBody = new + { + input = inputs, + model = _embeddingModel.Name.Trim(), + dimensions = _embeddingModel.ModelInformation.EmbeddingDimensions, + }; var client = httpClientFactory.CreateClient("llm_embeddings_client"); @@ -236,8 +241,6 @@ AsyncPolicyWrap combinedPolicy HandleException(httpResponseMessage, embeddingResponse); - var embedding = embeddingResponse.Data.FirstOrDefault()?.Embedding ?? new List(); - var inputTokens = embeddingResponse.Usage?.PromptTokens ?? 0; var outTokens = embeddingResponse.Usage?.CompletionTokens ?? 0; var inputCostPerToken = _embeddingModel.ModelInformation.InputCostPerToken; @@ -245,6 +248,8 @@ AsyncPolicyWrap combinedPolicy ValidateEmbeddingMaxToken(inputTokens + outTokens, path); + var embedding = embeddingResponse.Data?.Select(x => x.Embedding).ToList() ?? new List>(); + return new EmbeddingsResponse( embedding, new TokenUsageResponse(inputTokens, inputCostPerToken, outTokens, outputCostPerToken)