Skip to content

Commit

Permalink
Add Granite code support (#1336)
Browse files Browse the repository at this point in the history
* feat(models): Add models.json blocks for Granite Code 3b and 8b

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Initial model params for granite code 3b

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(model config): Fix model configs for Granite Code

* Use the right tokenizer_file name
* Use the right transformer_params_key based on the file name in
model_params
* Use the updated name to indicate HF tokenizers

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(granite): Add model params for granite-code-8b

Something isn't quite working with this model yet, but the config should be
accurate at this point.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(deps): Add tokenizers to the deps explicitly

It was implicitly being pulled in via lm_eval -> transformers, but it's
better to have it explicit since we use it directly

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(tokenizer): Add basic support for jinja2 template rendering for HF tokenizers

This is a much simplified version of the corresponding logic in
transformers. I opted for this so that the full transformers dependency is
not added here.

CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1522

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(chat): Add HFTokenizerChatFormatter and use it for HF tokenizers

This will allow the jinja2 templates for HF tokenizers to be applied
without needing to hard-code the formatter logic. This will likely need to
be duplicated in the embedded code version of chat.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(deps): Add jinja2 as an explicit dep

It was getting pulled in implicitly via flask and lm_eval -> transformers,
but better to have it explicit.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(log): Add env-based LOG_LEVEL config to CLI

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(log): Add better logging in model and generate

In generate, there were a number of commented-out log lines. These are safe
to leave in as long as lazy string interpolation is used.

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat(generate): Make prepending BOS model-conigurable

And disable it for Granite Code models

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(chat): Refactor chat template logic to encapsulate all formatting in classes

The formatted strings may not be perfectly 1:1 with the previous impl, but
they should be in line with the official model guidelines:

* https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
* https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(chat): Fix small formatting bugs in llama3 chat formatter

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* test: Add initial unit tests for chat formatters

There's no formal execution framework for pytest yet, but these were
helpful in ensuring that the formatting was working correctly!

To run them, install pytest and run `pytest tests/`

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(logging): Disable logging in generate unless set in the env

There is an incompatibility with logging and torch._dynamo, so this
disables it unless the developer asks for it explicitly.

NOTE: The TC team has stated that they have holistic logging on the roadmap
so this is a short-term solution pending a more robust approach.

REF: https://github.com/pytorch/torchchat/actions/runs/11963066986/job/33493237302#step:14:3599

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove trailing n from llama3 <|eot_id|>

There's inconsistency in the documentation on whether or not there should
be a n after <|eot_id|>, but this maintains consistency with previous
formatting

Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
  • Loading branch information
2 people authored and vmpuri committed Feb 4, 2025
1 parent ff2d53c commit 5e16167
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 75 deletions.
4 changes: 4 additions & 0 deletions install/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ gguf
# Tiktoken tokenizer for Llama 3 and other advanced models
tiktoken

# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
tokenizers
jinja2

# Miscellaneous
snakeviz
sentencepiece
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Global pytest config, fixtures, and helpers go here!
"""

# Standard
import os
import sys

# Make sure tests can import torchchat
sys.path.append(
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
)
216 changes: 216 additions & 0 deletions tests/test_chat_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
Unit tests for chat formatters
"""

# Third Party
import pytest

# Local
from torchchat.generate import (
HFTokenizerChatFormatter,
Llama2ChatFormatter,
Llama3ChatFormatter,
)

## Helpers #####################################################################

class DummyTokenizer:
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
def encode(self, text, *_, **__):
return text


class DummySPTokenizer(DummyTokenizer):
"""Emulated Sentencepiece tokenizer with bos/eos"""
bos = "<s>"
eos = "</s>"


class DummyLlama3Tokenizer(DummyTokenizer):
class _IdentityDict:
def __getitem__(self, key):
return key
special_tokens = _IdentityDict()


class DummyHFTokenizer(DummyTokenizer):
"""Dummy made up chat template scheme"""
# Sequence
bos = "<bos>"
# Turn
bot = "<bot>"
eot = "<eot>"
# Role
bor = "<bor>"
eor = "<eor>"
def apply_chat_template(self, messages, add_generation_prompt):
out = [self.bos]
role = None
for msg in messages:
role = msg["role"]
content = msg["content"]
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
if add_generation_prompt and role != "assistant":
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
return "\n".join(out)


def check_rendering(fmt, messages, expected, add_generation_prompt):
"""Render messages and compare to expected output"""
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected


def make_message(role, text):
return {"role": role, "content": text}


SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
USER1 = "Hello world!"
ASSISTANT1 = "Greetings! How can I help you?"
USER2 = "Why is the sky blue?"
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."


# Stock sets of messages to test
MSGS_NO_SYS= [
make_message("user", USER1),
]
MSGS_SYS_USR = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
]
MSGS_SYS_USR_ASST = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
make_message("assistant", ASSISTANT1),
]
MSGS_MULTI_TURN = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
make_message("assistant", ASSISTANT1),
make_message("user", USER2),
make_message("assistant", ASSISTANT2),
]

