Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean-up stopping to avoid hard-coded things for llama-3 as it was fixed 11 days ago. #1590

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@

from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt, \
openai_gpts, get_vllm_extra_dict, anthropic_gpts, google_gpts, mistralai_gpts, groq_gpts, \
gradio_to_llm, history_for_llm, is_gradio_vision_model, is_json_model, get_use_chat_template, apply_chat_template
gradio_to_llm, history_for_llm, is_gradio_vision_model, is_json_model, apply_chat_template
from stopping import get_stopping
from prompter_utils import get_use_chat_template

langchain_actions = [x.value for x in list(LangChainAction)]

Expand Down Expand Up @@ -3263,12 +3264,12 @@ def get_model(
if base_model in non_hf_types:
from gpt4all_llm import get_model_tokenizer_gpt4all
model, tokenizer_llamacpp, device = get_model_tokenizer_gpt4all(base_model,
n_jobs=n_jobs,
gpu_id=gpu_id,
n_gpus=n_gpus,
max_seq_len=max_seq_len,
llamacpp_dict=llamacpp_dict,
llamacpp_path=llamacpp_path)
n_jobs=n_jobs,
gpu_id=gpu_id,
n_gpus=n_gpus,
max_seq_len=max_seq_len,
llamacpp_dict=llamacpp_dict,
llamacpp_path=llamacpp_path)
# give chance to use tokenizer_base_model
if tokenizer is None:
tokenizer = tokenizer_llamacpp
Expand Down
2 changes: 1 addition & 1 deletion src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7980,7 +7980,7 @@ def get_chain(query=None,
context=context,
iinput=iinput,
system_prompt=system_prompt)
if external_handle_chat_conversation or prompter.prompt_type in ['template', 'unknown']:
if external_handle_chat_conversation or prompter.prompt_type in [template_prompt_type, unknown_prompt_type]:
# should already have attribute, checking sanity
assert hasattr(llm, 'chat_conversation')
llm_kwargs.update(chat_conversation=history_to_use_final)
Expand Down
38 changes: 10 additions & 28 deletions src/prompter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ast
import time
import os

# also supports imports from this file from other files
from enums import PromptType, gpt_token_mapping, \
anthropic_mapping, google_mapping, mistralai_mapping, groq_mapping, openai_supports_json_mode, noop_prompt_type, \
unknown_prompt_type, template_prompt_type, user_prompt_for_fake_system_prompt0
unknown_prompt_type, user_prompt_for_fake_system_prompt0, template_prompt_type
from src.prompter_utils import get_use_chat_template
from src.stopping import update_terminate_responses
from src.utils import get_gradio_tmp

non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
Expand Down Expand Up @@ -1676,18 +1679,6 @@ def inject_chatsep(prompt_type, prompt, chat_sep=None):
return prompt


def get_use_chat_template(tokenizer, prompt_type=None):
if tokenizer is None:
return False
use_chat_template = prompt_type in [None, '', unknown_prompt_type, template_prompt_type] and \
(hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)
return use_chat_template


class Prompter(object):
def __init__(self, prompt_type, prompt_dict, debug=False, stream_output=False, repeat_penalty=False,
allowed_repeat_line_length=10, system_prompt=None, tokenizer=None, verbose=False):
Expand All @@ -1709,20 +1700,11 @@ def __init__(self, prompt_type, prompt_dict, debug=False, stream_output=False, r
system_prompt=system_prompt)
self.use_chat_template = False
self.tokenizer = tokenizer
if tokenizer is not None:
self.use_chat_template = get_use_chat_template(tokenizer, prompt_type=prompt_type)
if self.use_chat_template:
# add terminations
if self.terminate_response is None:
self.terminate_response = []
# like in stopping.py
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
self.terminate_response.extend([tokenizer.eos_token])
if '<|eot_id|>' in tokenizer.added_tokens_encoder:
self.terminate_response.extend(['<|eot_id|>'])
if '<|im_end|>' in tokenizer.added_tokens_encoder:
self.terminate_response.extend(['<|im_end|>'])

if self.terminate_response is None:
self.terminate_response = []
self.use_chat_template = get_use_chat_template(tokenizer, prompt_type=prompt_type)
self.terminate_response = update_terminate_responses(self.terminate_response,
tokenizer=tokenizer)
self.pre_response = self.PreResponse
self.verbose = verbose

Expand All @@ -1743,7 +1725,7 @@ def generate_prompt(self, data_point, reduced=False, context_from_history=None,
In which case we need to put promptA at very front to recover correct behavior
:return:
"""
if self.prompt_type in ['template', 'unknown']:
if self.prompt_type in [template_prompt_type, unknown_prompt_type]:
assert self.use_chat_template
assert self.tokenizer is not None
from src.gen import apply_chat_template
Expand Down
13 changes: 13 additions & 0 deletions src/prompter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from src.enums import unknown_prompt_type, template_prompt_type


def get_use_chat_template(tokenizer, prompt_type=None):
if tokenizer is None:
return False
use_chat_template = prompt_type in [None, '', unknown_prompt_type, template_prompt_type] and \
(hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)
return use_chat_template
39 changes: 30 additions & 9 deletions src/stopping.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
import time

import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import StoppingCriteria, StoppingCriteriaList, GenerationConfig

from enums import PromptType, t5_type
from src.prompter_utils import get_use_chat_template


def update_terminate_responses(terminate_response, tokenizer=None):
if terminate_response is None:
terminate_response = []
if tokenizer is not None:
# e.g. for dbrx
if '<|im_end|>' in tokenizer.added_tokens_encoder:
terminate_response.extend(['<|im_end|>'])
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
if isinstance(tokenizer.eos_token, str):
terminate_response.extend([tokenizer.eos_token])
elif isinstance(tokenizer.eos_token, list):
terminate_response.extend(tokenizer.eos_token)

if hasattr(tokenizer, 'name_or_path'):
reverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
generate_eos_token_id = GenerationConfig.from_pretrained(tokenizer.name_or_path).eos_token_id
if isinstance(generate_eos_token_id, list):
for eos_token_id in generate_eos_token_id:
terminate_response.extend([reverse_vocab[eos_token_id]])
else:
terminate_response.extend([reverse_vocab[generate_eos_token_id]])
terminate_response_tmp = terminate_response.copy()
terminate_response.clear()
[terminate_response.append(x) for x in terminate_response_tmp if x not in terminate_response]
return terminate_response


class StoppingCriteriaSub(StoppingCriteria):
Expand Down Expand Up @@ -132,14 +160,7 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
encounters += [1] * len(stop)
handle_newlines += [False] * len(stop)

# e.g. for llama-3
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
if '<|eot_id|>' in tokenizer.added_tokens_encoder:
stop_words.extend(['<|eot_id|>'])
if '<|im_end|>' in tokenizer.added_tokens_encoder:
stop_words.extend(['<|im_end|>'])
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
stop_words.extend([tokenizer.eos_token])
stop_words = update_terminate_responses(stop_words, tokenizer=tokenizer)

# get stop tokens
stop_words_ids = [
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "f957748399a19591580f0f99ee13b85f99e3f9fb"
__version__ = "832ad2d4a6b1431105785045a6b218a8451591f9"
Loading