Skip to content

Commit

Permalink
Moved get_embeddings function to utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PromtEngineer committed Feb 5, 2024
1 parent 2018a08 commit fd67d92
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 89 deletions.
94 changes: 39 additions & 55 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import click
import torch
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from utils import get_embeddings

from constants import (
CHROMA_SETTINGS,
Expand All @@ -18,27 +18,30 @@
SOURCE_DIRECTORY,
)


def file_log(logentry):
file1 = open("file_ingest.log","a")
file1.write(logentry + "\n")
file1.close()
print(logentry + "\n")
file1 = open("file_ingest.log", "a")
file1.write(logentry + "\n")
file1.close()
print(logentry + "\n")


def load_single_document(file_path: str) -> Document:
# Loads a single document from a file path
try:
file_extension = os.path.splitext(file_path)[1]
loader_class = DOCUMENT_MAP.get(file_extension)
if loader_class:
file_log(file_path + ' loaded.')
loader = loader_class(file_path)
else:
file_log(file_path + ' document type is undefined.')
raise ValueError("Document type is undefined")
return loader.load()[0]
file_extension = os.path.splitext(file_path)[1]
loader_class = DOCUMENT_MAP.get(file_extension)
if loader_class:
file_log(file_path + " loaded.")
loader = loader_class(file_path)
else:
file_log(file_path + " document type is undefined.")
raise ValueError("Document type is undefined")
return loader.load()[0]
except Exception as ex:
file_log('%s loading error: \n%s' % (file_path, ex))
return None
file_log("%s loading error: \n%s" % (file_path, ex))
return None


def load_document_batch(filepaths):
logging.info("Loading document batch")
Expand All @@ -48,20 +51,20 @@ def load_document_batch(filepaths):
futures = [exe.submit(load_single_document, name) for name in filepaths]
# collect data
if futures is None:
file_log(name + ' failed to submit')
return None
file_log(name + " failed to submit")
return None
else:
data_list = [future.result() for future in futures]
# return data and file paths
return (data_list, filepaths)
data_list = [future.result() for future in futures]
# return data and file paths
return (data_list, filepaths)


def load_documents(source_dir: str) -> list[Document]:
# Loads all documents from the source documents directory, including nested folders
paths = []
for root, _, files in os.walk(source_dir):
for file_name in files:
print('Importing: ' + file_name)
print("Importing: " + file_name)
file_extension = os.path.splitext(file_name)[1]
source_file_path = os.path.join(root, file_name)
if file_extension in DOCUMENT_MAP.keys():
Expand All @@ -79,21 +82,21 @@ def load_documents(source_dir: str) -> list[Document]:
filepaths = paths[i : (i + chunksize)]
# submit the task
try:
future = executor.submit(load_document_batch, filepaths)
future = executor.submit(load_document_batch, filepaths)
except Exception as ex:
file_log('executor task failed: %s' % (ex))
future = None
file_log("executor task failed: %s" % (ex))
future = None
if future is not None:
futures.append(future)
futures.append(future)
# process all results
for future in as_completed(futures):
# open the file and load the data
try:
contents, _ = future.result()
docs.extend(contents)
except Exception as ex:
file_log('Exception: %s' % (ex))
file_log("Exception: %s" % (ex))

return docs


Expand All @@ -102,11 +105,11 @@ def split_documents(documents: list[Document]) -> tuple[list[Document], list[Doc
text_docs, python_docs = [], []
for doc in documents:
if doc is not None:
file_extension = os.path.splitext(doc.metadata["source"])[1]
if file_extension == ".py":
python_docs.append(doc)
else:
text_docs.append(doc)
file_extension = os.path.splitext(doc.metadata["source"])[1]
if file_extension == ".py":
python_docs.append(doc)
else:
text_docs.append(doc)
return text_docs, python_docs


Expand Down Expand Up @@ -159,29 +162,9 @@ def main(device_type):
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""

def get_embeddings():
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
)
embeddings = get_embeddings()
embeddings = get_embeddings(device_type)

logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")

db = Chroma.from_documents(
Expand All @@ -191,6 +174,7 @@ def get_embeddings():
client_settings=CHROMA_SETTINGS,
)


if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO
Expand Down
38 changes: 7 additions & 31 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

from prompt_template_utils import get_prompt_template
from utils import get_embeddings

# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
Expand All @@ -34,7 +35,7 @@
MODEL_BASENAME,
MAX_NEW_TOKENS,
MODELS_PATH,
CHROMA_SETTINGS
CHROMA_SETTINGS,
)


Expand Down Expand Up @@ -125,37 +126,13 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""

def get_embeddings():
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
)
embeddings = get_embeddings()
embeddings = get_embeddings(device_type)

logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")

# load the vectorstore
db = Chroma(
persist_directory=PERSIST_DIRECTORY,
embedding_function=embeddings,
client_settings=CHROMA_SETTINGS
)
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()

# get the prompt template and memory if set by the user.
Expand Down Expand Up @@ -243,7 +220,6 @@ def get_embeddings():
is_flag=True,
help="whether to save Q&A pairs to a CSV file (Default is False)",
)

def main(device_type, show_sources, use_history, model_type, save_qa):
"""
Implements the main information retrieval task for a localGPT.
Expand Down Expand Up @@ -296,7 +272,7 @@ def main(device_type, show_sources, use_history, model_type, save_qa):
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
print("----------------------------------SOURCE DOCUMENTS---------------------------")

# Log the Q&A to CSV only if save_qa is True
if save_qa:
utils.log_to_csv(query, answer)
Expand Down
34 changes: 31 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import csv
from datetime import datetime
from constants import EMBEDDING_MODEL_NAME
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings


def log_to_csv(question, answer):

Expand All @@ -14,12 +19,35 @@ def log_to_csv(question, answer):

# Check if file exists, if not create and write headers
if not os.path.isfile(log_path):
with open(log_path, mode='w', newline='', encoding='utf-8') as file:
with open(log_path, mode="w", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
writer.writerow(["timestamp", "question", "answer"])

# Append the log entry
with open(log_path, mode='a', newline='', encoding='utf-8') as file:
with open(log_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
writer.writerow([timestamp, question, answer])
writer.writerow([timestamp, question, answer])


def get_embeddings(device_type="cuda"):
if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
embed_instruction="Represent the document for retrieval:",
query_instruction="Represent the question for retrieving supporting documents:",
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
query_instruction="Represent this sentence for searching relevant passages:",
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
)

0 comments on commit fd67d92

Please sign in to comment.