Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve split and merge #1612

Merged
merged 9 commits into from
May 11, 2024
164 changes: 46 additions & 118 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@
is_vision_model, is_gradio_vision_model, is_json_model
from src.serpapi import H2OSerpAPIWrapper
from utils_langchain import StreamingGradioCallbackHandler, _chunk_sources, _add_meta, add_parser, fix_json_meta, \
load_general_summarization_chain, H2OHuggingFaceHubEmbeddings, make_sources_file
load_general_summarization_chain, H2OHuggingFaceHubEmbeddings, make_sources_file, select_docs_with_score, \
split_merge_docs

# to check imports
# find ./src -name '*.py' | xargs awk '{ if (sub(/\\$/, "")) printf "%s ", $0; else print; }' | grep 'from langchain\.' | sed 's/^[ \t]*//' > go.py
Expand Down Expand Up @@ -1702,6 +1703,8 @@ def reducer(accumulator, element):

llm_output = {"token_usage": token_usage, "model_name": self.model_name}
self.count_output_tokens += token_usage.get('completion_tokens', 0)
if self.count_output_tokens == 0:
self.count_output_tokens += sum([self.get_num_tokens(x[0].text) for x in generations if len(x) > 0])
return LLMResult(generations=generations, llm_output=llm_output)

def _generate(
Expand Down Expand Up @@ -3613,6 +3616,9 @@ def file_to_doc(file,

is_public=False,
from_ui=True,

hf_embedding_model=None,
use_openai_embedding=False,
):
# SOME AUTODETECTION LOGIC FOR URL VS TEXT

Expand Down Expand Up @@ -3669,7 +3675,9 @@ def file_to_doc(file,
set_audio_types1 = set_audio_types

assert db_type is not None
chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type)
chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type,
hf_embedding_model=hf_embedding_model, use_openai_embedding=use_openai_embedding,
verbose=verbose)
add_meta = functools.partial(_add_meta, headsize=headsize, filei=filei)
# FIXME: if zip, file index order will not be correct if other files involved
path_to_docs_func = functools.partial(path_to_docs,
Expand Down Expand Up @@ -3723,6 +3731,9 @@ def file_to_doc(file,

is_public=is_public,
from_ui=from_ui,

hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding,
)

if file is None:
Expand Down Expand Up @@ -4619,6 +4630,9 @@ def path_to_doc1(file,

is_public=False,
from_ui=True,

hf_embedding_model=None,
use_openai_embedding=False,
):
assert db_type is not None
if verbose:
Expand Down Expand Up @@ -4681,6 +4695,9 @@ def path_to_doc1(file,
selected_file_types=selected_file_types,
is_public=is_public,
from_ui=from_ui,

hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding,
)
except BaseException as e:
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
Expand Down Expand Up @@ -4764,6 +4781,9 @@ def path_to_docs(path_or_paths,
selected_file_types=None,

from_ui=True,

use_openai_embedding=False,
hf_embedding_model=None,
):
if verbose:
print("BEGIN Consuming path_or_paths=%s url=%s text=%s" % (path_or_paths, url, text), flush=True)
Expand Down Expand Up @@ -4906,6 +4926,9 @@ def path_to_docs(path_or_paths,

is_public=is_public,
from_ui=from_ui,

hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding,
)

if is_public:
Expand All @@ -4923,6 +4946,7 @@ def no_tqdm(x):
filei0 = filei

if n_jobs != 1 and len(globs_non_image_types) > 1:
kwargs['hf_embedding_model'] = None # can't fork and use CUDA
# avoid nesting, e.g. upload 1 zip and then inside many files
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
Expand Down Expand Up @@ -5520,7 +5544,9 @@ def _make_db(use_openai_embedding=False,

sources = []
if not db:
chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type)
chunk_sources = functools.partial(_chunk_sources, chunk=chunk, chunk_size=chunk_size, db_type=db_type,
hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding, verbose=verbose)
if langchain_mode in ['wiki_full']:
from read_wiki_full import get_all_documents
small_test = None
Expand Down Expand Up @@ -5602,6 +5628,9 @@ def _make_db(use_openai_embedding=False,

is_public=False,
from_ui=True,

hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding,
)
new_metadata_sources = set([x.metadata['source'] for x in sources1])
if new_metadata_sources:
Expand Down Expand Up @@ -6777,110 +6806,6 @@ def _get_docs_with_score(query, k_db,
return docs_with_score


def select_docs_with_score(docs_with_score, top_k_docs, one_doc_size):
if top_k_docs > 0:
docs_with_score = docs_with_score[:top_k_docs]
elif one_doc_size is not None:
docs_with_score = [(docs_with_score[0][:one_doc_size], docs_with_score[0][1])]
else:
# do nothing
pass
return docs_with_score


class H2OCharacterTextSplitter(RecursiveCharacterTextSplitter):
@classmethod
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
def _huggingface_tokenizer_length(text: str) -> int:
return get_token_count(text, tokenizer)

return cls(length_function=_huggingface_tokenizer_length, **kwargs)


def split_merge_docs(docs_with_score, tokenizer=None, max_input_tokens=None, docs_token_handling=None,
joiner=docs_joiner_default,
non_doc_prompt='',
do_split=True,
verbose=False):
# NOTE: Could use joiner=\n\n, but if PDF and continues, might want just full continue with joiner=''
# NOTE: assume max_input_tokens already processed if was -1 and accounts for model_max_len and is per-llm call
if max_input_tokens is not None:
max_input_tokens -= get_token_count(non_doc_prompt, tokenizer)

if docs_token_handling in ['chunk']:
return docs_with_score, 0
elif docs_token_handling in [None, 'split_or_merge']:
assert tokenizer
tokens_before_split = [get_token_count(x + joiner, tokenizer) for x in
[x[0].page_content for x in docs_with_score]]
# skip split if not necessary, since expensive for some reason
do_split &= any([x > max_input_tokens for x in tokens_before_split])
if do_split:

if verbose:
print('tokens_before_split=%s' % tokens_before_split, flush=True)

# see if need to split
# account for joiner tokens
joiner_tokens = get_token_count(joiner, tokenizer)
doc_chunk_size = max(64, min(max_input_tokens,
max(64, max_input_tokens - joiner_tokens * len(docs_with_score))))
text_splitter = H2OCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=doc_chunk_size, chunk_overlap=0
)
[x[0].metadata.update(dict(docscore=x[1], doci=doci, ntokens=tokens_before_split[doci])) for doci, x in
enumerate(docs_with_score)]
docs = [x[0] for x in docs_with_score]
# only split those that need to be split, else recursive splitter goes too nuts and takes too long
docs_to_split = [x for x in docs if x.metadata['ntokens'] > doc_chunk_size]
docs_to_not_split = [x for x in docs if x.metadata['ntokens'] <= doc_chunk_size]
docs_split_new = flatten_list([text_splitter.split_documents([x]) for x in docs_to_split])
docs_new = docs_to_not_split + docs_split_new
doci_new = [x.metadata['doci'] for x in docs_new]
# order back by doci
docs_new = [x for _, x in sorted(zip(doci_new, docs_new), key=lambda pair: pair[0])]
docs_with_score = [(x, x.metadata['docscore']) for x in docs_new]

tokens_after_split = [get_token_count(x + joiner, tokenizer) for x in
[x[0].page_content for x in docs_with_score]]
if verbose:
print('tokens_after_split=%s' % tokens_after_split, flush=True)

docs_with_score_new = []
k = 0
while k < len(docs_with_score):
# means use max_input_tokens to ensure model gets no more than max_input_tokens each map
top_k_docs, one_doc_size, num_doc_tokens = \
get_docs_tokens(tokenizer,
text_context_list=[x[0].page_content for x in docs_with_score[k:]],
max_input_tokens=max_input_tokens)
docs_with_score1 = select_docs_with_score(docs_with_score[k:], top_k_docs, one_doc_size)
new_score = docs_with_score1[0][1]
new_page_content = joiner.join([x[0].page_content for x in docs_with_score1])
new_metadata = docs_with_score1[0][0].metadata.copy()
new_metadata['source'] = joiner.join(set([x[0].metadata['source'] for x in docs_with_score1]))
doc1 = Document(page_content=new_page_content, metadata=new_metadata)
docs_with_score_new.append((doc1, new_score))

if do_split:
assert one_doc_size is None, "Split failed: %s" % one_doc_size
elif one_doc_size is not None:
# chopped
assert top_k_docs == 1
assert top_k_docs >= 1
k += top_k_docs

tokens_after_merge = [get_token_count(x + joiner, tokenizer) for x in
[x[0].page_content for x in docs_with_score_new]]
if verbose:
print('tokens_after_merge=%s' % tokens_after_merge, flush=True)

max_tokens_after_merge = max(tokens_after_merge) if tokens_after_merge else 0
return docs_with_score_new, max_tokens_after_merge
else:
raise ValueError("No such docs_token_handling=%s" % docs_token_handling)


def get_single_document(document_choice, db, extension=None):
if isinstance(document_choice, str):
document_choice = [document_choice]
Expand Down Expand Up @@ -7797,8 +7722,6 @@ def get_chain(query=None,
text_context_list=text_context_list,
chunk_id_filter=chunk_id_filter)

if top_k_docs == -1:
top_k_docs = len(db_documents)
# similar to langchain's chroma's _results_to_docs_and_scores
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
for result in zip(db_documents, db_metadatas)]
Expand Down Expand Up @@ -7831,7 +7754,8 @@ def get_chain(query=None,
]
docs_with_score = docs_with_score2

