Skip to content

Commit

Permalink
Merge pull request #218 from hsm207/weaviate-vectorstore
Browse files Browse the repository at this point in the history
Add weaviate support
  • Loading branch information
pseudotensor authored Jun 2, 2023
2 parents faa6e67 + 8070b4b commit 9371e51
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 9 deletions.
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def main(
:param document_choice: Default document choice when taking subset of collection
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
:param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
:param use_openai_model: Whether to use OpenAI model for use with vector db
:param hf_embedding_model: Which HF embedding model to use for vector db
Expand Down
35 changes: 34 additions & 1 deletion gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
if db_type == 'faiss':
from langchain.vectorstores import FAISS
db = FAISS.from_documents(sources, embedding)

elif db_type == 'weaviate':
import weaviate
from weaviate.embedded import EmbeddedOptions
from langchain.vectorstores import Weaviate

# TODO: add support for connecting via docker compose
client = weaviate.Client(
embedded_options=EmbeddedOptions()
)
index_name = langchain_mode.replace(' ', '_').capitalize()
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, index_name=index_name)

elif db_type == 'chroma':
collection_name = langchain_mode.replace(' ', '_')
os.makedirs(persist_directory, exist_ok=True)
Expand All @@ -74,12 +87,32 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo

return db

def _get_unique_sources_in_weaviate(db):
batch_size=100
id_source_list = []
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)

while result['objects']:
id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
last_id = id_source_list[-1][0]
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)

unique_sources = {source for _, source in id_source_list}
return unique_sources

def add_to_db(db, sources, db_type='faiss', avoid_dup=True):
if not sources:
return db
if db_type == 'faiss':
db.add_documents(sources)
elif db_type == 'weaviate':
if avoid_dup:
unique_sources = _get_unique_sources_in_weaviate(db)
sources = [x for x in sources if x.metadata['source'] not in unique_sources]
if len(sources) == 0:
return db
db.add_documents(documents=sources)

elif db_type == 'chroma':
if avoid_dup:
collection = db.get()
Expand Down Expand Up @@ -919,7 +952,7 @@ def _run_qa_db(query=None,
:param chunk:
:param chunk_size:
:param user_path: user path to glob recursively from
:param db_type: 'faiss' for in-memory db or 'chroma' for persistent db
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
:param model_name: model name, used to switch behaviors
:param model: pre-initialized model, else will make new one
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
Expand Down
56 changes: 49 additions & 7 deletions make_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def make_db_main(use_openai_embedding: bool = False,
pre_load_caption_model: bool = False,
caption_gpu: bool = True,
enable_ocr: bool = False,
db_type: str = 'chroma',
):
"""
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
Expand Down Expand Up @@ -85,10 +86,11 @@ def make_db_main(use_openai_embedding: bool = False,
:param pre_load_caption_model: See generate.py
:param caption_gpu: Caption images on GPU if present
:param enable_ocr: Whether to enable OCR on images
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
:return: None
"""

db_type = 'chroma'


if download_all:
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
Expand Down Expand Up @@ -142,6 +144,49 @@ def make_db_main(use_openai_embedding: bool = False,
sources = [x for x in sources if 'exception' not in x.metadata]

assert len(sources) > 0, "No sources found"
if db_type == 'chroma':
db = _create_or_update_chroma_db(sources, use_openai_embedding, persist_directory, add_if_exists, verbose, hf_embedding_model, collection_name)
elif db_type == 'weaviate':
db = _create_or_update_weaviate_db(sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model, collection_name)
else:
raise ValueError(f"db_type={db_type} not supported")

assert db is not None
if verbose:
print("DONE", flush=True)
return db

def _create_or_update_weaviate_db(sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model, collection_name):
import weaviate
from weaviate.embedded import EmbeddedOptions
from langchain.vectorstores import Weaviate

# TODO: add support for connecting via docker compose
client = weaviate.Client(
embedded_options=EmbeddedOptions()
)

index_name = collection_name.replace(' ', '_').capitalize()

if not add_if_exists:
if verbose and client.schema.exists(index_name):
print("Removing %s" % index_name, flush=True)
client.schema.delete_class(index_name)

if verbose:
print("Generating db", flush=True)
db = get_db(sources,
use_openai_embedding=use_openai_embedding,
db_type='weaviate',
persist_directory=None,
langchain_mode='UserData',
hf_embedding_model=hf_embedding_model)
else:
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
db = Weaviate(embedding_function=embedding, client=client, by_text=False, index_name=index_name)
add_to_db(db, sources, db_type='weaviate')

def _create_or_update_chroma_db(sources, use_openai_embedding, persist_directory, add_if_exists, verbose, hf_embedding_model, collection_name):
if not os.path.isdir(persist_directory) or not add_if_exists:
if os.path.isdir(persist_directory):
if verbose:
Expand All @@ -151,7 +196,7 @@ def make_db_main(use_openai_embedding: bool = False,
print("Generating db", flush=True)
db = get_db(sources,
use_openai_embedding=use_openai_embedding,
db_type=db_type,
db_type='chroma',
persist_directory=persist_directory,
langchain_mode='UserData',
hf_embedding_model=hf_embedding_model)
Expand All @@ -161,12 +206,9 @@ def make_db_main(use_openai_embedding: bool = False,
db = Chroma(embedding_function=embedding,
persist_directory=persist_directory,
collection_name=collection_name)
add_to_db(db, sources, db_type=db_type)
assert db is not None
if verbose:
print("DONE", flush=True)
add_to_db(db, sources, db_type='chroma')

return db


if __name__ == "__main__":
fire.Fire(make_db_main)
3 changes: 3 additions & 0 deletions reqs_optional/requirements_optional_langchain.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ tabulate==0.9.0
# to check licenses
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
pip-licenses==4.3.0

# weaviate vector db
weaviate-client==3.19.2
18 changes: 18 additions & 0 deletions tests/test_langchain_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,24 @@ def test_qa_daidocs_db_chunk_hf_chroma():
)
check_ret(ret)

@wrap_test_forked
def test_qa_wiki_db_chunk_hf_weaviate():

from gpt4all_llm import get_model_tokenizer_gpt4all
model_name = 'llama'
model, tokenizer, device = get_model_tokenizer_gpt4all(model_name)

from gpt_langchain import _run_qa_db
query = "What are the main differences between Linux and Windows?"
# chunk_size is chars for each of k=4 chunks
ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True,
chunk_size=128 * 1, # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
langchain_mode='wiki',
db_type='weaviate',
prompt_type='wizard2',
model_name=model_name, model=model, tokenizer=tokenizer,
)
check_ret(ret)

@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
Expand Down

0 comments on commit 9371e51

Please sign in to comment.