From 8fda36313a7fe715aff5bd161a10af9064a501e1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Jun 2024 13:20:32 -0400 Subject: [PATCH 1/3] update to be deprecated evaluation_strategy and c4 dataset --- src/axolotl/core/trainer_builder.py | 16 +- src/axolotl/utils/callbacks/__init__.py | 5 +- src/axolotl/utils/config/__init__.py | 369 ------------------ .../config/models/input/v0_4_1/__init__.py | 26 +- tests/test_validation.py | 22 +- 5 files changed, 36 insertions(+), 402 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index af6eaabf2f..1cad1a8c3f 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1416,17 +1416,15 @@ def build(self, total_num_steps): if not self.cfg.test_datasets and self.cfg.val_set_size == 0: # no eval set, so don't eval - training_arguments_kwargs["evaluation_strategy"] = "no" + training_arguments_kwargs["eval_strategy"] = "no" elif self.cfg.eval_steps: - training_arguments_kwargs["evaluation_strategy"] = "steps" + training_arguments_kwargs["eval_strategy"] = "steps" training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps - elif self.cfg.evaluation_strategy: - training_arguments_kwargs[ - "evaluation_strategy" - ] = self.cfg.evaluation_strategy + elif self.cfg.eval_strategy: + training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy else: # we have an eval set, but no steps defined, default to use epoch - training_arguments_kwargs["evaluation_strategy"] = "epoch" + training_arguments_kwargs["eval_strategy"] = "epoch" if self.cfg.save_steps: training_arguments_kwargs["save_strategy"] = "steps" @@ -1860,10 +1858,10 @@ def build_training_arguments(self, total_num_steps): training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors if self.eval_dataset: - training_args_kwargs["evaluation_strategy"] = "steps" + training_args_kwargs["eval_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps else: - training_args_kwargs["evaluation_strategy"] = "no" + training_args_kwargs["eval_strategy"] = "no" if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 0bc781fcb4..8768bc2bf7 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -64,10 +64,7 @@ def on_step_end( control: TrainerControl, **kwargs, ): - if ( - args.evaluation_strategy == IntervalStrategy.STEPS - and state.global_step == 1 - ): + if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1: control.should_evaluate = True return control diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 6e5ecda03a..53ff8e1c92 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -1,8 +1,6 @@ """Module for working with config dicts""" -import json import logging import os -from pathlib import Path from typing import Optional import torch @@ -247,370 +245,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): return DictDefault( dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True)) ) - - -def legacy_validate_config(cfg): - """ - This is a "pre-validation" step that handles the yaml configuration before we have any - information about the model architecture - """ - if is_torch_bf16_gpu_available(): - if not cfg.bf16 and not cfg.bfloat16: - LOG.info("bf16 support detected, but not enabled for this configuration.") - else: - if ( - not cfg.merge_lora - and not cfg.is_preprocess - and (cfg.bf16 is True or cfg.bfloat16 is True) - ): - raise ValueError( - "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." - ) - if ( - # pylint: disable=too-many-boolean-expressions - not (cfg.bf16 or cfg.bfloat16) - and (cfg.fp16 or cfg.float16) - and not cfg.adapter - and not cfg.flash_attention - and cfg.sample_packing - ): - LOG.warning( - "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." - ) - # ValueError: Attempting to unscale FP16 gradients. - # OR - # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half - if cfg.max_packed_sequence_len: - raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") - - if cfg.sample_packing and cfg.rl: - raise ValueError("`sample_packing: true` does not work with RLHF training") - - if cfg.sample_packing and not cfg.pad_to_sequence_len: - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - ) - - if cfg.gradient_accumulation_steps and cfg.batch_size: - raise ValueError( - "please set only one of gradient_accumulation_steps or batch_size" - ) - if cfg.batch_size: - LOG.warning( - "%s\n%s", - "batch_size is not recommended. Please use gradient_accumulation_steps instead.", - "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", - ) - if ( - cfg.eval_batch_size - and cfg.micro_batch_size - and cfg.eval_batch_size != cfg.micro_batch_size - ): - LOG.warning( - "eval_batch_size != micro_batch_size. This can lead to VRAM instability." - ) - - if cfg.adapter == "qlora": - if cfg.merge_lora: - # can't merge qlora if loaded in 8bit or 4bit - if cfg.load_in_8bit: - raise ValueError("Can't merge qlora if loaded in 8bit") - - if cfg.gptq: - raise ValueError("Can't merge qlora if gptq") - - if cfg.load_in_4bit: - raise ValueError("Can't merge qlora if loaded in 4bit") - - else: - if cfg.load_in_8bit: - raise ValueError("Can't load qlora in 8bit") - - if cfg.gptq: - raise ValueError("Can't load qlora if gptq") - - if not cfg.load_in_4bit: - raise ValueError("Require cfg.load_in_4bit to be True for qlora") - - if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: - raise ValueError("Fused modules are not supported with QLoRA") - - loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq: - LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") - - if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp): - raise ValueError("Fused modules are not supported with LoRA") - - if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters: - raise ValueError( - "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." - ) - - if cfg.relora_steps: - if cfg.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") - - if cfg.fsdp: - raise ValueError("fsdp not supported with ReLoRA") - - if cfg.deepspeed: - raise ValueError("deepspeed not supported with ReLoRA") - - if cfg.lr_scheduler == "one_cycle": - raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") - - if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: - raise ValueError("Fused modules are not supported with ReLoRA") - - if cfg.trust_remote_code: - LOG.warning( - "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." - ) - - if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True: - raise ValueError( - "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" - ) - - if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: - raise ValueError("FSDP is not supported for falcon models") - - if ( - cfg.base_model and "mpt" in cfg.base_model.lower() - ) and cfg.gradient_checkpointing: - raise ValueError("gradient_checkpointing is not supported for MPT models") - - if cfg.flash_optimum is True: - if cfg.adapter: - LOG.warning("BetterTransformers probably doesn't work with PEFT adapters") - if cfg.fp16 or cfg.bf16: - raise ValueError("AMP is not supported with BetterTransformer") - if cfg.float16 is not True and cfg.bfloat16 is not True: - LOG.warning( - "You should probably set bfloat16 or float16 to true to " - "load the model in float16 for BetterTransformers" - ) - if int(torch.__version__.split(".", maxsplit=1)[0]) < 2: - LOG.warning("torch>=2.0.0 required") - raise ValueError( - f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" - ) - - if cfg.pretraining_dataset and cfg.group_by_length: - LOG.warning( - "You probably want to disable group_by_length as it will force a streamed dataset to download completely." - ) - if cfg.pretraining_dataset and not cfg.max_steps: - raise ValueError( - "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" - ) - - if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( - not cfg.optimizer or "adamw" not in cfg.optimizer - ): - LOG.warning("adamw hyperparameters found, but no adamw optimizer set") - - if cfg.push_to_hub_model_id: - raise ValueError( - "push_to_hub_model_id is deprecated. Please use hub_model_id instead." - ) - - if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]: - LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty." - ) - - if cfg.gptq and cfg.revision_of_model: - raise ValueError( - "revision_of_model is not supported for GPTQ models. " - + "Please download the model from HuggingFace Hub manually for correct branch, " - + "point to its path, and remove revision_of_model from the config." - ) - - # if cfg.sample_packing and cfg.sdp_attention: - # # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 - # raise ValueError( - # "sample_packing not compatible with sdp_attention. Use flash_attention" - # ) - - if cfg.sample_packing and cfg.xformers_attention: - raise ValueError( - "sample_packing not compatible with xformers_attention. Use flash_attention" - ) - - if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16): - # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 - LOG.warning( - "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " - "This may work on H100s." - ) - - if cfg.early_stopping_patience: - if not cfg.save_steps or not cfg.eval_steps: - raise ValueError( - "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." - ) - if cfg.save_steps % cfg.eval_steps != 0: - raise ValueError( - "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." - ) - - if cfg.saves_per_epoch and cfg.save_steps: - raise ValueError( - "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." - ) - if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy must be empty or set to `steps` when used with saves_per_epoch." - ) - if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - if cfg.evals_per_epoch and cfg.eval_steps: - raise ValueError( - "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." - ) - if ( - cfg.evals_per_epoch - and cfg.evaluation_strategy - and cfg.evaluation_strategy != "steps" - ): - raise ValueError( - "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." - ) - if ( - cfg.evaluation_strategy - and cfg.eval_steps - and cfg.evaluation_strategy != "steps" - ): - raise ValueError( - "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps." - ) - - if ( - cfg.val_set_size == 0 - and (cfg.eval_steps or cfg.evaluation_strategy) - and not cfg.test_datasets - ): - raise ValueError( - "eval_steps and evaluation_strategy are not supported with val_set_size == 0" - ) - - if ( - cfg.sample_packing - and cfg.eval_table_size - and cfg.eval_sample_packing is not False - ): - raise ValueError( - "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." - ) - - if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit): - raise ValueError( - "load_in_8bit and load_in_4bit are not supported without setting an adapter." - "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." - ) - - if cfg.rope_scaling: - LOG.warning("`rope_scaling` should now be be a key under `model_config`") - - if cfg.wandb_run_id and not cfg.wandb_name: - cfg.wandb_name = cfg.wandb_run_id - - LOG.warning( - "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." - ) - - if cfg.noisy_embedding_alpha is not None: - # Deprecated, use neftune_noise_alpha - LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") - if cfg.neftune_noise_alpha is None: - cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha - else: - # User is providing both; bail and have them sort out their settings - raise ValueError( - "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" - ) - - if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: - raise ValueError("neftune_noise_alpha must be > 0.0") - - if cfg.max_memory is not None and cfg.gpu_memory_limit is not None: - raise ValueError( - "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." - ) - - if ( - cfg.unfrozen_parameters - and cfg.gradient_checkpointing_kwargs - and cfg.gradient_checkpointing_kwargs.use_reentrant is True - ): - # https://github.com/huggingface/transformers/issues/21381 - raise ValueError( - "`use_reentrant` must be false when used with partially frozen model." - ) - - if cfg.deepspeed and Path(cfg.deepspeed).is_file(): - with open(cfg.deepspeed, encoding="utf-8") as file: - contents = file.read() - deepspeed_cfg: DictDefault = DictDefault(json.loads(contents)) - if cfg.flash_attention: - if ( - deepspeed_cfg.zero_optimization - and deepspeed_cfg.zero_optimization.stage == 3 - ): - if not ( - ( - deepspeed_cfg.bf16 - and deepspeed_cfg.bf16.enabled # pylint: disable=no-member - is True - ) - or ( - deepspeed_cfg.fp16 - and deepspeed_cfg.fp16.enabled # pylint: disable=no-member - is True - ) - ): - raise ValueError( - "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" - ) - if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer: - LOG.warning( - f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer." - ) - - if cfg.test_datasets and cfg.val_set_size: - raise ValueError( - "non-zero val_set_size should not be used with test_datasets configuration" - ) - - if cfg.fsdp and "bnb" in cfg.optimizer: - raise ValueError(f"FSDP not compatible with {cfg.optimizer}") - - if cfg.do_causal_lm_eval and cfg.eval_sample_packing: - raise ValueError( - "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" - ) - - if cfg.eval_causal_lm_metrics: - if not isinstance(cfg.eval_causal_lm_metrics, list): - raise ValueError("eval_causal_lm_metrics must be a list") - # only ["sacrebleu", "comet", "ter", "chrf"] supported - if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS: - raise ValueError( - f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" - ) - - # TODO - # MPT 7b - # https://github.com/facebookresearch/bitsandbytes/issues/25 - # no 8bit adaAmw w bf16 - - # GPT-NeoX - # evals broken when extending context len - # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product - # attention_mask = causal_mask + attention_mask - # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1feb8aae86..aeaf90667d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -68,6 +68,7 @@ class DeprecatedParameters(BaseModel): rope_scaling: Optional[Any] = None noisy_embedding_alpha: Optional[float] = None dpo_beta: Optional[float] = None + evaluation_strategy: Optional[str] = None @field_validator("max_packed_sequence_len") @classmethod @@ -99,6 +100,13 @@ def validate_dpo_beta(cls, dpo_beta): LOG.warning("dpo_beta is deprecated, use rl_beta instead") return dpo_beta + @field_validator("evaluation_strategy") + @classmethod + def validate_evaluation_strategy(cls, evaluation_strategy): + if evaluation_strategy is not None: + LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") + return evaluation_strategy + class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" @@ -731,7 +739,7 @@ class Config: warmup_ratio: Optional[float] = None eval_steps: Optional[Union[int, float]] = None evals_per_epoch: Optional[Union[int]] = None - evaluation_strategy: Optional[str] = None + eval_strategy: Optional[str] = None save_steps: Optional[Union[int, float]] = None saves_per_epoch: Optional[int] = None save_strategy: Optional[str] = None @@ -1033,21 +1041,21 @@ def check_push_save(cls, data): @classmethod def check_evals(cls, data): if ( - data.get("evaluation_strategy") + data.get("eval_strategy") and data.get("eval_steps") - and data.get("evaluation_strategy") != "steps" + and data.get("eval_strategy") != "steps" ): raise ValueError( - "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps." + "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." ) if ( data.get("val_set_size") == 0 - and (data.get("eval_steps") or data.get("evaluation_strategy")) + and (data.get("eval_steps") or data.get("eval_strategy")) and not data.get("test_datasets") ): raise ValueError( - "eval_steps and evaluation_strategy are not supported with val_set_size == 0" + "eval_steps and eval_strategy are not supported with val_set_size == 0" ) if data.get("evals_per_epoch") and data.get("eval_steps"): raise ValueError( @@ -1055,11 +1063,11 @@ def check_evals(cls, data): ) if ( data.get("evals_per_epoch") - and data.get("evaluation_strategy") - and data.get("evaluation_strategy") != "steps" + and data.get("eval_strategy") + and data.get("eval_strategy") != "steps" ): raise ValueError( - "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." + "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." ) if data.get("do_bench_eval") and not ( diff --git a/tests/test_validation.py b/tests/test_validation.py index 67670b1928..44cb4a4b9b 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -726,7 +726,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): cfg = ( DictDefault( { - "evaluation_strategy": "epoch", + "eval_strategy": "epoch", "eval_steps": 10, } ) @@ -734,14 +734,14 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): ) with pytest.raises( - ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" + ValueError, match=r".*eval_strategy and eval_steps mismatch.*" ): validate_config(cfg) cfg = ( DictDefault( { - "evaluation_strategy": "no", + "eval_strategy": "no", "eval_steps": 10, } ) @@ -749,14 +749,14 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): ) with pytest.raises( - ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" + ValueError, match=r".*eval_strategy and eval_steps mismatch.*" ): validate_config(cfg) cfg = ( DictDefault( { - "evaluation_strategy": "steps", + "eval_strategy": "steps", } ) | minimal_cfg @@ -767,7 +767,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): cfg = ( DictDefault( { - "evaluation_strategy": "steps", + "eval_strategy": "steps", "eval_steps": 10, } ) @@ -790,7 +790,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): cfg = ( DictDefault( { - "evaluation_strategy": "no", + "eval_strategy": "no", } ) | minimal_cfg @@ -801,7 +801,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): cfg = ( DictDefault( { - "evaluation_strategy": "epoch", + "eval_strategy": "epoch", "val_set_size": 0, } ) @@ -810,7 +810,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): with pytest.raises( ValueError, - match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", + match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*", ): validate_config(cfg) @@ -826,7 +826,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): with pytest.raises( ValueError, - match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", + match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*", ): validate_config(cfg) @@ -856,7 +856,7 @@ def test_no_conflict_eval_strategy(self, minimal_cfg): cfg = ( DictDefault( { - "evaluation_strategy": "epoch", + "eval_strategy": "epoch", "val_set_size": 0.01, } ) From fc09867672107fb936693d806a81c33f12b288fa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 14 Nov 2024 11:58:19 -0500 Subject: [PATCH 2/3] chore: lint --- src/axolotl/utils/config/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 53ff8e1c92..b12ad81136 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -8,7 +8,6 @@ from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS from axolotl.utils.config.models.input.v0_4_1 import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, ) From 7da216dd91a75717614210537062a60b6555d9b3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Nov 2024 14:24:01 -0500 Subject: [PATCH 3/3] remap eval strategy to new config and add tests --- .../config/models/input/v0_4_1/__init__.py | 13 +++++++++++++ tests/test_validation.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index aeaf90667d..c295eaa168 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1308,6 +1308,19 @@ def check_val_w_test_datasets(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_eval_strategy(cls, data): + if ( + data.get("evaluation_strategy") is not None + and data.get("eval_strategy") is None + ): + LOG.info( + "explicitly setting `eval_strategy` from the `evaluation_strategy`" + ) + data["eval_strategy"] = data.get("evaluation_strategy") + return data + @model_validator(mode="before") @classmethod def check_fsdp_offload_w_8bit_optimizer(cls, data): diff --git a/tests/test_validation.py b/tests/test_validation.py index 44cb4a4b9b..f3f4d18ab8 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1095,6 +1095,24 @@ def test_dpo_beta_deprecation(self, minimal_cfg): assert new_cfg["dpo_beta"] is None assert len(self._caplog.records) == 1 + def test_eval_strategy_remap(self, minimal_cfg): + cfg = ( + DictDefault( + { + "evaluation_strategy": "steps", + } + ) + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + new_cfg = validate_config(cfg) + assert new_cfg.eval_strategy == "steps" + assert ( + "evaluation_strategy is deprecated, use eval_strategy instead" + in self._caplog.records[0].message + ) + class TestValidationCheckModelConfig(BaseValidation): """