docs_with_score = docs_with_score[:top_k_docs]
top_k_docs_sample = len(db_documents) if top_k_docs == -1 else top_k_docs
docs_with_score = docs_with_score[:top_k_docs_sample]
docs = [x[0] for x in docs_with_score]
scores = [x[1] for x in docs_with_score]
else:
Expand Down Expand Up @@ -7960,15 +7884,16 @@ def get_chain(query=None,
get_llm(**llm_kwargs)

# avoid craziness
top_k_docs_sample = len(docs_with_score) if top_k_docs == -1 else top_k_docs
if 0 < top_k_docs_trial < max_chunks:
# avoid craziness
if top_k_docs == -1:
top_k_docs = top_k_docs_trial
top_k_docs_sample = top_k_docs_trial
else:
top_k_docs = min(top_k_docs, top_k_docs_trial)
top_k_docs_sample = min(top_k_docs, top_k_docs_trial)
elif top_k_docs_trial >= max_chunks:
top_k_docs = max_chunks
docs_with_score = select_docs_with_score(docs_with_score, top_k_docs, one_doc_size)
top_k_docs_sample = max_chunks
docs_with_score = select_docs_with_score(docs_with_score, top_k_docs_sample, one_doc_size)
else:
# don't reduce, except listen to top_k_docs and max_total_input_tokens
one_doc_size = None
Expand Down Expand Up @@ -8002,13 +7927,13 @@ def get_chain(query=None,
# nothing, just getting base amount for each call
)

# group docs if desired/can to fill context to avoid multiple LLM calls or too large chunks
docs_with_score, max_doc_tokens = split_merge_docs(docs_with_score,
tokenizer,
max_input_tokens=max_input_tokens,
docs_token_handling=docs_token_handling,
joiner=docs_joiner if not doing_grounding else "Document xx",
non_doc_prompt=estimated_full_prompt,
hf_embedding_model=hf_embedding_model,
verbose=verbose)
# in case docs_with_score grew due to splitting, limit again by top_k_docs
if top_k_docs > 0:
Expand Down Expand Up @@ -8540,11 +8465,11 @@ def get_sources_answer(query, docs, answer,
answer_sources)
if verbose or True:
if t_run is not None and int(t_run) > 0:
sorted_sources_urls += 'Total Time: %d [s]<p>' % t_run
sorted_sources_urls += 'Total Time: %d [s]<br>' % t_run
if count_input_tokens and count_output_tokens:
sorted_sources_urls += 'Input Tokens: %s | Output Tokens: %d<p>' % (
sorted_sources_urls += 'Input Tokens: %s | Output Tokens: %d<br>' % (
count_input_tokens, count_output_tokens)
sorted_sources_urls += "Total document chunks used: %s<p>" % len(docs)
sorted_sources_urls += "Total document chunks used: %s<br>" % len(docs)
sorted_sources_urls += f"<font size=\"{font_size}\"></ul></p>{source_postfix}</font>"
title_overall = "Sources"
sorted_sources_urls = f"""<details><summary><font size="{font_size}">{title_overall}</font></summary><font size="{font_size}">{sorted_sources_urls}</font></details>"""
Expand Down Expand Up @@ -8950,6 +8875,9 @@ def _update_user_db(file,

is_public=is_public,
from_ui=from_ui,

use_openai_embedding=use_openai_embedding,
hf_embedding_model=hf_embedding_model,
)
exceptions = [x for x in sources if x.metadata.get('exception')]
exceptions_strs = [x.metadata['exception'] for x in exceptions]
Expand Down
12 changes: 7 additions & 5 deletions src/gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5376,11 +5376,11 @@ def larger_str(x, y):
text_output],
outputs=[text_output, chat_exception_text, speech_bot],
)
retry_user_args = dict(fn=functools.partial(user, retry=True),
retry_user_args = dict(fn=functools.partial(user, retry=True, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
inputs=inputs_list + [text_output],
outputs=text_output,
)
undo_user_args = dict(fn=functools.partial(user, undo=True),
undo_user_args = dict(fn=functools.partial(user, undo=True, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
inputs=inputs_list + [text_output],
outputs=text_output,
)
Expand Down Expand Up @@ -5566,10 +5566,11 @@ def clear_all():
inputs=[my_db_state, requests_state, guest_name, retry_btn, retry_btn],
outputs=[my_db_state, requests_state, retry_btn],
queue=queue)
submit_event3a = submit_event31.then(**user_args, api_name='retry' if allow_api else False)
submit_event3a = submit_event31.then(**retry_user_args,
api_name='retry' if allow_api else False)
# if retry, no longer the saved chat
submit_event3a2 = submit_event3a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue)
submit_event3b = submit_event3a2.then(**user_args2, api_name='retry2' if allow_api else False)
submit_event3b = submit_event3a2.then(**retry_user_args2, api_name='retry2' if allow_api else False)
submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \
.then(clear_instruct, None, iinput)
submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else False,
Expand Down Expand Up @@ -7017,7 +7018,8 @@ def show_doc(db1s, selection_docs_state1, requests_state1,
try:
with open(file, 'rt') as f:
content = f.read()
content = f"```text\n{content}\n```"
#content = f"```text\n{content}\n```"
content = text_to_html(content, api=api)
return dummy1, dummy1, dummy1, gr.update(visible=True, value=content), dummy1, dummy1, dummy1, dummy1
except:
return dummy_ret
Expand Down
12 changes: 11 additions & 1 deletion src/make_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def glob_to_db(user_path, chunk=True, chunk_size=512, verbose=False,
db_type=None,
selected_file_types=None,

is_public=False):
is_public=False,

hf_embedding_model=None,
use_openai_embedding=False,
):
assert db_type is not None

loaders_and_settings = dict(
Expand Down Expand Up @@ -100,6 +104,9 @@ def glob_to_db(user_path, chunk=True, chunk_size=512, verbose=False,

db_type=db_type,
is_public=is_public,

hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding,
)
sources1 = path_to_docs(user_path,
url=url,
Expand Down Expand Up @@ -371,6 +378,9 @@ def make_db_main(use_openai_embedding: bool = False,
selected_file_types=selected_file_types,

is_public=False,

hf_embedding_model=hf_embedding_model,
use_openai_embedding=use_openai_embedding,
)
exceptions = [x for x in sources if x.metadata.get('exception')]
print("Exceptions: %s/%s %s" % (len(exceptions), len(sources), exceptions), flush=True)
Expand Down
Loading
Loading