-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[test
] attempt to fix CI test for PT 2.0
#225
Conversation
The documentation is not available anymore as the PR was closed or merged. |
tests/test_ppo_trainer.py
Outdated
if is_torch_greater_2_0(): | ||
self.assertTrue(isinstance(ppo_trainer.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR)) | ||
else: | ||
self.assertTrue(isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PT < 2.0 the schedulers are wrapped inside an attribute scheduler
EDIT: what I said is wrong, accelerate
wraps it inside AccelerateScheduler
only for python != 3.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracking the issue at: huggingface/accelerate#1200
if _is_python_greater_3_8: | ||
from importlib.metadata import version | ||
|
||
torch_version = version("torch") | ||
else: | ||
import pkg_resources | ||
|
||
torch_version = pkg_resources.get_distribution("torch").version | ||
return torch_version >= "2.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a way I have found to do versions checking wrt to python version, I would be curious to see if you know better techniques.
test
] attempt to fix CI testtest
] attempt to fix CI test for PT 2.0
What does this PR do?
This PR fixes the failing tests related to Pytorch 2.0 upgrade
In Pytorch 2.0, the learning rate schedulers have been re-designed, thus we need a bit of refactoring on
trl
to deal with unexpected behaviors.In Pytorch 2.0 it seems that the class
_LRScheduler
have been deprecated in favor ofLRScheduler
In Pytorch 2.0:
In Pytorch < 2.0:
Hence the check
isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
does not pass anymore with Pytorch >=2.0