vwxyzjn/cleanrl

Bug in actor loss for sac_continuous_action.py

terencenwz opened this issue · 5 comments

Problem Description

In the following line

min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)

min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

should be

min_qf_pi = torch.min(qf1_pi, qf2_pi)
actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

Or else
(alpha * log_pi) - min_qf_pi produces a matrix of [batch_size x batch_size] instead of just [batch_size]
and gives a different actor loss from my tests:
min_qf_pi.shape: torch.Size([8])
log_pi.shape: torch.Size([8, 1])
((alpha * log_pi) - min_qf_pi):
tensor([[ 8.7687, 8.6482, 5.3872, 8.6279, 8.7512, 6.9031, 7.5819, 5.7800],
[ 9.0996, 8.9791, 5.7181, 8.9588, 9.0821, 7.2340, 7.9129, 6.1109],
[ 4.5497, 4.4292, 1.1682, 4.4089, 4.5323, 2.6841, 3.3630, 1.5610],
[ 9.8283, 9.7078, 6.4468, 9.6875, 9.8109, 7.9627, 8.6416, 6.8396],
[ 9.3948, 9.2743, 6.0133, 9.2540, 9.3773, 7.5292, 8.2081, 6.4061],
[ 6.0864, 5.9659, 2.7049, 5.9456, 6.0689, 4.2208, 4.8996, 3.0977],
[ 3.0503, 2.9298, -0.3312, 2.9095, 3.0328, 1.1847, 1.8635, 0.0616],
[ 1.6122, 1.4917, -1.7694, 1.4714, 1.5947, -0.2535, 0.4254, -1.3766]],
device='cuda:0', grad_fn=)

The line in the Atari version is correct

min_qf_values = torch.min(qf1_values, qf2_values)

Yeah, the version using min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1) computes an outer product due to the different shapes. That's why the sac_atari version omits the .view(-1) because I ran into the same issue.

It might be worthwhile to investigate why this hasn't been an issue previously though.

Your fix should work @terencenwz . If you wanna do a PR, I can merge it if that's also fine for @dosssman and @vwxyzjn

Thanks, appreciate it.

Will come back to this in the middle of the week if no changes by then.

Thanks for raising this issue @terencenwz. Alongside with the PR, we should probably re-run the benchmark experiments as well given that this is a performance-impacting change. The specific steps are listed at https://docs.cleanrl.dev/contribution/#rlops-for-performance-impacting-changes

After further tests, I found that the outer product gives the same mean, so the actor_loss is actually unaffected.

@terencenwz Thanks, I will update #383 with your suggested change. Im planning on running performance tests this week