## Llama2ChatFormatter #########################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
# sys, usr
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>
{USER1} [/INST]"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>
{USER1} [/INST] {ASSISTANT1} </s>
"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>
{USER1} [/INST] {ASSISTANT1} </s>
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
"""),
]
)
def test_llama2_chat_formatter(messages, expected):
"""Tests for Llama2 following the official guide
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
"""
tok = DummySPTokenizer()
fmt = Llama2ChatFormatter(tok)
# NOTE: add_generation_prompt not used by Llama2
check_rendering(fmt, messages, expected, True)

## Llama3ChatFormatter #########################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{USER1}<|eot_id|>"""),
# sys, usr
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{USER1}<|eot_id|>"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{ASSISTANT1}<|eot_id|>"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>
{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{ASSISTANT2}<|eot_id|>"""),
]
)
@pytest.mark.parametrize("add_generation_prompt", [True, False])
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
"""Tests for Llama3 following the official guide
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
"""
tok = DummyLlama3Tokenizer()
fmt = Llama3ChatFormatter(tok)
# No assistant prompt added if the last message is from the assistant
if add_generation_prompt and messages[-1]["role"] != "assistant":
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
check_rendering(fmt, messages, expected, add_generation_prompt)

## HFTokenizerChatFormatter ####################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"""<bos>
<bot><bor>user<eor>{USER1}<eot>"""),
# sys, usr
(MSGS_SYS_USR, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
<bot><bor>user<eor>{USER2}<eot>
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
]
)
@pytest.mark.parametrize("add_generation_prompt", [True, False])
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
tok = DummyHFTokenizer()
fmt = HFTokenizerChatFormatter(tok)
# No assistant prompt added if the last message is from the assistant
if add_generation_prompt and messages[-1]["role"] != "assistant":
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
check_rendering(fmt, messages, expected, add_generation_prompt)
28 changes: 27 additions & 1 deletion tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

# Standard
from typing import List, Optional
from typing import Dict, List, Optional
import json
import os

# Third Party
import jinja2
from tokenizers import Tokenizer

# Local
Expand Down Expand Up @@ -37,6 +38,9 @@ def __init__(self, file_path: str):
# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)

# Load the chat template if we have a config path
self._chat_template: Optional[jinja2.Template] = None

# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
Expand All @@ -48,6 +52,8 @@ def __init__(self, file_path: str):
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)
if chat_template_str := tok_config.get("chat_template"):
self._chat_template = jinja2.Template(chat_template_str)

# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
Expand All @@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]

## Interface ##

def encode(
self,
s: str,
Expand All @@ -90,3 +98,21 @@ def bos_id(self) -> int:

def eos_id(self) -> int:
return self._eos_id

## Additional Public Methods ##

def has_chat_template(self) -> bool:
return bool(self._chat_template)

def apply_chat_template(
self,
dialog: List[Dict[str, str]],
add_generation_prompt: bool = False,
) -> str:
"""If configured with a chat template, apply it to the list of messages
"""
if not self._chat_template:
raise ValueError("No chat template configured!")
return self._chat_template.render(
messages=dialog, add_generation_prompt=add_generation_prompt
)
10 changes: 9 additions & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
allowable_params_table,
)

logging.basicConfig(level=logging.INFO, format="%(message)s")
_log_level_env = os.getenv("LOG_LEVEL", "INFO")
try:
_log_level = getattr(logging, _log_level_env.upper())
except AttributeError:
print(f"Invalid log level: {_log_level_env}", file=sys.stderr)
_log_level = logging.INFO


logging.basicConfig(level=_log_level, format="%(message)s")
logger = logging.getLogger(__name__)

default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
Expand Down
Loading

0 comments on commit 5e16167

Please sign in to comment.