Skip to content

Commit

Permalink
update ref val
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Aug 8, 2024
1 parent 665c233 commit 8659101
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ def test_loss_trainer(self):
returns[idx].unsqueeze(0),
)

assert abs(pg_loss.item() - 2.2643) < 0.0001
assert abs(v_loss.item() - 0.1015) < 0.0001
assert abs(pg_loss.item() - 1.8226) < 0.0001
assert abs(v_loss.item() - 0.1260) < 0.0001

# check if we get same results with masked parts removed
pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss(
Expand All @@ -609,8 +609,8 @@ def test_loss_trainer(self):
apply_mask(advantages[idx], mask[idx]).unsqueeze(0),
apply_mask(returns[idx], mask[idx]).unsqueeze(0),
)
assert abs(pg_loss_unmasked.item() - 2.2643) < 0.0001
assert abs(v_loss_unmasked.item() - 0.1015) < 0.0001
assert abs(pg_loss_unmasked.item() - 1.8226) < 0.0001
assert abs(v_loss_unmasked.item() - 0.1260) < 0.0001

@parameterized.expand(
[
Expand Down

0 comments on commit 8659101

Please sign in to comment.