Skip to content

Commit

Permalink
support user defined prompters, pretokenized datasets in config, loca…
Browse files Browse the repository at this point in the history
…l parquet, local arrow files
  • Loading branch information
winglian committed Aug 7, 2023
1 parent 10405b9 commit 2d10911
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 7 deletions.
7 changes: 5 additions & 2 deletions src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import importlib


def load(strategy, tokenizer, cfg):
def load(strategy, tokenizer, cfg, ds_cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
func = getattr(mod, load_fn)
return func(tokenizer, cfg)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None
67 changes: 67 additions & 0 deletions src/axolotl/prompt_strategies/user_defined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
User Defined prompts with configuration from the YML config
"""

from typing import Tuple

from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)


class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Prompt Tokenization Strategy for user defined prompts
"""


class UserDefinedPrompter(SystemDataPrompter):
"""
Prompter for user defined prompts
"""


def load(tokenizer, cfg, ds_cfg=None):
if not ds_cfg:
raise ValueError("Missing dataset prompt configuration")

system_prompt = ""
if ds_cfg["system_prompt"] and not ds_cfg["field_system"]:
system_prompt = ds_cfg["system_prompt"]

def parse_instruction_fields(
self, prompt # pylint: disable=unused-argument
) -> Tuple[str, str, str, str]:
return (
prompt[ds_cfg["field_instruction"]],
prompt[ds_cfg["field_input"]]
if ds_cfg["field_input"] and ds_cfg["field_input"] in prompt
else "",
prompt[ds_cfg["field_output"]],
prompt[ds_cfg["field_system"]]
if ds_cfg["field_system"] and ds_cfg["field_system"] in prompt
else system_prompt,
)

def match_prompt_style(self):
self.turn_format = ds_cfg["format"]
self.turn_no_input_format = (
ds_cfg["no_input_format"]
if "no_input_format" in ds_cfg
else ds_cfg["format"]
)
self.system_format = ds_cfg["system_format"]

prompter = UserDefinedPrompter()
prompter.match_prompt_style = match_prompt_style

strat = UserDefinedPromptTokenizationStrategy(
prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

strat.parse_instruction_fields = parse_instruction_fields
return strat
31 changes: 26 additions & 5 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,15 @@ def load_tokenized_prepared_datasets(
split=None,
)
elif local_path.is_file():
ds_type = "json"
if d.ds_type:
ds_type = d.ds_type
elif d.data_files and ".parquet" in d.data_files[0]:
ds_type = "parquet"
elif d.data_files and ".arrow" in d.data_files[0]:
ds_type = "arrow"
ds = load_dataset(
"json",
ds_type,
name=d.name,
data_files=d.path,
streaming=False,
Expand Down Expand Up @@ -155,13 +162,27 @@ def load_tokenized_prepared_datasets(
)
else:
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)

d_base_type = d_prompt_style = None
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if isinstance(d_type, str):
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
if ds_strategy := load(d.type, tokenizer, cfg):
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, object):
ds_strategy = load("user_defined", tokenizer, cfg, d.type)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
Expand Down

0 comments on commit 2d10911

Please sign in to comment.