Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] attempt to fix CI test for PT 2.0 #225

Merged
merged 7 commits into from
Mar 17, 2023
Merged

[test] attempt to fix CI test for PT 2.0 #225

merged 7 commits into from
Mar 17, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Mar 16, 2023

What does this PR do?

This PR fixes the failing tests related to Pytorch 2.0 upgrade

In Pytorch 2.0, the learning rate schedulers have been re-designed, thus we need a bit of refactoring on trl to deal with unexpected behaviors.

In Pytorch 2.0 it seems that the class _LRScheduler have been deprecated in favor of LRScheduler

In Pytorch 2.0:

# Including _LRScheduler for backwards compatibility
# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
class _LRScheduler(LRScheduler):
    pass

In Pytorch < 2.0:

class LambdaLR(_LRScheduler):
    """Sets the learning rate of each parameter group to the initial lr
    times a given function. When last_epoch=-1, sets initial lr as lr.

    Args:"""

Hence the check isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler) does not pass anymore with Pytorch >=2.0

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 16, 2023

The documentation is not available anymore as the PR was closed or merged.

if is_torch_greater_2_0():
self.assertTrue(isinstance(ppo_trainer.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR))
else:
self.assertTrue(isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR))
Copy link
Contributor Author

@younesbelkada younesbelkada Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PT < 2.0 the schedulers are wrapped inside an attribute scheduler
EDIT: what I said is wrong, accelerate wraps it inside AccelerateScheduler only for python != 3.8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking the issue at: huggingface/accelerate#1200

Comment on lines +29 to +37
if _is_python_greater_3_8:
from importlib.metadata import version

torch_version = version("torch")
else:
import pkg_resources

torch_version = pkg_resources.get_distribution("torch").version
return torch_version >= "2.0"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a way I have found to do versions checking wrt to python version, I would be curious to see if you know better techniques.

@younesbelkada younesbelkada changed the title [test] attempt to fix CI test [test] attempt to fix CI test for PT 2.0 Mar 16, 2023
@younesbelkada younesbelkada requested a review from lvwerra March 16, 2023 14:57
@younesbelkada younesbelkada merged commit 6b88bba into main Mar 17, 2023
@younesbelkada younesbelkada deleted the fix-ci branch March 17, 2023 09:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants