diff --git a/generate.py b/generate.py index 620d66b4d..1bc5b9b7e 100644 --- a/generate.py +++ b/generate.py @@ -224,7 +224,7 @@ def main( :param document_choice: Default document choice when taking subset of collection :param load_db_if_exists: Whether to load chroma db if exists or re-generate db :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually - :param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk + :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk :param use_openai_embedding: Whether to use OpenAI embeddings for vector db :param use_openai_model: Whether to use OpenAI model for use with vector db :param hf_embedding_model: Which HF embedding model to use for vector db diff --git a/gpt_langchain.py b/gpt_langchain.py index 6bd6de39d..c8201f1cb 100644 --- a/gpt_langchain.py +++ b/gpt_langchain.py @@ -55,6 +55,19 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo if db_type == 'faiss': from langchain.vectorstores import FAISS db = FAISS.from_documents(sources, embedding) + + elif db_type == 'weaviate': + import weaviate + from weaviate.embedded import EmbeddedOptions + from langchain.vectorstores import Weaviate + + # TODO: add support for connecting via docker compose + client = weaviate.Client( + embedded_options=EmbeddedOptions() + ) + index_name = langchain_mode.replace(' ', '_').capitalize() + db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, index_name=index_name) + elif db_type == 'chroma': collection_name = langchain_mode.replace(' ', '_') os.makedirs(persist_directory, exist_ok=True) @@ -74,12 +87,32 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo return db +def _get_unique_sources_in_weaviate(db): + batch_size=100 + id_source_list = [] + result = db._client.data_object.get(class_name=db._index_name, limit=batch_size) + + while result['objects']: + id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']] + last_id = id_source_list[-1][0] + result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id) + + unique_sources = {source for _, source in id_source_list} + return unique_sources def add_to_db(db, sources, db_type='faiss', avoid_dup=True): if not sources: return db if db_type == 'faiss': db.add_documents(sources) + elif db_type == 'weaviate': + if avoid_dup: + unique_sources = _get_unique_sources_in_weaviate(db) + sources = [x for x in sources if x.metadata['source'] not in unique_sources] + if len(sources) == 0: + return db + db.add_documents(documents=sources) + elif db_type == 'chroma': if avoid_dup: collection = db.get() @@ -919,7 +952,7 @@ def _run_qa_db(query=None, :param chunk: :param chunk_size: :param user_path: user path to glob recursively from - :param db_type: 'faiss' for in-memory db or 'chroma' for persistent db + :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db :param model_name: model name, used to switch behaviors :param model: pre-initialized model, else will make new one :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None diff --git a/make_db.py b/make_db.py index dafb2a564..95477e272 100644 --- a/make_db.py +++ b/make_db.py @@ -45,6 +45,7 @@ def make_db_main(use_openai_embedding: bool = False, pre_load_caption_model: bool = False, caption_gpu: bool = True, enable_ocr: bool = False, + db_type: str = 'chroma', ): """ # To make UserData db for generate.py, put pdfs, etc. into path user_path and run: @@ -85,10 +86,11 @@ def make_db_main(use_openai_embedding: bool = False, :param pre_load_caption_model: See generate.py :param caption_gpu: Caption images on GPU if present :param enable_ocr: Whether to enable OCR on images + :param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported. :return: None """ - db_type = 'chroma' + if download_all: print("Downloading all (and unzipping): %s" % all_db_zips, flush=True) @@ -142,6 +144,49 @@ def make_db_main(use_openai_embedding: bool = False, sources = [x for x in sources if 'exception' not in x.metadata] assert len(sources) > 0, "No sources found" + if db_type == 'chroma': + db = _create_or_update_chroma_db(sources, use_openai_embedding, persist_directory, add_if_exists, verbose, hf_embedding_model, collection_name) + elif db_type == 'weaviate': + db = _create_or_update_weaviate_db(sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model, collection_name) + else: + raise ValueError(f"db_type={db_type} not supported") + + assert db is not None + if verbose: + print("DONE", flush=True) + return db + +def _create_or_update_weaviate_db(sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model, collection_name): + import weaviate + from weaviate.embedded import EmbeddedOptions + from langchain.vectorstores import Weaviate + + # TODO: add support for connecting via docker compose + client = weaviate.Client( + embedded_options=EmbeddedOptions() + ) + + index_name = collection_name.replace(' ', '_').capitalize() + + if not add_if_exists: + if verbose and client.schema.exists(index_name): + print("Removing %s" % index_name, flush=True) + client.schema.delete_class(index_name) + + if verbose: + print("Generating db", flush=True) + db = get_db(sources, + use_openai_embedding=use_openai_embedding, + db_type='weaviate', + persist_directory=None, + langchain_mode='UserData', + hf_embedding_model=hf_embedding_model) + else: + embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) + db = Weaviate(embedding_function=embedding, client=client, by_text=False, index_name=index_name) + add_to_db(db, sources, db_type='weaviate') + +def _create_or_update_chroma_db(sources, use_openai_embedding, persist_directory, add_if_exists, verbose, hf_embedding_model, collection_name): if not os.path.isdir(persist_directory) or not add_if_exists: if os.path.isdir(persist_directory): if verbose: @@ -151,7 +196,7 @@ def make_db_main(use_openai_embedding: bool = False, print("Generating db", flush=True) db = get_db(sources, use_openai_embedding=use_openai_embedding, - db_type=db_type, + db_type='chroma', persist_directory=persist_directory, langchain_mode='UserData', hf_embedding_model=hf_embedding_model) @@ -161,12 +206,9 @@ def make_db_main(use_openai_embedding: bool = False, db = Chroma(embedding_function=embedding, persist_directory=persist_directory, collection_name=collection_name) - add_to_db(db, sources, db_type=db_type) - assert db is not None - if verbose: - print("DONE", flush=True) + add_to_db(db, sources, db_type='chroma') + return db - if __name__ == "__main__": fire.Fire(make_db_main) diff --git a/reqs_optional/requirements_optional_langchain.txt b/reqs_optional/requirements_optional_langchain.txt index de026a82b..01e145936 100644 --- a/reqs_optional/requirements_optional_langchain.txt +++ b/reqs_optional/requirements_optional_langchain.txt @@ -40,3 +40,6 @@ tabulate==0.9.0 # to check licenses # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT' pip-licenses==4.3.0 + +# weaviate vector db +weaviate-client==3.19.2 \ No newline at end of file diff --git a/tests/test_langchain_units.py b/tests/test_langchain_units.py index 83c2e3afd..a6693f559 100644 --- a/tests/test_langchain_units.py +++ b/tests/test_langchain_units.py @@ -162,6 +162,24 @@ def test_qa_daidocs_db_chunk_hf_chroma(): ) check_ret(ret) +@wrap_test_forked +def test_qa_wiki_db_chunk_hf_weaviate(): + + from gpt4all_llm import get_model_tokenizer_gpt4all + model_name = 'llama' + model, tokenizer, device = get_model_tokenizer_gpt4all(model_name) + + from gpt_langchain import _run_qa_db + query = "What are the main differences between Linux and Windows?" + # chunk_size is chars for each of k=4 chunks + ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True, + chunk_size=128 * 1, # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens + langchain_mode='wiki', + db_type='weaviate', + prompt_type='wizard2', + model_name=model_name, model=model, tokenizer=tokenizer, + ) + check_ret(ret) @pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run") @wrap_test_forked