Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for revision Dataset Parameter to specify reading from Huggingface Dataset Revision #1912

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class SFTDataset(BaseModel):
drop_system_message: Optional[bool] = None

trust_remote_code: Optional[bool] = False
revision: Optional[str] = None


class UserDefinedDPOType(BaseModel):
Expand All @@ -146,6 +147,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None


class UserDefinedKTOType(BaseModel):
Expand All @@ -167,6 +169,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None


class RLType(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def load_split(dataset_cfgs, _cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)

Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def for_d_in_datasets(dataset_configs):
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
Expand Down Expand Up @@ -346,6 +347,7 @@ def for_d_in_datasets(dataset_configs):
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
Expand Down Expand Up @@ -380,6 +382,7 @@ def for_d_in_datasets(dataset_configs):
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
Expand All @@ -389,6 +392,7 @@ def for_d_in_datasets(dataset_configs):
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
Expand Down Expand Up @@ -433,8 +437,8 @@ def for_d_in_datasets(dataset_configs):
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style,
processor=processor,
)
Expand Down
138 changes: 138 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoTokenizer

from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault


Expand Down Expand Up @@ -267,6 +268,143 @@ def test_load_from_single_json(self):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features

def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""

cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)

train_dataset, _ = load_prepare_dpo_datasets(cfg)

assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features

def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)

dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)

assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features

def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""

cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)

train_dataset, _ = load_prepare_dpo_datasets(cfg)

assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features

def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)

prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)

dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)

assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)


if __name__ == "__main__":
unittest.main()
Loading