Skip to content

Commit

Permalink
fix trainer.py:multistep_trainer:policy.forward args bug
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Jan 12, 2023
1 parent 09d4490 commit a334a0c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ding/framework/middleware/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def detect_alive(self, expected, timeout):

task.off(self._event_name_detect)
logging.info(
"Barrier detect node done, node:[{}] has connected with {} active nodes!".format(self.node_id, expected)
"Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected)
)


Expand Down
4 changes: 2 additions & 2 deletions ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

if ctx.train_data is None: # no enough data from data fetcher
return
data = ctx.train_data.to(policy._device)
train_output = policy.forward(data)
# data = ctx.train_data.to(policy._device)
train_output = policy.forward(ctx.train_data)
nonlocal last_log_iter
if ctx.train_iter - last_log_iter >= log_freq:
loss = np.mean([o['total_loss'] for o in train_output])
Expand Down

0 comments on commit a334a0c

Please sign in to comment.