diff --git a/requirements_env.txt b/requirements_env.txt new file mode 100644 index 0000000000..f8acbf73c2 --- /dev/null +++ b/requirements_env.txt @@ -0,0 +1,315 @@ +accelerate==0.34.1 +addict==2.4.0 +aiofiles==23.2.1 +aiohttp==3.9.0 +aiosignal==1.3.1 +aiostream==0.5.2 +alembic==1.13.1 +annotated-types==0.6.0 +annoy==1.17.3 +ansible==6.7.0 +ansible-core==2.13.13 +ansible-vault==2.1.0 +anyio==3.7.1 +appdirs==1.4.4 +art==6.0 +asgiref==3.7.2 +async-timeout==4.0.2 +attrdict==2.0.1 +attrs==22.2.0 +awscli==1.32.75 +-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl +backoff==2.2.1 +base58==2.1.1 +beartype==0.17.2 +bitnet==0.2.1 +bitsandbytes==0.42.0 +bittensor==6.7.0 +black==23.7.0 +blinker==1.7.0 +boto3==1.34.75 +botocore==1.34.75 +cachetools==5.3.3 +cachy==0.1.1 +certifi==2023.7.22 +cffi==1.16.0 +cfgv==3.3.1 +chai-guanaco==1.2.4 +charset-normalizer==3.2.0 +cleo==0.6.8 +click==8.1.7 +cloudpickle==2.0.0 +cohere==4.11.2 +colorama==0.4.4 +coloredlogs==15.0.1 +CoLT5-attention==0.10.20 +contextlib2==21.6.0 +contourpy==1.2.0 +cryptography==41.0.3 +cycler==0.12.1 +cytoolz==0.12.3 +databricks-cli==0.18.0 +dataclasses-json==0.5.7 +datasets==2.11.0 +ddt==1.6.0 +decorator==5.1.1 +deepspeed==0.15.0 +# Editable Git install with no remote (dialogpt==0.1) +-e /Users/wing/Projects/ml/dialogpt/src +dill==0.3.6 +distlib==0.3.6 +docker==7.0.0 +docker-pycreds==0.4.0 +docstring-parser==0.15 +docutils==0.16 +ecdsa==0.18.0 +einops==0.7.0 +einops-exts==0.0.4 +einx==0.1.3 +entrypoints==0.4 +eth-hash==0.6.0 +eth-keys==0.5.0 +eth-typing==4.0.0 +eth-utils==2.3.1 +evaluate==0.4.0 +exceptiongroup==1.1.1 +fastapi==0.109.2 +fastcore==1.5.29 +ffmpy==0.4.0 +filelock==3.12.2 +-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet +fire==0.5.0 +first==2.0.2 +flake8==7.0.0 +Flask==3.0.1 +fonttools==4.47.2 +frozendict==2.4.1 +frozenlist==1.3.3 +fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe +fsspec==2023.6.0 +fuzzywuzzy==0.18.0 +gitdb==4.0.10 +GitPython==3.1.31 +google-pasta==0.2.0 +gradio==4.42.0 +gradio_client==1.3.0 +greenlet==2.0.2 +grpclib==0.4.7 +gunicorn==21.2.0 +h11==0.14.0 +h2==4.1.0 +hpack==4.0.0 +httpcore==0.17.3 +httpx==0.24.1 +huggingface-hub==0.23.4 +humanfriendly==10.0 +hyperframe==6.0.1 +identify==2.5.24 +idna==3.4 +immutables==0.20 +importlib-metadata==6.7.0 +importlib-resources==6.1.1 +inflection==0.5.1 +iniconfig==2.0.0 +itsdangerous==2.1.2 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.3.2 +jsonlines==3.1.0 +jsonschema==2.6.0 +kiwisolver==1.4.5 +langchain==0.0.144 +Levenshtein==0.24.0 +libcst==1.1.0 +liger-kernel==0.0.0 +lion-pytorch==0.1.2 +llama-cpp-python==0.1.36 +llvmlite==0.40.1 +local-attention==1.9.0 +loguru==0.7.0 +Mako==1.3.2 +Markdown==3.5.2 +markdown-it-py==3.0.0 +markdown2==2.4.10 +MarkupSafe==2.1.2 +marshmallow==3.19.0 +marshmallow-enum==1.5.1 +matplotlib==3.8.2 +mccabe==0.7.0 +mdurl==0.1.2 +MEGABYTE-pytorch==0.0.7 +-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit +mlflow==2.10.0 +modal==0.62.77 +more-itertools==10.2.0 +mpmath==1.2.1 +msgpack==1.0.7 +msgpack-numpy-opentensor==0.5.0 +multidict==6.0.4 +multiprocess==0.70.14 +munch==2.5.0 +mypy==1.3.0 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +netaddr==0.10.1 +networkx==3.0rc1 +nh3==0.2.14 +nodeenv==1.8.0 +nomic==2.0.2 +numba==0.57.1 +numexpr==2.8.4 +numpy==1.24.4 +oauthlib==3.2.2 +openai==0.27.4 +openapi==1.1.0 +openapi-schema-pydantic==1.2.4 +optimum==1.8.6 +orjson==3.10.7 +packaging==23.1 +pandas==2.0.0 +parameterized==0.9.0 +password-strength==0.0.3.post2 +pastel==0.1.1 +pathos==0.3.0 +pathspec==0.11.1 +pathtools==0.1.2 +peft==0.11.1 +pendulum==3.0.0 +Pillow==9.5.0 +pip-tools==1.11.0 +platformdirs==3.2.0 +pluggy==1.4.0 +poetry==0.7.1 +pox==0.3.2 +ppft==1.7.6.6 +pre-commit==3.3.2 +prettytable==3.10.0 +prompt-toolkit==3.0.39 +protobuf==3.20.2 +protobuf3-to-dict==0.1.5 +psutil==5.9.5 +psycopg==3.1.18 +PuLP==2.8.0 +py==1.11.0 +py-bip39-bindings==0.1.11 +py-cpuinfo==9.0.0 +py-ed25519-zebra-bindings==1.0.1 +py-sr25519-bindings==0.2.0 +pyarrow==11.0.0 +pyasn1==0.6.0 +pycodestyle==2.11.1 +pycparser==2.21 +pycryptodome==3.20.0 +pydantic==2.5.3 +pydantic_core==2.14.6 +pydub==0.25.1 +pyfiglet==0.8.post1 +pyflakes==3.2.0 +Pygments==2.15.1 +PyJWT==2.8.0 +pylev==1.4.0 +PyNaCl==1.5.0 +pynvml==11.5.0 +pyparsing==2.4.7 +pyrsistent==0.14.11 +pytest==8.0.2 +pytest-asyncio==0.23.4 +python-dateutil==2.8.2 +python-dotenv==1.0.1 +python-Levenshtein==0.24.0 +python-multipart==0.0.9 +pytz==2023.3 +PyYAML==6.0.1 +querystring-parser==1.2.4 +rapidfuzz==3.6.1 +regex==2023.6.3 +requests==2.31.0 +requests-toolbelt==0.8.0 +resolvelib==0.8.1 +responses==0.18.0 +retry==0.9.2 +rich==13.7.0 +rsa==4.7.2 +ruff==0.6.3 +s3transfer==0.10.1 +safetensors==0.4.5 +sagemaker==2.148.0 +scalecodec==1.2.7 +schedulefree==1.2.1 +schema==0.7.5 +scikit-learn==1.4.0 +scipy==1.9.3 +seaborn==0.13.2 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==1.19.1 +setproctitle==1.3.2 +shellingham==1.5.4 +shortuuid==1.0.11 +shtab==1.6.5 +sigtools==4.0.1 +six==1.16.0 +skypilot==0.4.1 +smdebug-rulesconfig==1.0.1 +smmap==5.0.0 +sniffio==1.3.0 +SQLAlchemy==1.4.47 +sqlparse==0.4.4 +starlette==0.36.3 +substrate-interface==1.5.2 +svgwrite==1.4.3 +sympy==1.11.1 +synchronicity==0.6.7 +tabulate==0.9.0 +tblib==1.7.0 +tenacity==8.2.2 +tensor-parallel==2.0.0 +termcolor==2.2.0 +text2art==0.2.0 +threadpoolctl==3.2.0 +tiktoken==0.6.0 +time-machine==2.14.1 +timm==0.9.16 +tokenizers==0.19.1 +tokenmonster==1.1.12 +toml==0.9.6 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==2.2.0 +torchdata==0.6.1 +torchdiffeq==0.2.3 +TorchFix==0.4.0 +torchtext==0.15.2 +torchvision==0.17.0 +tqdm==4.66.2 +transformers==4.44.2 +trl==0.9.6 +typer==0.12.5 +types-certifi==2021.10.8.3 +types-requests==2.31.0.20240125 +types-setuptools==69.0.0.20240125 +types-toml==0.10.8.7 +typing==3.7.4.3 +typing-inspect==0.8.0 +typing_extensions==4.9.0 +tyro==0.5.18 +tzdata==2023.3 +unique-names-generator==1.0.2 +urllib3==2.2.2 +uvicorn==0.22.0 +vector_quantize_pytorch==1.14.1 +virtualenv==20.23.0 +voyager==2.0.2 +wandb==0.16.2 +watchfiles==0.21.0 +wavedrom==2.0.3.post3 +wcwidth==0.2.6 +websocket-client==1.7.0 +websockets==12.0 +Werkzeug==3.0.1 +wonderwords==2.2.0 +xxhash==3.2.0 +yarl==1.8.2 +zetascale==2.2.7 +zipp==3.15.0 diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index e12462c000..aab29e2670 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -27,6 +27,7 @@ register_chatml_template, register_llama3_template, ) +from axolotl.utils.trainer import disable_datasets_caching LOG = logging.getLogger("axolotl.cli.preprocess") @@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": - load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + with disable_datasets_caching(): + if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": + load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.download: model_name = parsed_cfg.base_model diff --git a/src/axolotl/core/chat/__init__.py b/src/axolotl/core/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/chat/format/__init__.py b/src/axolotl/core/chat/format/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py new file mode 100644 index 0000000000..315d101a86 --- /dev/null +++ b/src/axolotl/core/chat/format/chatml.py @@ -0,0 +1,34 @@ +""" +ChatML transformation functions for MessageContents +""" +from typing import Optional + +from ..messages import MessageContents, Messages +from .shared import wrap_tools + + +def format_message( + message: Messages, + message_index: Optional[int] = None, # pylint: disable=unused-argument +) -> Messages: + if message.is_chat_formatted: + return message + + # prepend the role prefix within a MessageContents to message.content + message.content.insert( + 0, + MessageContents( + type="text", + value=f"<|im_start|>{message.role}\n", + weight=0, + ), + ) + message.content.append( + MessageContents(type="text", value="<|im_end|>", weight=message.weight) + ) + message.content.append(MessageContents(type="text", value="\n", weight=0)) + + message = wrap_tools(message) + + message.is_chat_formatted = True + return message diff --git a/src/axolotl/core/chat/format/llama3x.py b/src/axolotl/core/chat/format/llama3x.py new file mode 100644 index 0000000000..17fa7aa8d4 --- /dev/null +++ b/src/axolotl/core/chat/format/llama3x.py @@ -0,0 +1,45 @@ +""" +Llama 3.x chat formatting functions for MessageContents +""" +from typing import Optional + +from ..messages import MessageContents, Messages +from .shared import wrap_tools + + +def format_message(message: Messages, message_index: Optional[int] = None) -> Messages: + if message.is_chat_formatted: + return message + + message_role = message.role + if message.role == "tool": + message_role = "ipython" + + # prepend the role prefix within a MessageContents to message.content + message.content.insert( + 0, + MessageContents( + type="text", + value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n", + weight=0, + ), + ) + + message.content.append( + MessageContents(type="text", value="<|eot_id|>", weight=message.weight) + ) + + message = wrap_tools(message) + + if message_index == 0: + message.content.insert( + 0, + MessageContents( + type="text", + value="<|begin_of_text|>", + weight=0, + ), + ) + + message.is_chat_formatted = True + return message diff --git a/src/axolotl/core/chat/format/shared.py b/src/axolotl/core/chat/format/shared.py new file mode 100644 index 0000000000..9efa2353db --- /dev/null +++ b/src/axolotl/core/chat/format/shared.py @@ -0,0 +1,47 @@ +""" +shared functions for format transforms +""" +from axolotl.core.chat.messages import MessageContents, Messages + + +def wrap_tools(message: Messages): + # loop over message.content by index to find tool calls, we need to wrap each with tags, + # so be wary of indexing issues when changing the list while iterating. + # iterate over the range in reverse order to avoid index shifting + for i in range(len(message.content) - 1, -1, -1): + if message.content[i].type == "tool_call": + # append a MessageContents text tag after + message.content.insert( + i + 1, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + # make sure the actual tool call content ends with a newline + message.content[i].has_newline = True + # prepend a MessageContents text tag before + message.content.insert( + i, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + elif message.content[i].type == "tool_response": + # append a MessageContents text tag after + message.content.insert( + i + 1, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + # make sure the actual tool response content ends with a newline + message.content[i].has_newline = True + # prepend a MessageContents text tag before + message.content.insert( + i, + MessageContents( + type="text", value="\n", weight=message.weight + ), + ) + + return message diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py new file mode 100644 index 0000000000..c879bf477b --- /dev/null +++ b/src/axolotl/core/chat/messages.py @@ -0,0 +1,230 @@ +""" +internal message representations of chat messages +""" +import json +from enum import Enum +from typing import Any, Callable, List, Optional, Union + +from pydantic import BaseModel +from transformers import PreTrainedTokenizer + + +class MessageRoles(str, Enum): + """ + Message roles for the system, user, assistant, and tools + """ + + system = "system" # pylint: disable=invalid-name + user = "user" # pylint: disable=invalid-name + assistant = "assistant" # pylint: disable=invalid-name + tool = "tool" # pylint: disable=invalid-name + ipython = ( # pylint: disable=invalid-name + # for responses from builtin tools + "ipython" + ) + + +class MessageContentTypes(str, Enum): + """ + Message content types for text, image, audio, tool calls, and tool responses + """ + + special_token = "special_token" # pylint: disable=invalid-name # nosec B105 + text = "text" # pylint: disable=invalid-name + image = "image" # pylint: disable=invalid-name + audio = "audio" # pylint: disable=invalid-name + tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant + tool_response = "tool_response" # pylint: disable=invalid-name + + +class SpecialToken(str, Enum): + """ + Special tokens for beginning of string and end of string + """ + + bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105 + eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105 + + +class ToolCallFunction(BaseModel): + """ + Tool call function with name and arguments + """ + + name: str + arguments: dict[str, str] + + +class Tool(BaseModel): + """ + Tool with description, function, and parameters + """ + + description: str + function: ToolCallFunction + parameters: dict[str, str] # .properties + + +class ToolCallContents(BaseModel): + """ + Tool call contents with name, arguments, and optional id + """ + + name: str + arguments: dict[str, Union[str, int]] + id: Optional[str] = None # pylint: disable=invalid-name + + def __str__(self) -> str: + data = {"name": self.name, "arguments": self.arguments} + if self.id is not None: + data["id"] = self.id + return json.dumps(data) + + +class ToolResponseContents(BaseModel): + """ + Tool response contents with name, content, and optional id + """ + + name: str + content: Union[str, dict[str, Union[str, int, float]]] + id: Optional[str] = None # pylint: disable=invalid-name + + def __str__(self) -> str: + data = {"name": self.name, "content": self.content} + if self.id is not None: + data["id"] = self.id + return json.dumps(data) + + +class MessageContents(BaseModel): + """ + Message contents with type, value, metadata, weight, newline, and end of contents + """ + + type: Union[str, MessageContentTypes] + value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken] + meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata + weight: Optional[Union[int, float]] = None + has_newline: bool = False + eoc: bool = False # end of contents + + def __str__(self) -> str: + str_val = str(self.value) + if self.has_newline and not str_val.endswith("\n"): + str_val += "\n" + return str_val + + +class Messages(BaseModel): + """ + Messages with role, content, metadata, weight, and chat formatting + """ + + role: Union[MessageRoles, str] # allows for arbitrary roles + content: List["MessageContents"] + meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata + weight: Optional[Union[int, float]] = None + is_chat_formatted: bool = False + + def __str__(self) -> str: + return "".join(str(c) for c in self.content) + + def tokenized( + self, tokenizer: PreTrainedTokenizer, ignore_index=-100 + ) -> dict[str, List[int]]: + # iterate over the contents, tokenizing the concatenated string values up to the current MessageContents + # returns a dictionary mapping w input_ids, attention_mask, and labels + input_ids: List[int] = [] + labels: List[int] = [] + pending_input_ids: List[int] = [] + pending_weight = self.weight + running_content = "" + for _, msg_content in enumerate(self.content): + # TODO also handle non-text content types + if msg_content.type in [ + MessageContentTypes.text.value, + MessageContentTypes.tool_call.value, + MessageContentTypes.tool_response.value, + ]: + running_content += str(msg_content) + tok_results = tokenizer(running_content, add_special_tokens=False) + tok_input_ids = tok_results["input_ids"] + if pending_input_ids: + new_pending_inputs = tok_input_ids[ + len(input_ids) : len(input_ids) + len(pending_input_ids) + ] + if new_pending_inputs != pending_input_ids: + # logging.warning("tokenization mismatch from concatenation.") + pending_input_ids = new_pending_inputs + input_ids.extend(pending_input_ids) + if pending_weight: + labels.extend(pending_input_ids) + else: + labels.extend([ignore_index] * len(pending_input_ids)) + pending_input_ids = tok_results["input_ids"][len(input_ids) :] + pending_weight = self.weight and msg_content.weight not in [0, 0.0] + input_ids.extend(pending_input_ids) + if pending_weight: + labels.extend(pending_input_ids) + else: + labels.extend([ignore_index] * len(pending_input_ids)) + attention_mask = [1] * len(input_ids) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +class Chats(BaseModel): + """ + top level data structure for chat conversations + """ + + conversation: List[Messages] + + def __str__(self) -> str: + return "".join(str(c) for c in self.conversation) + + def tokenized( + self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100 + ) -> dict[str, List[int]]: + input_ids = [] + attention_mask = [] + labels = [] + for msg in self.conversation: + msg_results = msg.tokenized(tokenizer, ignore_index) + input_ids.extend(msg_results["input_ids"]) + attention_mask.extend(msg_results["attention_mask"]) + labels.extend(msg_results["labels"]) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +class ChatFormattedChats(Chats): + """ + Chat formatted chats with formatter and optional train on inputs + """ + + formatter: Callable # [[Union[dict, Chats]], Chats] + train_on_inputs: bool = False + + def model_post_init(self, __context): + for i, msg in enumerate(self.conversation): + self.conversation[i] = self.formatter(msg, message_index=i) + if self.train_on_inputs: + self.conversation[i].weight = 1 + + +class PreferenceChats(BaseModel): + """ + representation for preference data for chat + """ + + prompt: List[Messages] + chosen: Messages + rejected: Messages diff --git a/src/axolotl/core/datasets/__init__.py b/src/axolotl/core/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py new file mode 100644 index 0000000000..e74c247d2c --- /dev/null +++ b/src/axolotl/core/datasets/chat.py @@ -0,0 +1,55 @@ +""" +chat dataset module +""" +import os +from typing import Callable, Optional, Union + +from datasets import Dataset +from transformers import PreTrainedTokenizer + +from axolotl.core.chat.messages import ChatFormattedChats + + +class TokenizedChatDataset(Dataset): + """ + Tokenized chat dataset + """ + + def __init__( + self, + data: Dataset, + model_transform: Union[PreTrainedTokenizer, Callable], + *args, + message_transform: Optional[Callable] = None, + formatter=None, + process_count: Optional[int] = None, + keep_in_memory: Optional[bool] = False, + **kwargs, + ): + def map_fn(ex): + if message_transform is not None: + ex = message_transform(ex) + if formatter is not None: + ex = ChatFormattedChats( + formatter=formatter, + **ex, + ) + else: + ex = ChatFormattedChats( + **ex, + ) + return ex.tokenized(model_transform) + + process_or_cpu_count: int = ( + process_count or os.cpu_count() # type: ignore[assignment] + ) + num_proc = min(64, process_or_cpu_count) + features = data.features.keys() + tokenized_data = data.map( + map_fn, + num_proc=num_proc, + keep_in_memory=keep_in_memory, + remove_columns=features, + desc="Tokenizing Chats", + ) + super().__init__(tokenized_data.data, *args, **kwargs) diff --git a/src/axolotl/core/datasets/transforms/__init__.py b/src/axolotl/core/datasets/transforms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py new file mode 100644 index 0000000000..98d5f171a7 --- /dev/null +++ b/src/axolotl/core/datasets/transforms/chat_builder.py @@ -0,0 +1,150 @@ +""" +This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. +""" +from typing import Any, Mapping, Union + + +def chat_message_transform_builder( # pylint: disable=dangerous-default-value + train_on_inputs=False, + conversations_field: str = "conversations", + message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role" + message_field_content: Union[str, list[str]] = [ + "value", + "text", + "content", + ], # commonly "content" + message_field_training: Union[str, list[str]] = [ + "train", + "weight", + ], # commonly "weight" +): + """Builds a transform that takes a row from the dataset and converts it to a Chat + + Args: + train_on_inputs (bool, optional): + If True, the transform will train on the inputs. If False, the transform will train on the targets. + Defaults to False. + conversations_field (str, optional): + The field name of the conversations. Defaults to "conversations". + message_field_role (str | list[str], optional): + The field name of the role. Defaults to "role". + message_field_content (str | list[str], optional): + The field name of the message content. Defaults to "content". + message_field_training (str | list[str], optional): + The field name of the train/weight. Defaults to "weight". + + Returns: + Callable: + A function that takes a list of conversations and returns a list of messages. + """ + + message_field_role = ( + [message_field_role] + if isinstance(message_field_role, str) + else message_field_role + ) + message_field_content = ( + [message_field_content] + if isinstance(message_field_content, str) + else message_field_content + ) + message_weight_fields = ( + [message_field_training] + if isinstance(message_field_training, str) + else message_field_training + ) + + role_value_mappings = { + "system": "system", + "user": "user", + "human": "user", + "assistant": "assistant", + "gpt": "assistant", + "tool": "tool", + "ipython": "ipython", + } + if train_on_inputs: + role_default_weights_mappings = { + "system": 1, + "user": 1, + "assistant": 1, + "tool": 1, + "ipython": 1, + } + else: + role_default_weights_mappings = { + "system": 0, + "user": 0, + "assistant": 1, + "tool": 0, + "ipython": 0, + } + + def transform_builder(sample: Mapping[str, Any]): + if conversations_field not in sample: + raise ValueError(f"Field '{conversations_field}' not found in sample.") + # if none of the role fields are in the message, raise an error + if not any( + role in sample[conversations_field][0] for role in message_field_role + ): + raise ValueError("No role field found in message.") + role_field = next( + role + for role in message_field_role + if role in sample[conversations_field][0] + ) + if not any( + field in sample[conversations_field][0] for field in message_field_content + ): + raise ValueError("No message_content field found in message.") + message_content_field = next( + field + for field in message_field_content + if field in sample[conversations_field][0] + ) + if not any( + field in sample[conversations_field][0] for field in message_field_training + ): + message_weight_field = None + else: + message_weight_field = next( + field + for field in message_weight_fields + if field in sample[conversations_field][0] + ) + + messages = [] + for message in sample[conversations_field]: + role = role_value_mappings[message[role_field]] + weight = ( + int(message[message_weight_field]) + if message_weight_field + else role_default_weights_mappings[role] + ) + + # TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents + if isinstance(message[message_content_field], str): + messages.append( + { + "role": role, + "content": [ + { + "type": "text", + "value": message[message_content_field], + } + ], + "weight": weight, + } + ) + else: + messages.append( + { + "role": role, + "content": message[message_content_field], + "weight": weight, + } + ) + + return {"conversation": messages} + + return transform_builder diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 66cd5deeb9..74da20c5e1 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -11,6 +11,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): try: + if strategy == "messages": + from .messages import load as messages_load + + return messages_load(tokenizer, cfg, ds_cfg, processor=processor) load_fn = "load" if strategy.split(".")[-1].startswith("load_"): load_fn = strategy.split(".")[-1] @@ -31,4 +35,5 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return None except Exception as exc: # pylint: disable=broad-exception-caught LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") - return None + raise exc + return None diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py new file mode 100644 index 0000000000..d014d93a6b --- /dev/null +++ b/src/axolotl/prompt_strategies/messages/__init__.py @@ -0,0 +1,34 @@ +"""Module to load message prompt strategies.""" + +import importlib +import inspect +import logging + +LOG = logging.getLogger("axolotl.prompt_strategies.messages") + + +def load(tokenizer, cfg, ds_cfg, processor=None): + try: + strategy = ds_cfg.get("input_transform", "chat") + # pylint: disable=duplicate-code + load_fn = "load" + if strategy.split(".")[-1].startswith("load_"): + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module( + f".{strategy}", "axolotl.prompt_strategies.messages" + ) + func = getattr(mod, load_fn) + load_kwargs = {} + sig = inspect.signature(func) + if "ds_cfg" in sig.parameters: + load_kwargs["ds_cfg"] = ds_cfg + if "processor" in sig.parameters: + load_kwargs["processor"] = processor + return func(tokenizer, cfg, **load_kwargs) + except ModuleNotFoundError: + return None + except Exception as exc: # pylint: disable=broad-exception-caught + LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") + raise exc + return None diff --git a/src/axolotl/prompt_strategies/messages/chat.py b/src/axolotl/prompt_strategies/messages/chat.py new file mode 100644 index 0000000000..35d7649026 --- /dev/null +++ b/src/axolotl/prompt_strategies/messages/chat.py @@ -0,0 +1,84 @@ +""" +Chat dataset wrapping strategy for new internal messages representations +""" +from typing import Any, Callable, Dict, Optional + +from axolotl.core.datasets.chat import TokenizedChatDataset +from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder +from axolotl.prompt_tokenizers import DatasetWrappingStrategy + + +class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy): + """ + Chat dataset wrapping strategy for new internal messages representations + """ + + def __init__( + self, + processor, + message_transform=None, + formatter=None, + **kwargs, # pylint: disable=unused-argument + ): + """ + :param processor: tokenizer or image processor + :param kwargs: + """ + self.processor = processor + self.dataset = None + self.message_transform = message_transform + self.formatter = formatter + + def wrap_dataset( + self, + dataset, + process_count: Optional[int] = None, + keep_in_memory: Optional[bool] = False, + **kwargs, # pylint: disable=unused-argument + ): + self.dataset = TokenizedChatDataset( + dataset, + message_transform=self.message_transform, + model_transform=self.processor, + formatter=self.formatter, + process_count=process_count, + keep_in_memory=keep_in_memory, + ) + return self.dataset + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + ds_cfg = ds_cfg or {} + + field_messages = ds_cfg.get("field_messages") + message_field_role = ds_cfg.get("message_field_role") + message_field_content = ds_cfg.get("message_field_content") + message_field_training = ds_cfg.get("message_field_training") + + builder_kwargs = {} + if field_messages: + builder_kwargs["conversations_field"] = field_messages + if message_field_role: + builder_kwargs["message_field_role"] = message_field_role + if message_field_content: + builder_kwargs["message_field_content"] = message_field_content + if message_field_training: + builder_kwargs["message_field_training"] = message_field_training + + chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml")) + format_message = ( + lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment + ) + if chat_template == "chatml": + from axolotl.core.chat.format.chatml import format_message # noqa F811 + if chat_template.startswith("llama3"): + from axolotl.core.chat.format.llama3x import format_message # noqa F811 + message_transform: Callable = chat_message_transform_builder( + train_on_inputs=ds_cfg.get("train_on_inputs", False), + **builder_kwargs, + ) + strategy = ChatMessageDatasetWrappingStrategy( + tokenizer, message_transform=message_transform, formatter=format_message + ) + + return strategy diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 11dd084a85..51d497a23c 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -30,6 +30,12 @@ class InvalidDataException(Exception): """ +class DatasetWrappingStrategy(abc.ABC): + """ + Abstract class for wrapping datasets for Chat Messages + """ + + class PromptTokenizingStrategy(abc.ABC): """ Abstract class for tokenizing strategies diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 40f4a36abb..3304c62f28 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -102,10 +102,12 @@ class SFTDataset(BaseModel): path: Optional[str] = None split: Optional[str] = None type: Optional[Union[str, UserDefinedPrompterType]] = None + input_transform: Optional[str] = None shards: Optional[int] = None conversation: Optional[str] = None chat_template: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None + input_format: Optional[str] = None name: Optional[str] = None ds_type: Optional[str] = None train_on_split: Optional[str] = None diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 39eb2c4e04..163059c2b8 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -23,6 +23,7 @@ AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, AlpacaReflectionPTStrategy, + DatasetWrappingStrategy, GPTeacherPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, @@ -573,7 +574,7 @@ def get_dataset_wrapper( d_base_type, dataset, d_prompt_style=None, - processor=None, + processor=None, # pylint: disable=unused-argument ): dataset_wrapper = None dataset_prompter = None @@ -608,15 +609,16 @@ def get_dataset_wrapper( ) elif cfg.skip_prepare_dataset: dataset_wrapper = dataset - elif ds_strategy := load( - config_dataset.type, tokenizer, cfg, config_dataset, processor=processor - ): - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, - dataset, - **ds_kwargs, - ) + elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): + if isinstance(ds_strategy, DatasetWrappingStrategy): + dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) + else: + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, + dataset, + **ds_kwargs, + ) elif d_base_type == "alpaca": dataset_prompter = AlpacaPrompter(d_prompt_style) ds_strategy = AlpacaPromptTokenizingStrategy( diff --git a/tests/core/chat/__init__.py b/tests/core/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/chat/format/__init__.py b/tests/core/chat/format/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/chat/test_messages.py b/tests/core/chat/test_messages.py new file mode 100644 index 0000000000..b3be56c590 --- /dev/null +++ b/tests/core/chat/test_messages.py @@ -0,0 +1,197 @@ +""" +Tests for the chat messages module +""" +import unittest + +import pytest +from transformers import AddedToken, AutoTokenizer + +from axolotl.core.chat.format.chatml import format_message +from axolotl.core.chat.messages import ChatFormattedChats, Chats + + +@pytest.fixture(scope="session", name="llama_tokenizer") +def llama_tokenizer_fixture(): + return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B") + + +@pytest.fixture(scope="session", name="chatml_tokenizer") +def llama_tokenizer_w_chatml(llama_tokenizer): + llama_tokenizer.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + llama_tokenizer.add_tokens( + [ + AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), + ] + ) + + return llama_tokenizer + + +@pytest.fixture(scope="session", name="chat_msgs") +def chat_msgs_fixture(): + return { + "conversation": [ + { + "role": "system", + "content": [ + {"type": "text", "value": "You are a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "value": "What is today's stock price of Apple?"}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_call", + "value": { + "name": "get_date", + "arguments": {}, + }, + }, + { + "type": "tool_call", + "value": { + "name": "get_stock_price", + "arguments": {"symbol": "AAPL"}, + }, + }, + ], + "weight": 1, + }, + { + "role": "tool", + "content": [ + { + "type": "tool_response", + "value": { + "name": "get_date", + "content": {"date": "2024-09-09"}, + }, + }, + { + "type": "tool_response", + "value": { + "name": "get_stock_price", + "content": {"symbol": "AAPL", "price": 123.45}, + }, + }, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "value": "The stock price of Apple is $123.45.\n", + "weight": 0, + }, + { + "type": "text", + "value": "The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.", + }, + { + "type": "text", + "value": "The stock price of Apple on September 9, 2024 is $123.45.", + }, + ], + "weight": 1, + }, + ] + } + + +class TestMessagesCase: + """ + Test cases for the chat messages module + """ + + def test_tool_call_stringify(self, chat_msgs): + chat_msgs_as_obj = Chats(**chat_msgs) + assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str( + chat_msgs_as_obj.conversation[2].content[1].value + ) + + def test_chatml_formatted_wrapper(self, chat_msgs): + chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message) + target_chatml = """<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +What is today's stock price of Apple?<|im_end|> +<|im_start|>assistant + +{"name": "get_date", "arguments": {}} + + +{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}} + +<|im_end|> +<|im_start|>tool + +{"name": "get_date", "content": {"date": "2024-09-09"}} + + +{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}} + +<|im_end|> +<|im_start|>assistant +The stock price of Apple is $123.45. +The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n""" + assert target_chatml == str(chat_msg_formatted) + + def test_chatml_formatting_tool_call(self, chat_msgs): + chat_msgs_as_obj = Chats(**chat_msgs) + target_chatml_turn2 = """<|im_start|>assistant\n\n{"name": "get_date", "arguments": {}}\n\n\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n\n<|im_end|>\n""" + assert target_chatml_turn2 == str( + format_message(chat_msgs_as_obj.conversation[2]) + ) + + def test_train_labels(self, chatml_tokenizer, chat_msgs): + chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message) + tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer) + # fmt: off + target_labels = [ + -100, -100, -100, # role + 27, 14506, 13735, 397, 5018, 609, 794, + 330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524, + 14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794, + 330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314, + 794, 330, 84016, 43, 96742, 524, 14506, 13735, 397, + 128256, # <|im_end|> + -100 # trailing newline + ] + # fmt: on + assert tokenized["labels"] == target_labels + + def test_train_labels_2(self, chatml_tokenizer, chat_msgs): + # also test if indivudal contents are set not to train + chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message) + tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer) + # fmt: off + target_labels = [ + -100, -100, -100, # role + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response + 27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430, + 315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457, + 5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315, + 8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400, + 4513, 13, 1774, 13, + 128256, # <|im_end|> + -100, # trailing newline + ] + # fmt: on + assert tokenized["labels"] == target_labels + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/messages/__init__.py b/tests/prompt_strategies/messages/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py new file mode 100644 index 0000000000..96c4b6cbbf --- /dev/null +++ b/tests/prompt_strategies/messages/test_chat.py @@ -0,0 +1,62 @@ +""" +tests for chat_template prompt strategy +""" +# pylint: disable=duplicate-code +import logging +import unittest + +from axolotl.prompt_strategies.messages.chat import load +from axolotl.utils.dict import DictDefault + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +class TestMessagesChatLlama3: + """ + Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy. + """ + + def test_llama3_load(self, llama3_tokenizer, assistant_dataset): + LOG.info("Loading llama-3 tokenizer with assistant dataset") + strategy = load( + llama3_tokenizer, + DictDefault( + { + "train_on_inputs": False, + "sequence_len": 512, + } + ), + DictDefault( + { + "chat_template": "llama3", + "message_field_role": "role", + "message_field_content": "content", + "field_messages": "messages", + } + ), + ) + res = strategy.wrap_dataset(assistant_dataset) + input_ids = res[0]["input_ids"] + # fmt: off + expected_input_ids = [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + +if __name__ == "__main__": + unittest.main()