Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CohereForAI/c4ai-command-r-v01 #1467

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,10 +2267,17 @@ def get_config(base_model,
config.update({"max_seq_len": 2 * 8192})
if return_model and \
issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
try:
model = AutoModel.from_config(
config,
trust_remote_code=trust_remote_code,
)
except Exception as e:
if 'has no attribute' in str(e):
# half-baked hack to transformers by Cohere
model = None
else:
raise
else:
# can't infer
model = None
Expand Down Expand Up @@ -5700,12 +5707,17 @@ def get_limited_prompt(instruction,
min_max_new_tokens=min_max_new_tokens)

from openai_server.backend_utils import structure_to_messages
use_chat_template = (prompt_type in [None, '', 'plain'] and
hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template)
use_chat_template = prompt_type in [None, '', 'plain'] and \
(hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)

if use_chat_template:
messages = structure_to_messages(instruction, system_prompt, history)
messages = structure_to_messages(instruction,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
history)
context2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
iinput = ''
context = ''
Expand Down
57 changes: 40 additions & 17 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
get_list_or_str, have_pillow, only_selenium, only_playwright, only_unstructured_urls, get_short_name, \
get_accordion, have_jq, get_doc, get_source, have_chromamigdb, get_token_count, reverse_ucurve_list, get_size, \
get_test_name_core, download_simple, have_fiftyone, have_librosa, return_good_url, n_gpus_global, \
get_accordion_named, hyde_titles, have_cv2, FullSet, create_relative_symlink, split_list, get_gradio_tmp
get_accordion_named, hyde_titles, have_cv2, FullSet, create_relative_symlink, split_list, get_gradio_tmp, merge_dict
from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
LangChainAction, LangChainMode, DocumentChoice, LangChainTypes, font_size, head_acc, super_source_prefix, \
super_source_postfix, langchain_modes_intrinsic, get_langchain_prompts, LangChainAgent, docs_joiner_default, \
Expand Down Expand Up @@ -1672,6 +1672,7 @@ async def agenerate_prompt(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)


class H2OChatAnthropic2Sys(H2OChatAnthropic2):
pass

Expand All @@ -1685,7 +1686,6 @@ class H2OChatAnthropic3(GenerateStream, ExtraChat, ChatAnthropic3):
# max_new_tokens0: Any = None # FIXME: Doesn't seem to have same max_tokens == -1 for prompts==1



class H2OChatAnthropic3Sys(H2OChatAnthropic3):
pass

Expand Down Expand Up @@ -5982,7 +5982,8 @@ def split_merge_docs(docs_with_score, tokenizer=None, max_input_tokens=None, doc
# see if need to split
# account for joiner tokens
joiner_tokens = get_token_count(docs_joiner_default, tokenizer)
doc_chunk_size = max(64, min(max_input_tokens, max(64, max_input_tokens - joiner_tokens * len(docs_with_score))))
doc_chunk_size = max(64, min(max_input_tokens,
max(64, max_input_tokens - joiner_tokens * len(docs_with_score))))
text_splitter = H2OCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=doc_chunk_size, chunk_overlap=0
)
Expand Down Expand Up @@ -7035,7 +7036,7 @@ def get_chain(query=None,
estimated_prompt_no_docs = template_if_no_docs.format(context='', question=query)

# add metadata to documents and make new copy of docs with them to not contaminate originals
if metadata_in_context and not doc_json_mode:
if metadata_in_context and not doc_json_mode and not hasattr(tokenizer, 'apply_grounded_generation_template'):
docs_with_score = [(Document(page_content='Begin Document:\n\n' +
'Metadata:\n' +
'\n'.join(['%s = %s' % (k, v) for k, v in x.metadata.items() if
Expand Down Expand Up @@ -7216,11 +7217,8 @@ def get_chain(query=None,
prompter=prompter)

if doc_json_mode:
def merge_dict(dict1, dict2):
return dict2.update(dict1)

# make copy so don't change originals
if metadata_in_context:
if metadata_in_context and not hasattr(tokenizer, 'apply_grounded_generation_template'):
docs = [Document(page_content=json.dumps(merge_dict(dict(ID=xi, content=x.page_content),
{k: v for k, v in x.metadata.items() if
v and k in metadata_in_context_set})),
Expand All @@ -7232,21 +7230,46 @@ def merge_dict(dict1, dict2):
for xi, x in enumerate(docs)]

if langchain_action == LangChainAction.QUERY.value:
if use_template:
# instruct-like, rather than few-shot prompt_type='plain' as default
# but then sources confuse the model with how inserted among rest of text, so avoid
if hasattr(tokenizer, 'apply_grounded_generation_template'):
assert prompt_type == 'plain'
# https://huggingface.co/CohereForAI/c4ai-command-r-v01
prompt = PromptTemplate(
# input_variables=["summaries", "question"],
input_variables=["context", "question"],
template=template,
template='{context}{question}', # ignored
)
chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
documents = [merge_dict(dict(text=x.page_content),
{k: v for k, v in x.metadata.items() if
v and k in metadata_in_context_set}) for x in docs]
from openai_server.backend_utils import structure_to_messages
conversation = structure_to_messages(query,
system_prompt if system_prompt not in [None, '', 'auto'] else None,
chat_conversation)
query_with_docs = tokenizer.apply_grounded_generation_template(
conversation,
documents=documents,
citation_mode="accurate", # or "fast"
tokenize=False,
add_generation_prompt=True,
)
chain_kwargs = dict(input_documents=[], question=query_with_docs)
else:
# unused normally except in testing
assert use_openai_model or prompt_type == 'plain', "Unexpected to use few-shot template for %s %s" % (
model_name, prompt_type)
chain = load_qa_with_sources_chain(llm)
chain_kwargs = dict(input_documents=docs, question=query)
if use_template:
# instruct-like, rather than few-shot prompt_type='plain' as default
# but then sources confuse the model with how inserted among rest of text, so avoid
prompt = PromptTemplate(
# input_variables=["summaries", "question"],
input_variables=["context", "question"],
template=template,
)
chain = load_qa_chain(llm, prompt=prompt, verbose=verbose)
else:
# unused normally except in testing
assert use_openai_model or prompt_type == 'plain', "Unexpected to use few-shot template for %s %s" % (
model_name, prompt_type)
chain = load_qa_with_sources_chain(llm)
chain_kwargs = dict(input_documents=docs, question=query)
target = wrapped_partial(chain, chain_kwargs)
elif summarize_action:
if async_output:
Expand Down
6 changes: 6 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,3 +2027,9 @@ def get_lock_file(name):
lock_file = os.path.join(base_path, "%s.lock" % lock_type)
makedirs(os.path.dirname(lock_file)) # ensure made
return lock_file


def merge_dict(dict1, dict2):
ret = dict1.copy()
ret.update(dict2)
return ret
Loading