Skip to content

Commit

Permalink
calculate scalar q_values to log average_q statistics in categorical …
Browse files Browse the repository at this point in the history
…dqn algorithms.
  • Loading branch information
keisuke-nakata committed Jul 30, 2020
1 parent d420891 commit 4738769
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pfrl/agents/categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def _compute_y_and_t(self, exp_batch):
batch_q_target = self._compute_target_values(exp_batch)
assert batch_q_target.shape == (batch_size, n_atoms)

# for `agent.get_statistics()`
batch_q_scalars = qout.evaluate_actions(batch_actions)
self.q_record.extend(batch_q_scalars.detach().cpu().numpy().ravel())

return batch_q, batch_q_target

def _compute_loss(self, exp_batch, errors_out=None):
Expand Down

0 comments on commit 4738769

Please sign in to comment.