Skip to content

Commit

Permalink
fix user defined dataset types
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Aug 16, 2023
1 parent e99170a commit 0650a0f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 30 deletions.
4 changes: 3 additions & 1 deletion src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import importlib

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig


def load(strategy, tokenizer, cfg, ds_cfg):
try:
Expand All @@ -13,7 +15,7 @@ def load(strategy, tokenizer, cfg, ds_cfg):
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = ds_cfg
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None
85 changes: 58 additions & 27 deletions src/axolotl/prompt_strategies/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,79 @@
User Defined prompts with configuration from the YML config
"""

from typing import Tuple
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple

from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)


class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
@dataclass
class UserDefinedDatasetConfig:
"""
Prompt Tokenization Strategy for user defined prompts
dataclass configuration representing a userdefined dataset type
"""

system_prompt: str = ""
field_system: str = "system"
field_instruction: str = "instruction"
field_input: str = "input"
field_output: str = "output"
format: str = "{instruction} {input} "
no_input_format: str = "{instruction} "
system_format: str = "{system}"

def __getitem__(self, item):
return getattr(self, item)

class UserDefinedPrompter(SystemDataPrompter):

class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Prompter for user defined prompts
Prompt Tokenization Strategy for user defined prompts
"""


def load(tokenizer, cfg, ds_cfg=None):
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
if not ds_cfg:
raise ValueError("Missing dataset prompt configuration")

system_prompt = ""
if ds_cfg["system_prompt"] and not ds_cfg["field_system"]:
system_prompt = ds_cfg["system_prompt"]
if ds_cfg.system_prompt and not ds_cfg.field_system:
system_prompt = ds_cfg.system_prompt

def parse_instruction_fields(
self, prompt # pylint: disable=unused-argument
field_instruction,
field_input,
field_output,
field_system,
system_prompt,
prompt,
) -> Tuple[str, str, str, str]:
return (
prompt[ds_cfg["field_instruction"]],
prompt[ds_cfg["field_input"]]
if ds_cfg["field_input"] and ds_cfg["field_input"] in prompt
else "",
prompt[ds_cfg["field_output"]],
prompt[ds_cfg["field_system"]]
if ds_cfg["field_system"] and ds_cfg["field_system"] in prompt
else system_prompt,
prompt[field_instruction],
prompt[field_input] if field_input in prompt else "",
prompt[field_output] if field_output in prompt else "",
prompt[field_system] if field_system in prompt else system_prompt,
)

def match_prompt_style(self):
self.turn_format = ds_cfg["format"]
self.turn_no_input_format = (
ds_cfg["no_input_format"]
if "no_input_format" in ds_cfg
else ds_cfg["format"]
)
self.system_format = ds_cfg["system_format"]
turn_format = ds_cfg.format
turn_no_input_format = ds_cfg.no_input_format
system_format = ds_cfg.system_format

class UserDefinedPrompter(SystemDataPrompter):
"""
Prompter for user defined prompts
"""

def match_prompt_style(self):
self.turn_format = turn_format
self.turn_no_input_format = turn_no_input_format
self.system_format = system_format

prompter = UserDefinedPrompter()
prompter.match_prompt_style = match_prompt_style

strat = UserDefinedPromptTokenizationStrategy(
prompter,
Expand All @@ -63,5 +83,16 @@ def match_prompt_style(self):
cfg.sequence_len,
)

strat.parse_instruction_fields = parse_instruction_fields
setattr(
strat,
"parse_instruction_fields",
partial(
parse_instruction_fields,
ds_cfg.field_instruction,
ds_cfg.field_input,
ds_cfg.field_output,
ds_cfg.field_system,
system_prompt,
),
)
return strat
5 changes: 3 additions & 2 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
Expand Down Expand Up @@ -221,8 +222,8 @@ def load_tokenized_prepared_datasets(
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, object):
ds_strategy = load("user_defined", tokenizer, cfg, d.type)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
Expand Down

0 comments on commit 0650a0f

Please sign in to comment.