Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tokenizing labels #214

Merged
merged 4 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/axolotl/prompt_strategies/alpaca_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,36 @@ def load(tokenizer, cfg):

class AlpacaConcisePrompter(AlpacaPrompter):
"""
Alpaca Prompter extending the system prompt to ask for concise answers
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
"""

system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"


class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""

system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"

def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()


class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""

prompt_input = "{instruction} {input} "
prompt_no_input = "{instruction} "

def __init__(self): # pylint: disable=super-init-not-called
pass


class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
Expand Down Expand Up @@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg):

def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
Expand All @@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg):

def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
Expand Down
32 changes: 17 additions & 15 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,27 @@ def tokenize_prompt(self, prompt):
input, # pylint: disable=redefined-builtin
response,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]

return tokenized_full_prompt
return tokenized_prompt

def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin
Expand Down
48 changes: 45 additions & 3 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from transformers import AutoTokenizer

from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompter
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter

logging.basicConfig(level="INFO")

Expand All @@ -29,7 +33,6 @@ def setUp(self) -> None:
)

def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
Expand All @@ -53,6 +56,45 @@ def test_sharegpt_integration(self):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])

def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
"output": "world!",
}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(3186)
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100

def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(6324)
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100


if __name__ == "__main__":
unittest.main()