input state-action pair into Rainbow DQN
junhuang-ifast opened this issue · 1 comments
Hi, I was thinking of incorporating the action (in addition to state) as a state-action pair input into the rainbow dqn model, however I am unsure of which part to insert it. Below code shows 4 places where I am thinking of adding the actions (as input to the model), but I am unsure if it is appropriate to add them there or not. (please see "<----" symbol)
def _compute_dqn_loss(self, samples: Dict[str, np.ndarray], gamma: float) -> torch.Tensor:
"""Return categorical dqn loss."""
device = self.device # for shortening the following lines
state = torch.FloatTensor(samples["obs"]).to(device)
next_state = torch.FloatTensor(samples["next_obs"]).to(device)
action = torch.LongTensor(samples["acts"]).to(device)
reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
# Categorical DQN algorithm
delta_z = float(self.v_max - self.v_min) / (self.atom_size - 1)
with torch.no_grad():
# Double DQN
next_state_EDIT = np.concatenate([next_state, action]) <---- concat action
next_action = self.dqn(next_state_EDIT).argmax(1) <---- edited state as input
next_dist = self.dqn_target.dist(next_state_EDIT) <---- edited state as input
next_dist = next_dist[range(self.batch_size), next_action]
t_z = reward + (1 - done) * gamma * self.support
t_z = t_z.clamp(min=self.v_min, max=self.v_max)
b = (t_z - self.v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
offset = (
torch.linspace(
0, (self.batch_size - 1) * self.atom_size, self.batch_size
).long()
.unsqueeze(1)
.expand(self.batch_size, self.atom_size)
.to(self.device)
)
proj_dist = torch.zeros(next_dist.size(), device=self.device)
proj_dist.view(-1).index_add_(
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
)
proj_dist.view(-1).index_add_(
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
)
state_EDIT = np.concatenate([state, action]) <---- concat action
dist = self.dqn.dist(state_EDIT) <---- edited state as input
log_p = torch.log(dist[range(self.batch_size), action])
elementwise_loss = -(proj_dist * log_p).sum(1)
return elementwise_loss
def select_action(self, state: np.ndarray) -> np.ndarray:
"""Select an action from the input state."""
# NoisyNet: no epsilon greedy action selection
state_EDIT = np.concatenate([state, action]) <---- concat action
selected_action = self.dqn(
torch.FloatTensor(state_EDIT).to(self.device) <---- edited state as input
).argmax()
selected_action = selected_action.detach().cpu().numpy()
if not self.is_test:
self.transition = [state, selected_action]
return selected_action
I have seen state-action pair as input to the Q function of soft actor critic before, but not in DQN. So I am unsure if its logical to do this, especially in self.dqn.dist(state_EDIT)
and selected_action = self.dqn(torch.FloatTensor(state_EDIT).to(self.device)).argmax()
.
Any ideas on this? thanks :)
Firstly, I would like to know why you want to use state-action pairs for DQN.
DQN is a method for problems with small size discrete actions, so it is designed to predict all actions' values according to input states. Your approach (state-action input) is usually employed for the problems on continuous action space which is intractable to predict all state-actions' values.