Skip to content

Commit

Permalink
fix mllama patch and don't save prepared ds when skipping
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 2, 2024
1 parent 2835f0e commit 5e491bc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions src/axolotl/monkeypatch/attention/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,11 @@ def patch_mllama():
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES.set(
"flash_attention_2", MllamaTextSelfFlashAttention2
)
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES.set(
"flash_attention_2", MllamaTextCrossFlashAttention2
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES.set(
"flash_attention_2", MLLAMA_VISION_ATTENTION_CLASSES.get("sdpa")
)
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def for_d_in_datasets(dataset_configs):
if not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)

if cfg.local_rank == 0:
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
Expand Down

0 comments on commit 5e491bc

Please sign in to comment.