Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weaviate support #218

Merged
merged 9 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def main(
:param document_choice: Default document choice when taking subset of collection
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
:param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
:param use_openai_model: Whether to use OpenAI model for use with vector db
:param hf_embedding_model: Which HF embedding model to use for vector db
Expand Down
35 changes: 34 additions & 1 deletion gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
if db_type == 'faiss':
from langchain.vectorstores import FAISS
db = FAISS.from_documents(sources, embedding)

elif db_type == 'weaviate':
import weaviate
from weaviate.embedded import EmbeddedOptions
from langchain.vectorstores import Weaviate

# TODO: add support for connecting via docker compose
client = weaviate.Client(
embedded_options=EmbeddedOptions()
)
index_name = langchain_mode.replace(' ', '_').capitalize()
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, index_name=index_name)

elif db_type == 'chroma':
collection_name = langchain_mode.replace(' ', '_')
os.makedirs(persist_directory, exist_ok=True)
Expand All @@ -74,12 +87,32 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo

return db

def _get_unique_sources_in_weaviate(db):
batch_size=100
id_source_list = []
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)

while result['objects']:
id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
last_id = id_source_list[-1][0]
result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)

unique_sources = {source for _, source in id_source_list}
return unique_sources

def add_to_db(db, sources, db_type='faiss', avoid_dup=True):
if not sources:
return db
if db_type == 'faiss':
db.add_documents(sources)
elif db_type == 'weaviate':
if avoid_dup:
unique_sources = _get_unique_sources_in_weaviate(db)
sources = [x for x in sources if x.metadata['source'] not in unique_sources]
if len(sources) == 0:
return db
db.add_documents(documents=sources)

elif db_type == 'chroma':
if avoid_dup:
collection = db.get()
Expand Down Expand Up @@ -919,7 +952,7 @@ def _run_qa_db(query=None,
:param chunk:
:param chunk_size:
:param user_path: user path to glob recursively from
:param db_type: 'faiss' for in-memory db or 'chroma' for persistent db
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
:param model_name: model name, used to switch behaviors
:param model: pre-initialized model, else will make new one
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
Expand Down
56 changes: 49 additions & 7 deletions make_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def make_db_main(use_openai_embedding: bool = False,
pre_load_caption_model: bool = False,
caption_gpu: bool = True,
enable_ocr: bool = False,
db_type: str = 'chroma',
):
"""
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
Expand Down Expand Up @@ -85,10 +86,11 @@ def make_db_main(use_openai_embedding: bool = False,
:param pre_load_caption_model: See generate.py
:param caption_gpu: Caption images on GPU if present
:param enable_ocr: Whether to enable OCR on images
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
:return: None
"""

db_type = 'chroma'


if download_all:
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
Expand Down Expand Up @@ -142,6 +144,49 @@ def make_db_main(use_openai_embedding: bool = False,
sources = [x for x in sources if 'exception' not in x.metadata]

assert len(sources) > 0, "No sources found"
if db_type == 'chroma':
db = _create_or_update_chroma_db(sources, use_openai_embedding, persist_directory, add_if_exists, verbose, hf_embedding_model, collection_name)
elif db_type == 'weaviate':
db = _create_or_update_weaviate_db(sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model, collection_name)
else:
raise ValueError(f"db_type={db_type} not supported")

assert db is not None
if verbose:
print("DONE", flush=True)
return db

def _create_or_update_weaviate_db(sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model, collection_name):
import weaviate
from weaviate.embedded import EmbeddedOptions
from langchain.vectorstores import Weaviate

# TODO: add support for connecting via docker compose
client = weaviate.Client(
embedded_options=EmbeddedOptions()
)

index_name = collection_name.replace(' ', '_').capitalize()

if not add_if_exists:
if verbose and client.schema.exists(index_name):
print("Removing %s" % index_name, flush=True)
client.schema.delete_class(index_name)

if verbose:
print("Generating db", flush=True)
db = get_db(sources,
use_openai_embedding=use_openai_embedding,
db_type='weaviate',
persist_directory=None,
langchain_mode='UserData',
hf_embedding_model=hf_embedding_model)
else:
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
db = Weaviate(embedding_function=embedding, client=client, by_text=False, index_name=index_name)
add_to_db(db, sources, db_type='weaviate')

def _create_or_update_chroma_db(sources, use_openai_embedding, persist_directory, add_if_exists, verbose, hf_embedding_model, collection_name):
if not os.path.isdir(persist_directory) or not add_if_exists:
if os.path.isdir(persist_directory):
if verbose:
Expand All @@ -151,7 +196,7 @@ def make_db_main(use_openai_embedding: bool = False,
print("Generating db", flush=True)
db = get_db(sources,
use_openai_embedding=use_openai_embedding,
db_type=db_type,
db_type='chroma',
persist_directory=persist_directory,
langchain_mode='UserData',
hf_embedding_model=hf_embedding_model)
Expand All @@ -161,12 +206,9 @@ def make_db_main(use_openai_embedding: bool = False,
db = Chroma(embedding_function=embedding,
persist_directory=persist_directory,
collection_name=collection_name)
add_to_db(db, sources, db_type=db_type)
assert db is not None
if verbose:
print("DONE", flush=True)
add_to_db(db, sources, db_type='chroma')

return db


if __name__ == "__main__":
fire.Fire(make_db_main)
3 changes: 3 additions & 0 deletions reqs_optional/requirements_optional_langchain.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ tabulate==0.9.0
# to check licenses
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
pip-licenses==4.3.0

# weaviate vector db
weaviate-client==3.19.2
18 changes: 18 additions & 0 deletions tests/test_langchain_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,24 @@ def test_qa_daidocs_db_chunk_hf_chroma():
)
check_ret(ret)

@wrap_test_forked
def test_qa_wiki_db_chunk_hf_weaviate():

from gpt4all_llm import get_model_tokenizer_gpt4all
model_name = 'llama'
model, tokenizer, device = get_model_tokenizer_gpt4all(model_name)

from gpt_langchain import _run_qa_db
query = "What are the main differences between Linux and Windows?"
# chunk_size is chars for each of k=4 chunks
ret = _run_qa_db(query=query, use_openai_model=False, use_openai_embedding=False, text_limit=None, chunk=True,
chunk_size=128 * 1, # characters, and if k=4, then 4*4*128 = 2048 chars ~ 512 tokens
langchain_mode='wiki',
db_type='weaviate',
prompt_type='wizard2',
model_name=model_name, model=model, tokenizer=tokenizer,
)
check_ret(ret)

@pytest.mark.skipif(not have_openai_key, reason="requires OpenAI key to run")
@wrap_test_forked
Expand Down