Skip to content

Commit

Permalink
feat: ✨ Use Batch Api for embeddings (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdihadeli authored Nov 26, 2024
1 parent 6afbef9 commit 2cf05f6
Show file tree
Hide file tree
Showing 22 changed files with 319 additions and 85 deletions.
4 changes: 3 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# AI Assist

> AI assistant for coding, chat, code explanation, review with supporting local and online language models.
> `Context Aware` AI assistant for coding, chat, code explanation, review with supporting local and online language models.
`AIAssist` is compatible with [OpenAI](https://platform.openai.com/docs/api-reference/introduction) and [Azure AI Services](https://azure.microsoft.com/en-us/products/ai-services) through apis or [Ollama models](https://ollama.com/search) through [ollama engine](https://ollama.com/) locally.

> [!TIP]
> You can use ollama and its models that are more compatible with code like [deepseek-v2.5](https://ollama.com/library/deepseek-v2.5) or [qwen2.5-coder](https://ollama.com/library/qwen2.5-coder) locally. To use local models, you will need to run [Ollama](https://github.com/ollama/ollama) process first. For running ollama you can use [ollama docker](https://ollama.com/blog/ollama-is-now-available-as-an-official-docker-image) container.
Note: `vscode` and `jetbrains` plugins are in the plan and I will add them soon.

## Features

-`Context Aware` ai code assistant through [ai embeddings](src/AIAssistant/Services/CodeAssistStrategies/EmbeddingCodeAssist.cs) which is based on Retrieval Augmented Generation (RAG) or [tree-sitter application summarization](src/AIAssistant/Services/CodeAssistStrategies/TreeSitterCodeAssistSummary.cs) to summarize application context and understanding by AI.
Expand Down
8 changes: 4 additions & 4 deletions src/AIAssistant/Commands/CodeAssistCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ public sealed class Settings : CommandSettings
[Description("[grey] the type of code assist. it can be `embedding` or `summary`.[/].")]
public CodeAssistType? CodeAssistType { get; set; }

[CommandOption("--threshold <threshold")]
[CommandOption("--threshold <threshold>")]
[Description("[grey] the threshold is a value for using in the `embedding`.[/].")]
public decimal? Threshold { get; set; }

[CommandOption("--temperature <temperature")]
[CommandOption("--temperature <temperature>")]
[Description(
"[grey] the temperature is a value for controlling creativity or randomness on the llm response.[/]."
)]
public decimal? Temperature { get; set; }

[CommandOption("--chat-api-key <key>")]
[CommandOption("--chat-api-key <chat-api-key>")]
[Description("[grey] the chat model api key.[/].")]
public string? ChatModelApiKey { get; set; }

Expand Down Expand Up @@ -159,7 +159,7 @@ await AnsiConsole

console.Write(new Rule());

userInput = "can you remove all comments from Add.cs file?";
//userInput = "can you remove all comments from Add.cs file?";
_running = await internalCommandProcessor.ProcessCommand(userInput, scope);
}

Expand Down
3 changes: 2 additions & 1 deletion src/AIAssistant/Contracts/IEmbeddingService.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections;
using AIAssistant.Chat.Models;
using AIAssistant.Data;
using AIAssistant.Dtos;
Expand All @@ -9,7 +10,7 @@ namespace AIAssistant.Contracts;
public interface IEmbeddingService
{
Task<AddEmbeddingsForFilesResult> AddOrUpdateEmbeddingsForFiles(
IEnumerable<CodeFileMap> codeFilesMap,
IList<CodeFileMap> codeFilesMap,
ChatSession chatSession
);
Task<GetRelatedEmbeddingsResult> GetRelatedEmbeddings(string userQuery, ChatSession chatSession);
Expand Down
2 changes: 1 addition & 1 deletion src/AIAssistant/Contracts/ILLMClientManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface ILLMClientManager
CancellationToken cancellationToken = default
);
Task<GetEmbeddingResult> GetEmbeddingAsync(
string input,
IList<string> inputs,
string? path,
CancellationToken cancellationToken = default
);
Expand Down
8 changes: 8 additions & 0 deletions src/AIAssistant/Dtos/GetBatchEmbeddingResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace AIAssistant.Dtos;

public class GetBatchEmbeddingResult(IList<IList<double>> embeddings, int totalTokensCount, decimal totalCost)
{
public IList<IList<double>> Embeddings { get; } = embeddings;
public int TotalTokensCount { get; } = totalTokensCount;
public decimal TotalCost { get; } = totalCost;
}
6 changes: 5 additions & 1 deletion src/AIAssistant/Dtos/GetEmbeddingResult.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
namespace AIAssistant.Dtos;

public record GetEmbeddingResult(IList<double> Embeddings, int TotalTokensCount, decimal TotalCost);
public record GetEmbeddingResult(
IList<IList<double>> Embeddings, // Multiple embeddings for batch
int TotalTokensCount,
decimal TotalCost
);
18 changes: 18 additions & 0 deletions src/AIAssistant/Models/FileBatch.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace AIAssistant.Models;

/// <summary>
/// Represents a batch of files and their chunks to be processed in a single embedding request.
/// </summary>
public class FileBatch
{
public IList<FileChunkGroup> Files { get; set; } = new List<FileChunkGroup>();
public int TotalTokens { get; set; }

/// <summary>
/// Combines all chunked inputs for this batch into a single list for API calls.
/// </summary>
public IList<string> GetBatchInputs()
{
return Files.SelectMany(file => file.Chunks).ToList();
}
}
14 changes: 14 additions & 0 deletions src/AIAssistant/Models/FileChunkGroup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using TreeSitter.Bindings.CustomTypes.TreeParser;

namespace AIAssistant.Models;

/// <summary>
/// Represents a file and its associated chunks for embedding.
/// </summary>
public class FileChunkGroup(CodeFileMap file, List<string> chunks)
{
public CodeFileMap File { get; } = file;
public IList<string> Chunks { get; } = chunks;

public string Input => string.Join("\n", Chunks);
}
214 changes: 196 additions & 18 deletions src/AIAssistant/Services/EmbeddingService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,86 @@
using AIAssistant.Dtos;
using AIAssistant.Models;
using BuildingBlocks.LLM;
using BuildingBlocks.Utils;
using TreeSitter.Bindings.CustomTypes.TreeParser;

namespace AIAssistant.Services;

public class EmbeddingService(
ILLMClientManager llmClientManager,
ICodeEmbeddingsRepository codeEmbeddingsRepository,
IPromptManager promptManager
IPromptManager promptManager,
ITokenizer tokenizer
) : IEmbeddingService
{
public async Task<AddEmbeddingsForFilesResult> AddOrUpdateEmbeddingsForFiles(
IEnumerable<CodeFileMap> codeFilesMap,
IList<CodeFileMap> codeFilesMap,
ChatSession chatSession
)
{
int totalTokens = 0;
decimal totalCost = 0;

IList<CodeEmbedding> codeEmbeddings = new List<CodeEmbedding>();
var fileEmbeddingsMap = new Dictionary<string, List<IList<double>>>();

// Group files and manage batching using the updated tokenizer logic
var fileBatches = await BatchFilesByTokenLimitAsync(codeFilesMap, maxBatchTokens: 8192);

foreach (var batch in fileBatches)
{
var batchInputs = batch.GetBatchInputs();
var embeddingResult = await llmClientManager.GetEmbeddingAsync(batchInputs, null);

int resultIndex = 0;
foreach (var fileChunkGroup in batch.Files)
{
// Extract embeddings for the current file's chunks
var fileEmbeddings = embeddingResult
.Embeddings.Skip(resultIndex)
.Take(fileChunkGroup.Chunks.Count)
.ToList();

resultIndex += fileChunkGroup.Chunks.Count;

// Group embeddings by file path
if (!fileEmbeddingsMap.TryGetValue(fileChunkGroup.File.RelativePath, out List<IList<double>>? value))
{
value = new List<IList<double>>();
fileEmbeddingsMap[fileChunkGroup.File.RelativePath] = value;
}

value.AddRange(fileEmbeddings);
}

totalTokens += embeddingResult.TotalTokensCount;
totalCost += embeddingResult.TotalCost;
}

foreach (var codeFileMap in codeFilesMap)
// Merge and create final embeddings for each file
var codeEmbeddings = new List<CodeEmbedding>();
foreach (var entry in fileEmbeddingsMap)
{
var input = promptManager.GetEmbeddingInputString(codeFileMap.TreeSitterFullCode);
var embeddingResult = await llmClientManager.GetEmbeddingAsync(input, codeFileMap.RelativePath);
var filePath = entry.Key;
var embeddings = entry.Value;

// Merge embeddings for the file
var mergedEmbedding = MergeEmbeddings(embeddings);

// Retrieve the original file details from codeFilesMap
var fileDetails = codeFilesMap.First(file => file.RelativePath == filePath);

codeEmbeddings.Add(
new CodeEmbedding
{
RelativeFilePath = codeFileMap.RelativePath,
TreeSitterFullCode = codeFileMap.TreeSitterFullCode,
TreeOriginalCode = codeFileMap.TreeOriginalCode,
Code = codeFileMap.OriginalCode,
RelativeFilePath = fileDetails.RelativePath,
TreeSitterFullCode = fileDetails.TreeSitterFullCode,
TreeOriginalCode = fileDetails.TreeOriginalCode,
Code = fileDetails.OriginalCode,
SessionId = chatSession.SessionId,
Embeddings = embeddingResult.Embeddings,
Embeddings = mergedEmbedding,
}
);

totalTokens += embeddingResult.TotalTokensCount;
totalCost += embeddingResult.TotalCost;
}

// we can replace it with an embedded database like `chromadb`, it can give us n of most similarity items
await codeEmbeddingsRepository.AddOrUpdateCodeEmbeddings(codeEmbeddings);

return new AddEmbeddingsForFilesResult(totalTokens, totalCost);
Expand All @@ -59,7 +96,7 @@ public async Task<GetRelatedEmbeddingsResult> GetRelatedEmbeddings(string userQu

// Find relevant code based on the user query
var relevantCodes = codeEmbeddingsRepository.Query(
embeddingsResult.Embeddings,
embeddingsResult.Embeddings.First(),
chatSession.SessionId,
llmClientManager.EmbeddingThreshold
);
Expand All @@ -82,6 +119,147 @@ public IEnumerable<CodeEmbedding> QueryByFilter(

public async Task<GetEmbeddingResult> GenerateEmbeddingForUserInput(string userInput)
{
return await llmClientManager.GetEmbeddingAsync(userInput, null);
return await llmClientManager.GetEmbeddingAsync(new List<string> { userInput }, null);
}

private async Task<List<FileBatch>> BatchFilesByTokenLimitAsync(
IEnumerable<CodeFileMap> codeFilesMap,
int maxBatchTokens
)
{
var fileBatches = new List<FileBatch>();
var currentBatch = new FileBatch();

foreach (var file in codeFilesMap)
{
// Convert the full code to an input string and split into chunks
var input = promptManager.GetEmbeddingInputString(file.TreeSitterFullCode);
var chunks = await SplitTextIntoChunksAsync(input, maxTokens: 8192);

var tokenCountTasks = chunks.Select(chunk => tokenizer.GetTokenCount(chunk));
var tokenCounts = await Task.WhenAll(tokenCountTasks);

// Pair chunks with their token counts
var chunkWithTokens = chunks.Zip(
tokenCounts,
(chunk, tokenCount) => new { Chunk = chunk, TokenCount = tokenCount }
);

foreach (var chunkGroup in chunkWithTokens)
{
// If adding this chunk would exceed the batch token limit
if (currentBatch.TotalTokens + chunkGroup.TokenCount > maxBatchTokens && currentBatch.Files.Count > 0)
{
// Finalize the current batch and start a new one
fileBatches.Add(currentBatch);
currentBatch = new FileBatch();
}

// Add this chunk to the current batch
if (currentBatch.Files.All(f => f.File != file))
{
// If this is the first chunk of this file in the current batch, add a new FileChunkGroup
currentBatch.Files.Add(new FileChunkGroup(file, new List<string> { chunkGroup.Chunk }));
}
else
{
// Add the chunk to the existing FileChunkGroup for this file
var fileGroup = currentBatch.Files.First(f => f.File == file);
fileGroup.Chunks.Add(chunkGroup.Chunk);
}

currentBatch.TotalTokens += chunkGroup.TokenCount;
}
}

// Add the last batch if it has content
if (currentBatch.Files.Count > 0)
{
fileBatches.Add(currentBatch);
}

return fileBatches;
}

private async Task<List<string>> SplitTextIntoChunksAsync(string text, int maxTokens)
{
var words = text.Split(' ');
var chunks = new List<string>();
var currentChunk = new List<string>();

foreach (var word in words)
{
currentChunk.Add(word);

// Check token count only when the chunk exceeds a certain word threshold
if (currentChunk.Count % 50 == 0 || currentChunk.Count == words.Length)
{
var currentText = string.Join(" ", currentChunk);
var currentTokenCount = await tokenizer.GetTokenCount(currentText);

if (currentTokenCount > maxTokens)
{
// Ensure the chunk size is within limits
while (currentTokenCount > maxTokens && currentChunk.Count > 1)
{
currentChunk.RemoveAt(currentChunk.Count - 1);
currentText = string.Join(" ", currentChunk);
currentTokenCount = await tokenizer.GetTokenCount(currentText);
}

// Add the finalized chunk only if it fits the token limit
if (currentTokenCount <= maxTokens)
{
chunks.Add(currentText);
}

// Start a new chunk with the current word
currentChunk.Clear();
currentChunk.Add(word);
}
}
}

// Add the final chunk if it has content and is within the token limit
if (currentChunk.Count > 0)
{
var finalText = string.Join(" ", currentChunk);
var finalTokenCount = await tokenizer.GetTokenCount(finalText);

if (finalTokenCount <= maxTokens)
{
chunks.Add(finalText);
}
}

return chunks;
}

private IList<double> MergeEmbeddings(IList<IList<double>> embeddings)
{
if (embeddings == null || embeddings.Count == 0)
throw new ArgumentException("The embeddings list cannot be null or empty.");

int dimension = embeddings.First().Count;
var mergedEmbedding = new double[dimension];

foreach (var embedding in embeddings)
{
if (embedding.Count != dimension)
throw new InvalidOperationException("All embeddings must have the same dimensionality.");

for (int i = 0; i < dimension; i++)
{
mergedEmbedding[i] += embedding[i];
}
}

// Average the embeddings to unify them into one
for (int i = 0; i < dimension; i++)
{
mergedEmbedding[i] /= embeddings.Count;
}

return mergedEmbedding;
}
}
Loading

0 comments on commit 2cf05f6

Please sign in to comment.