diff --git a/tests/test_glm_seq2seq.py b/tests/test_glm_seq2seq.py index 0fb0b154..0d008304 100644 --- a/tests/test_glm_seq2seq.py +++ b/tests/test_glm_seq2seq.py @@ -24,8 +24,8 @@ def test_init_trainer_pytorch(self): eval_interval=100, log_interval=50, experiment_name='glm_large', + fp16=True, pytorch_device='cuda', - load_dir=None, lr=1e-4) print("downloading...") diff --git a/tests/test_glm_superglue.py b/tests/test_glm_superglue.py index 5c35d2c2..0735db09 100644 --- a/tests/test_glm_superglue.py +++ b/tests/test_glm_superglue.py @@ -27,6 +27,7 @@ def test_init_trainer_pytorch(self): experiment_name='glm_large', pytorch_device='cuda', load_dir=None, + fp16=True, lr=1e-4, save_epoch=10) print("downloading...")