From 5e491bceae6fe24d99724eb67debd2e21c2296c1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 2 Oct 2024 11:33:49 -0400 Subject: [PATCH] fix mllama patch and don't save prepared ds when skipping --- src/axolotl/monkeypatch/attention/mllama.py | 16 +++++++--------- src/axolotl/utils/data/sft.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py index 41cdc80ef1..885965452a 100644 --- a/src/axolotl/monkeypatch/attention/mllama.py +++ b/src/axolotl/monkeypatch/attention/mllama.py @@ -217,13 +217,11 @@ def patch_mllama(): MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access True ) - MLLAMA_TEXT_ATTENTION_CLASSES.set( - "flash_attention_2", MllamaTextSelfFlashAttention2 - ) - MLLAMA_TEXT_CROSS_ATTENTION_CLASSES.set( - "flash_attention_2", MllamaTextCrossFlashAttention2 - ) + MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 + MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MllamaTextCrossFlashAttention2 # fallback to SDPA - MLLAMA_VISION_ATTENTION_CLASSES.set( - "flash_attention_2", MLLAMA_VISION_ATTENTION_CLASSES.get("sdpa") - ) + MLLAMA_VISION_ATTENTION_CLASSES[ + "flash_attention_2" + ] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 2b84825548..7d6922cbf2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -457,7 +457,7 @@ def for_d_in_datasets(dataset_configs): if not cfg.skip_prepare_dataset: dataset, _ = process_datasets_for_packing(cfg, dataset, None) - if cfg.local_rank == 0: + if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: