Skip to content

Commit

Permalink
finally fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Sep 15, 2024
1 parent 4c8e4d4 commit 05620fe
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/recipes/test_lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
"log_every_n_steps=1",
"gradient_accumulation_steps=1",
"clip_grad_norm=100",
"tokenizer.max_seq_len=512",
] + dummy_stack_exchange_dataset_config()

@pytest.mark.parametrize("save_adapter_weights_only", [False, True])
Expand Down Expand Up @@ -93,6 +94,8 @@ def test_training_state_on_resume(

expected_loss_values = get_loss_values_from_metric_logger(log_file)

resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)
# Resume training
cmd_2 = f"""
tune run lora_dpo_single_device \
Expand All @@ -106,7 +109,7 @@ def test_training_state_on_resume(
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
metric_logger.filename={resumed_log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
""".split()
Expand All @@ -116,10 +119,10 @@ def test_training_state_on_resume(
runpy.run_path(TUNE_PATH, run_name="__main__")

# Second epoch only
loss_values = get_loss_values_from_metric_logger(log_file)[:2]
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file)

torch.testing.assert_close(
loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
resumed_loss_values[:2], expected_loss_values[2:], rtol=1e-5, atol=1e-5
)

@pytest.mark.integration_test
Expand Down

0 comments on commit 05620fe

Please sign in to comment.