RL-Adventure/3.dueling dqn.ipynb missing forward?
laz8 opened this issue · 0 comments
laz8 commented
def compute_td_loss(batch_size):
state, action, reward, next_state, done = replay_buffer.sample(batch_size)
state = Variable(torch.FloatTensor(np.float32(state)))
next_state = Variable(torch.FloatTensor(np.float32(next_state)))
action = Variable(torch.LongTensor(action))
reward = Variable(torch.FloatTensor(reward))
done = Variable(torch.FloatTensor(done))
q_values = current_model(state)
next_q_values = target_model(next_state)
q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
next_q_value = next_q_values.max(1)[0]
expected_q_value = reward + gamma * next_q_value * (1 - done)
loss = (q_value - expected_q_value.detach()).pow(2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss
no forward call?
edit: sry, found my issue caused by Variable not "the missing forward", it works without calling forward(), the result is the same, can be closed.