Skip to content

Commit

Permalink
[test] attempt to fix CI test for PT 2.0 (#225)
Browse files Browse the repository at this point in the history
* attempt to fix CI test

* attempt to fix CI to PT 2.0

* fix 3.7 issue

* fix

* make quality

* try

* Update tests/test_ppo_trainer.py
  • Loading branch information
younesbelkada authored Mar 17, 2023
1 parent 44f708e commit 6b88bba
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
1 change: 0 additions & 1 deletion tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def test_ppo_step_with_no_ref_sgd(self):
for stat in EXPECTED_STATS:
assert stat in train_stats.keys()

@unittest.skip("TODO: fix this test")
def test_ppo_step_with_no_ref_sgd_lr_scheduler(self):
# initialize dataset
dummy_dataset = self._init_dummy_dataset()
Expand Down
19 changes: 19 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys


if sys.version_info[0] < 3.8:
_is_python_greater_3_8 = False
else:
_is_python_greater_3_8 = True


def is_peft_available():
return importlib.util.find_spec("peft") is not None


def is_torch_greater_2_0():
if _is_python_greater_3_8:
from importlib.metadata import version

torch_version = version("torch")
else:
import pkg_resources

torch_version = pkg_resources.get_distribution("torch").version
return torch_version >= "2.0"
13 changes: 11 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
stack_dicts,
stats_to_np,
)
from ..import_utils import is_torch_greater_2_0
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig

Expand Down Expand Up @@ -248,8 +249,16 @@ def __init__(

self.lr_scheduler = lr_scheduler
if self.lr_scheduler is not None:
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
raise ValueError("lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler")
lr_scheduler_class = (
torch.optim.lr_scheduler._LRScheduler
if not is_torch_greater_2_0()
else torch.optim.lr_scheduler.LRScheduler
)

if not isinstance(self.lr_scheduler, lr_scheduler_class):
raise ValueError(
"lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
)

if self.config.adap_kl_ctrl:
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
Expand Down

0 comments on commit 6b88bba

Please sign in to comment.