diff --git a/crllm_config.toml b/crllm_config.toml index 95f13ae..ff61c08 100644 --- a/crllm_config.toml +++ b/crllm_config.toml @@ -1,7 +1,8 @@ [crllm] loader = "git" -model = "deepseek-coder-v2:latest" +model = "qwen2.5-coder:32b" git_changed_lines = true [rag] -enabled = true \ No newline at end of file +enabled = true +cache_enabled = true \ No newline at end of file diff --git a/src/crllm/config/config.toml b/src/crllm/config/config.toml index f13bff4..6ee7131 100644 --- a/src/crllm/config/config.toml +++ b/src/crllm/config/config.toml @@ -16,6 +16,8 @@ enabled = false embedding_model = "all-minilm" src_path = "./" src_glob = "**/*.py" +cache_enabled = false +cache_path = ".crllm/index" [prompt] diff --git a/src/crllm/model/model.py b/src/crllm/model/model.py index 357db52..6692529 100644 --- a/src/crllm/model/model.py +++ b/src/crllm/model/model.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod import logging +import os from langchain_community.document_loaders.generic import GenericLoader from langchain_community.document_loaders.parsers import LanguageParser from langchain_community.vectorstores import FAISS @@ -35,6 +36,27 @@ def generate(self, prompt_template, prompt_args): return result.content def add_rag_context(self, prompt_args, rag_config): + vectorstore = self._get_vector_store(rag_config) + + def format_docs(docs): + return "\n\n".join(doc.page_content for doc in docs) + + prompt_args["context"] = format_docs( + vectorstore.similarity_search(prompt_args["code"], 3) + ) + + def _get_vector_store(self, rag_config): + embedding_config = {"model": rag_config["embedding_model"]} + + if rag_config["cache_enabled"]: + if os.path.exists(rag_config["cache_path"]): + vectorstore = FAISS.load_local( + rag_config["cache_path"], + embeddings=self._get_embeddings(embedding_config), + allow_dangerous_deserialization=True, + ) + return vectorstore + loader = GenericLoader.from_filesystem( path=rag_config["src_path"], glob=rag_config["src_glob"], @@ -44,18 +66,15 @@ def add_rag_context(self, prompt_args, rag_config): docs = loader.load() - embedding_config = {"model": rag_config["embedding_model"]} - vectorstore = FAISS.from_documents( documents=docs, embedding=self._get_embeddings(embedding_config) ) - def format_docs(docs): - return "\n\n".join(doc.page_content for doc in docs) + if rag_config["cache_enabled"]: + os.makedirs(os.path.dirname(rag_config["cache_path"]), exist_ok=True) + vectorstore.save_local(rag_config["cache_path"]) - prompt_args["context"] = format_docs( - vectorstore.similarity_search(prompt_args["code"], 3) - ) + return vectorstore @abstractmethod def _get_model(self, model_config):