Skip to content

Commit

Permalink
Add a config not to shuffle merged dataset (axolotl-ai-cloud#1394) [s…
Browse files Browse the repository at this point in the history
…kip ci]

* Add a config not to shuffle merged dataset

* Update README.md

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* invert the condition name

* update README

* info -> debug

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
seungduk-yanolja and winglian committed Mar 19, 2024
1 parent c0b75d5 commit 5eebec1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,10 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:

# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true

# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ class Config:

datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
Expand Down
13 changes: 10 additions & 3 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,11 @@ def for_d_in_datasets(dataset_configs):
dataset = concatenate_datasets(datasets)

if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
if cfg.shuffle_merged_datasets:
LOG.debug("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
else:
LOG.debug("NOT shuffling merged datasets")

dataset, _ = process_datasets_for_packing(cfg, dataset, None)

Expand Down Expand Up @@ -847,7 +850,11 @@ def wrap_pretraining_dataset(
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
LOG.debug("NOT shuffling merged pretraining datasets")

dataset = dataset.map(
encode,
batched=True,
Expand Down

0 comments on commit 5eebec1

Please sign in to comment.