Skip to content

Commit

Permalink
Merge pull request #1533 from h2oai/fix_grounded_template_token_counting
Browse files Browse the repository at this point in the history
Fix grounded template token counting
  • Loading branch information
pseudotensor authored Apr 6, 2024
2 parents f00b6ac + a1af6e6 commit 9d1179a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 32 deletions.
52 changes: 44 additions & 8 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,7 +2499,8 @@ def get_non_lora_model(base_model, model_loader, load_half,
return model


def get_client_from_inference_server(inference_server, base_model=None, raise_connection_exception=False, verbose=False):
def get_client_from_inference_server(inference_server, base_model=None, raise_connection_exception=False,
verbose=False):
inference_server, headers, username, password = get_hf_server(inference_server)
gr_client = None
hf_client = None
Expand All @@ -2508,7 +2509,8 @@ def get_client_from_inference_server(inference_server, base_model=None, raise_co

if base_model and is_gradio_vision_model(base_model):
from gradio_utils.grclient import GradioClient
gr_client = GradioClient(inference_server, check_hash=False, verbose=verbose, serialize=is_gradio_version4, **gradio_auth)
gr_client = GradioClient(inference_server, check_hash=False, verbose=verbose, serialize=is_gradio_version4,
**gradio_auth)
gr_client.setup()
elif headers is None:
try:
Expand Down Expand Up @@ -4104,8 +4106,10 @@ def evaluate(

pre_instruction = ''
if guided_json and response_format == 'json_object' and (json_vllm or
inference_server and inference_server.startswith('anthropic') and
is_json_model(base_model, inference_server, json_vllm=json_vllm)):
inference_server and inference_server.startswith(
'anthropic') and
is_json_model(base_model, inference_server,
json_vllm=json_vllm)):
# for vLLM or claude-3, support schema if given
# can't give schema both in prompt and tool/guided_json, messes model up
pass
Expand Down Expand Up @@ -5934,6 +5938,7 @@ def get_limited_prompt(instruction,
truncation_generation=False,
gradio_server=False,
attention_sinks=False,
doing_grounding=False,
):
if gradio_server or not inference_server:
# can listen to truncation_generation
Expand Down Expand Up @@ -6080,7 +6085,16 @@ def get_limited_prompt(instruction,

if text_context_list is None:
text_context_list = []
num_doc_tokens = sum([get_token_count(x + docs_joiner_default, tokenizer) for x in text_context_list])

num_doc_overhead_tokens = count_overhead_tokens(tokenizer, doing_grounding=doing_grounding)
if doing_grounding:
docs_joiner = "Document xx"
else:
docs_joiner = docs_joiner_default
# handle overhead by lowering locally max input tokens, since not removable
max_input_tokens -= num_doc_overhead_tokens

num_doc_tokens = sum([get_token_count(x + docs_joiner, tokenizer) for x in text_context_list])

num_prompt_tokens0 = (num_system_tokens or 0) + \
(num_instruction_tokens or 0) + \
Expand Down Expand Up @@ -6213,12 +6227,14 @@ def get_limited_prompt(instruction,
# update max_new_tokens
# limit so max_new_tokens = prompt + new < max
# otherwise model can fail etc. e.g. for distilgpt2 asking for 1024 tokens is enough to fail if prompt=1 token
if truncation_generation:
if not attention_sinks:
max_new_tokens = max(1, min(max_new_tokens, model_max_length - num_prompt_tokens))

if max_new_tokens < min_max_new_tokens:
if os.getenv('HARD_ASSERTS'):
if max_new_tokens < min_max_new_tokens:
raise ValueError("Invalid max_new_tokens=%s" % max_new_tokens)
raise ValueError("Invalid max_new_tokens=%s" % max_new_tokens)
else:
max_new_tokens = max(32, max_new_tokens)

if prompter is None:
# get prompter
Expand Down Expand Up @@ -6251,6 +6267,26 @@ def get_limited_prompt(instruction,
top_k_docs, one_doc_size, truncation_generation, system_prompt


def count_overhead_tokens(tokenizer, doing_grounding=False):
if doing_grounding:
from openai_server.backend_utils import structure_to_messages
system_prompt = ''
instruction = 'foo'
chat_conversation = []
prompt = tokenizer.apply_grounded_generation_template(
structure_to_messages(instruction,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
chat_conversation),
documents=[dict(text='foo')],
citation_mode="accurate", # or "fast"
tokenize=False,
add_generation_prompt=True,
)
return get_token_count(prompt, tokenizer)
else:
return 0


def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
"""
max_input_tokens: Over all LLM calls, upper limit of total token count,
Expand Down
73 changes: 49 additions & 24 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,8 @@ def get_llm(use_openai_model=False,
cls = H2OChatOpenAI
# FIXME: Support context, iinput
if inf_type == 'vllm_chat':
if is_json_model(model_name, inference_server, json_vllm=json_vllm) and response_format == 'json_object':
if is_json_model(model_name, inference_server,
json_vllm=json_vllm) and response_format == 'json_object':
# vllm without guided_json can't make json directly
kwargs_extra.update(dict(type=response_format if guided_json else 'text'))
async_output = False # https://github.com/h2oai/h2ogpt/issues/928
Expand Down Expand Up @@ -2352,6 +2353,7 @@ def get_llm(use_openai_model=False,
# NOTE: claude requires keys of properties to match pattern '^[a-zA-Z0-9_-]{1,64}$'
# i.e. no spaces, while vLLM can handle spaces.
if is_json_model(model_name, inference_server) and guided_json and response_format == 'json_object':
# https://docs.anthropic.com/claude/docs/tool-use#specifying-tools
model_kwargs = dict(tools=[
{
"name": "JSON",
Expand Down Expand Up @@ -7435,6 +7437,7 @@ def get_chain(query=None,
gradio_server=gradio_server,
attention_sinks=attention_sinks,
hyde_level=hyde_level,
doing_grounding=doing_grounding,
)

# NOTE: if map_reduce, then no need to auto reduce chunks
Expand Down Expand Up @@ -7473,7 +7476,12 @@ def get_chain(query=None,
docs_with_score],
)
# get updated llm
llm_kwargs.update(max_new_tokens=max_new_tokens, context=context, iinput=iinput, system_prompt=system_prompt)
llm_kwargs.update(max_new_tokens=max_new_tokens,
max_input_tokens=max_input_tokens,
max_total_input_tokens=max_total_input_tokens,
context=context,
iinput=iinput,
system_prompt=system_prompt)
if external_handle_chat_conversation:
# should already have attribute, checking sanity
assert hasattr(llm, 'chat_conversation')
Expand Down Expand Up @@ -7529,7 +7537,7 @@ def get_chain(query=None,
tokenizer,
max_input_tokens=max_input_tokens,
docs_token_handling=docs_token_handling,
joiner=docs_joiner,
joiner=docs_joiner if not doing_grounding else "Document xx",
non_doc_prompt=estimated_full_prompt,
verbose=verbose)
# in case docs_with_score grew due to splitting, limit again by top_k_docs
Expand All @@ -7556,7 +7564,7 @@ def get_chain(query=None,
# imperfect calculation, so will see how testing does
assert max_new_tokens >= min_max_new_tokens - 50, "%s %s" % (max_new_tokens, min_max_new_tokens)
# get updated llm
llm_kwargs.update(max_new_tokens=max_new_tokens)
llm_kwargs.update(max_new_tokens=max_new_tokens, max_input_tokens=max_input_tokens)
llm, model_name, streamer, prompt_type_out, async_output, only_new_text, gradio_server = \
get_llm(**llm_kwargs)

Expand Down Expand Up @@ -7649,20 +7657,42 @@ def get_chain(query=None,
template='{context}{question}', # ignored
)
chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
documents = [merge_dict(dict(text=x.page_content),
{k: v for k, v in x.metadata.items() if
v and k in metadata_in_context_set}) for x in docs]
from openai_server.backend_utils import structure_to_messages
conversation = structure_to_messages(query,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
chat_conversation)
query_with_docs = tokenizer.apply_grounded_generation_template(
conversation,
documents=documents,
citation_mode="accurate", # or "fast"
tokenize=False,
add_generation_prompt=True,
)

while True:
conversation = structure_to_messages(query,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
chat_conversation)
documents = [merge_dict(dict(text=x.page_content),
{k: v for k, v in x.metadata.items() if
v and k in metadata_in_context_set}) for x in docs]
query_with_docs = tokenizer.apply_grounded_generation_template(
conversation,
documents=documents,
citation_mode="accurate", # or "fast"
tokenize=False,
add_generation_prompt=True,
)
grounded_tokens = len(tokenizer.encode(query_with_docs))
if grounded_tokens > max_input_tokens and len(docs) > 0:
if docs_ordering_type in ['best_first']:
docs.pop()
elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
docs.pop(0)
elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
del docs[len(docs) // 2]
else:
raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)

elif grounded_tokens > max_input_tokens and len(chat_conversation) > 0:
chat_conversation = []
elif grounded_tokens > max_input_tokens and system_prompt:
system_prompt = ''
else:
if grounded_tokens > max_input_tokens:
print("Failed to fit grounded tokens: %s %s" % (grounded_tokens, max_input_tokens))
break

chain_kwargs = dict(input_documents=[], question=query_with_docs)
else:
if use_template:
Expand Down Expand Up @@ -7768,13 +7798,8 @@ def get_max_input_tokens(llm=None, tokenizer=None, inference_server=None, model_
# don't trust that fake tokenizer (e.g. GGUF/GGML) will make lots of tokens normally, allow more input
max_input_tokens = model_max_length - min(256, max_new_tokens)
else:
if 'falcon' in model_name or inference_server.startswith('http'):
# allow for more input for falcon, assume won't make as long outputs as default max_new_tokens
# Also allow if TGI or Gradio, because we tell it input may be same as output, even if model can't actually handle
max_input_tokens = model_max_length - min(256, max_new_tokens)
else:
# trust that maybe model will make so many tokens, so limit input
max_input_tokens = model_max_length - max_new_tokens
# trust that maybe model will make so many tokens, so limit input
max_input_tokens = model_max_length - max_new_tokens

return max_input_tokens

Expand Down

0 comments on commit 9d1179a

Please sign in to comment.