From 363f7e0abe379f34d09ec9a0a3fc0bc9889cf6d3 Mon Sep 17 00:00:00 2001 From: Thomas Cleberg Date: Thu, 12 Sep 2024 16:34:58 -0500 Subject: [PATCH 1/6] Add support for `revision` dataset parameter --- docs/config.qmd | 1 + src/axolotl/utils/data/sft.py | 10 +++++- tests/test_datasets.py | 67 +++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index 99a69a097..8329f3553 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -90,6 +90,7 @@ datasets: shards: # Optional[int] number of shards to split data into name: # Optional[str] name of dataset configuration to load train_on_split: train # Optional[str] name of dataset split to load from + revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 7d6922cbf..f46152a7c 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -242,6 +242,7 @@ def for_d_in_datasets(dataset_configs): name=config_dataset.name, streaming=True, token=use_auth_token, + revision=config_dataset.revision, ) ds_from_hub = True except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): @@ -319,6 +320,7 @@ def for_d_in_datasets(dataset_configs): data_files=config_dataset.data_files, streaming=False, split=None, + revision=config_dataset.revision, ) else: ds = load_from_disk(config_dataset.path) @@ -331,6 +333,7 @@ def for_d_in_datasets(dataset_configs): data_files=config_dataset.path, streaming=False, split=None, + revision=config_dataset.revision, ) else: raise ValueError( @@ -346,6 +349,7 @@ def for_d_in_datasets(dataset_configs): streaming=False, data_files=config_dataset.data_files, token=use_auth_token, + revision=config_dataset.revision, **load_ds_kwargs, ) elif ds_from_cloud and remote_file_system: @@ -363,6 +367,7 @@ def for_d_in_datasets(dataset_configs): streaming=False, split=None, storage_options=storage_options, + revision=config_dataset.revision, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -373,6 +378,7 @@ def for_d_in_datasets(dataset_configs): streaming=False, split=None, storage_options=storage_options, + revision=config_dataset.revision, ) else: if isinstance(config_dataset.data_files, str): @@ -380,6 +386,7 @@ def for_d_in_datasets(dataset_configs): repo_id=config_dataset.path, repo_type="dataset", filename=config_dataset.data_files, + revision=config_dataset.revision, ) elif isinstance(config_dataset.data_files, list): fp = [] @@ -389,6 +396,7 @@ def for_d_in_datasets(dataset_configs): repo_id=config_dataset.path, repo_type="dataset", filename=file, + revision=config_dataset.revision, ) ) else: @@ -433,8 +441,8 @@ def for_d_in_datasets(dataset_configs): config_dataset=config_dataset, tokenizer=tokenizer, cfg=cfg, - dataset=ds, d_base_type=d_base_type, + dataset=ds, d_prompt_style=d_prompt_style, processor=processor, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a274b7b89..5a631b2e6 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -267,6 +267,73 @@ def test_load_from_single_json(self): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_revision(self): + """Verify that processing data from the hub works with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "revision": "foo", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + def test_load_local_hub_with_revision(self): + """Verify that a local copy of a hub dataset can be loaded with a specific revision""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") + tmp_ds_path.mkdir(parents=True, exist_ok=True) + snapshot_download( + repo_id="mhenrichsen/alpaca_2k_test", + repo_type="dataset", + local_dir=tmp_ds_path, + revision="foo", + ) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "ds_type": "parquet", + "type": "alpaca", + "data_files": [ + "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", + ], + "revision": "foo", + }, + ], + } + ) + + dataset, _ = load_tokenized_prepared_datasets( + self.tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) if __name__ == "__main__": unittest.main() From 01aed9d0e4636b6f9a77257f820e68204f18fb54 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 14 Sep 2024 13:03:36 -0400 Subject: [PATCH 2/6] only use revision on hf hub backed datasets --- src/axolotl/utils/data/sft.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index f46152a7c..39eb2c4e0 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -320,7 +320,6 @@ def for_d_in_datasets(dataset_configs): data_files=config_dataset.data_files, streaming=False, split=None, - revision=config_dataset.revision, ) else: ds = load_from_disk(config_dataset.path) @@ -333,7 +332,6 @@ def for_d_in_datasets(dataset_configs): data_files=config_dataset.path, streaming=False, split=None, - revision=config_dataset.revision, ) else: raise ValueError( @@ -367,7 +365,6 @@ def for_d_in_datasets(dataset_configs): streaming=False, split=None, storage_options=storage_options, - revision=config_dataset.revision, ) elif config_dataset.path.startswith("https://"): ds_type = get_ds_type(config_dataset) @@ -378,7 +375,6 @@ def for_d_in_datasets(dataset_configs): streaming=False, split=None, storage_options=storage_options, - revision=config_dataset.revision, ) else: if isinstance(config_dataset.data_files, str): From b29ccb8588a1f1839f8ee36f5a28c97063ed22bc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 14 Sep 2024 13:22:09 -0400 Subject: [PATCH 3/6] use revision tied to head --- tests/test_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5a631b2e6..845d7345b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -279,7 +279,7 @@ def test_load_hub_with_revision(self): { "path": "mhenrichsen/alpaca_2k_test", "type": "alpaca", - "revision": "foo", + "revision": "d05c1cb", }, ], } @@ -319,7 +319,7 @@ def test_load_local_hub_with_revision(self): "data_files": [ "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", ], - "revision": "foo", + "revision": "d05c1cb", }, ], } @@ -335,5 +335,6 @@ def test_load_local_hub_with_revision(self): assert "labels" in dataset.features shutil.rmtree(tmp_ds_path) + if __name__ == "__main__": unittest.main() From a41f8d3b65a3f084b5e9d434539466792a5caaee Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 14 Sep 2024 23:10:57 -0400 Subject: [PATCH 4/6] set download to use revision --- tests/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 845d7345b..ce91da54e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -303,7 +303,7 @@ def test_load_local_hub_with_revision(self): repo_id="mhenrichsen/alpaca_2k_test", repo_type="dataset", local_dir=tmp_ds_path, - revision="foo", + revision="d05c1cb", ) prepared_path = Path(tmp_dir) / "prepared" From 744b993b07e7b5dcd56843a63d97055aaff07d41 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 11 Oct 2024 22:12:47 +0700 Subject: [PATCH 5/6] feat: add config to model validator class --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 47796add6..1c33b5907 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -125,6 +125,7 @@ class SFTDataset(BaseModel): drop_system_message: Optional[bool] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class UserDefinedDPOType(BaseModel): @@ -146,6 +147,7 @@ class DPODataset(BaseModel): split: Optional[str] = None type: Optional[Union[UserDefinedDPOType, str]] = None data_files: Optional[List[str]] = None + revision: Optional[str] = None class UserDefinedKTOType(BaseModel): @@ -167,6 +169,7 @@ class KTODataset(BaseModel): type: Optional[Union[UserDefinedKTOType, str]] = None data_files: Optional[List[str]] = None trust_remote_code: Optional[bool] = False + revision: Optional[str] = None class RLType(str, Enum): From 48836aee8326b6e8c285320fe1373ef2e481482c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 11 Oct 2024 22:15:08 +0700 Subject: [PATCH 6/6] feat: add revision config to RL and tests for it --- src/axolotl/utils/data/rl.py | 1 + tests/test_datasets.py | 70 ++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index d0324e1eb..35bd5fcbb 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -90,6 +90,7 @@ def load_split(dataset_cfgs, _cfg): ds = load_dataset( # pylint: disable=invalid-name ds_cfg["path"], split=ds_cfg["split"], + revision=ds_cfg.get("revision", None), ) split_datasets.insert(i, ds) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ce91da54e..f8b463a03 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,6 +12,7 @@ from transformers import AutoTokenizer from axolotl.utils.data import load_tokenized_prepared_datasets +from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.dict import DictDefault @@ -267,6 +268,40 @@ def test_load_from_single_json(self): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_dpo(self): + """Verify that processing dpo data from the hub works""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + def test_load_hub_with_revision(self): """Verify that processing data from the hub works with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: @@ -294,6 +329,41 @@ def test_load_hub_with_revision(self): assert "attention_mask" in dataset.features assert "labels" in dataset.features + def test_load_hub_with_revision_with_dpo(self): + """Verify that processing dpo data from the hub works with a specific revision""" + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "revision": "ea82cff", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + } + ], + } + ) + + train_dataset, _ = load_prepare_dpo_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features + def test_load_local_hub_with_revision(self): """Verify that a local copy of a hub dataset can be loaded with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: