Skip to content

Commit

Permalink
Fix(cfg): Add validation for save_strategy and eval_strategy (axolotl…
Browse files Browse the repository at this point in the history
…-ai-cloud#633)

* Fix(cfg): Check save_strategy cfg conflict with save_steps

* Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps

* chore: add extra check for steps only
  • Loading branch information
NanoCode012 authored Sep 28, 2023
1 parent ecca9ef commit b35694f
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 11 deletions.
18 changes: 18 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,24 @@ def validate_config(cfg):
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)

if (
cfg.evaluation_strategy
and cfg.eval_steps
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)

if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)

# TODO
# MPT 7b
Expand Down
15 changes: 4 additions & 11 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,26 +604,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est

if cfg.eval_steps and cfg.evaluation_strategy:
# assume if the user set both, they know what they're doing
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
if cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.evaluation_strategy:
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
# if explicitly set for epoch, just set, and eval steps don't matter
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.eval_steps:
# steps isn't used w/ epochs
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"

if cfg.save_steps:
# save_steps implies save_strategy of steps
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
Expand Down
168 changes: 168 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,171 @@ def test_sharegpt_deprecation(self):
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt:load_role"

def test_no_conflict_save_strategy(self):
cfg = DictDefault(
{
"save_strategy": "epoch",
"save_steps": 10,
}
)

with pytest.raises(
ValueError, match=r".*save_strategy and save_steps mismatch.*"
):
validate_config(cfg)

cfg = DictDefault(
{
"save_strategy": "no",
"save_steps": 10,
}
)

with pytest.raises(
ValueError, match=r".*save_strategy and save_steps mismatch.*"
):
validate_config(cfg)

cfg = DictDefault(
{
"save_strategy": "steps",
}
)

validate_config(cfg)

cfg = DictDefault(
{
"save_strategy": "steps",
"save_steps": 10,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"save_steps": 10,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"save_strategy": "no",
}
)

validate_config(cfg)

def test_no_conflict_eval_strategy(self):
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"eval_steps": 10,
}
)

with pytest.raises(
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
):
validate_config(cfg)

cfg = DictDefault(
{
"evaluation_strategy": "no",
"eval_steps": 10,
}
)

with pytest.raises(
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
):
validate_config(cfg)

cfg = DictDefault(
{
"evaluation_strategy": "steps",
}
)

validate_config(cfg)

cfg = DictDefault(
{
"evaluation_strategy": "steps",
"eval_steps": 10,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"eval_steps": 10,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"evaluation_strategy": "no",
}
)

validate_config(cfg)

cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"val_set_size": 0,
}
)

with pytest.raises(
ValueError,
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
):
validate_config(cfg)

cfg = DictDefault(
{
"eval_steps": 10,
"val_set_size": 0,
}
)

with pytest.raises(
ValueError,
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
):
validate_config(cfg)

cfg = DictDefault(
{
"val_set_size": 0,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"eval_steps": 10,
"val_set_size": 0.01,
}
)

validate_config(cfg)

cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"val_set_size": 0.01,
}
)

validate_config(cfg)

0 comments on commit b35694f

Please sign in to comment.