vwxyzjn/cleanrl

LSTM weights should have separate orthogonal initializations for each gate

Jammf opened this issue · 0 comments

Jammf commented

Problem Description

The LSTM weight matrices in ppo_atari_lstm.py seem to be be initialized incorrectly, if the goal is to have a separate orthogonal matrix for each gate. Since lstm.weight_ih_l0 and lstm.weight_hh_l0 have the four gate matricies concatenated together, shouldn't each of the four parts of the fused weight matrix be separately initialized to an orthogonal matrix?

Checklist

Current Behavior

As a minimal example, checking just the $W_{hi}$ component of the hidden-hidden weights:

import torch
lstm = torch.nn.LSTM(512, 128)
_ = torch.nn.init.orthogonal_(lstm.weight_hh_l0, 1.0)
W_hi = lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05)  # check that W_hi is orthogonal
# -> False

Expected Behavior

import torch
lstm = torch.nn.LSTM(512, 128)
_ = torch.nn.init.orthogonal_(lstm.weight_hh_l0[:128], 1.0)  # init a view with only W_hi
W_hi = lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05)  # check that W_hi is orthogonal
# -> True

Possible Solution

  self.lstm = nn.LSTM(512, 128)
  for name, param in self.lstm.named_parameters():
      if "bias" in name:
          nn.init.constant_(param, 0)
      elif "weight" in name:
-         nn.init.orthogonal_(param, 1.0)
+         nn.init.orthogonal_(param[:128], 1.0)
+         nn.init.orthogonal_(param[128:128*2], 1.0)
+         nn.init.orthogonal_(param[128*2:128*3], 1.0)
+         nn.init.orthogonal_(param[128*3:], 1.0)