-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn.py
21 lines (16 loc) · 819 Bytes
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
def calc_loss(batch, net, tgt_net, gamma, device="cpu"):
states, actions, rewards, dones, next_states = batch
states_v = torch.tensor(states).to(device)
next_states_v = torch.tensor(next_states).to(device)
actions_v = torch.tensor(actions).to(device)
rewards_v = torch.tensor(rewards).to(device)
done_mask = torch.BoolTensor(dones).to(device)
state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = tgt_net(next_states_v).max(1)[0]
next_state_values[done_mask] = 0.0
next_state_values = next_state_values.detach()
expected_state_action_values = next_state_values * gamma + rewards_v
return nn.MSELoss()(state_action_values, expected_state_action_values)