isaac-sim/IsaacGymEnvs

Mismatch between the size of the MLP built during training and the one built during test

Opened this issue · 0 comments

Hello, I have a mismatch between the size of the network of the MLP built during training and the one built during test. This means that whenever I load a checkpoint with the option train=True it cannot assign the weight correctly.

Specifically, during training I run:
python train.py task=HumanoidAMP checkpoint=./runs/HumanoidAMP/nn/HumanoidAMP_100.pth test=False
I have the following log:
Box(-1.0, 1.0, (28,), float32) Box(-inf, inf, (105,), float32) current training device: cuda:0 build mlp: 105 build mlp: 105 build mlp: 210 sigma actor_mlp.0.weight actor_mlp.0.bias actor_mlp.2.weight actor_mlp.2.bias critic_mlp.0.weight critic_mlp.0.bias critic_mlp.2.weight critic_mlp.2.bias value.weight value.bias mu.weight mu.bias _disc_mlp.0.weight _disc_mlp.0.bias _disc_mlp.2.weight _disc_mlp.2.bias _disc_logits.weight _disc_logits.bias RunningMeanStd: (1,) RunningMeanStd: (105,) RunningMeanStd: (210,)

And when I run:
python train.py task=HumanoidAMP checkpoint=./runs/HumanoidAMP/nn/HumanoidAMP_100.pth test=True
I get the log:
Box(-1.0, 1.0, (28,), float32) Box(-inf, inf, (105,), float32) current training device: cuda:0 build mlp: 105 build mlp: 105 build mlp: 105 sigma actor_mlp.0.weight actor_mlp.0.bias actor_mlp.2.weight actor_mlp.2.bias critic_mlp.0.weight critic_mlp.0.bias critic_mlp.2.weight critic_mlp.2.bias value.weight value.bias mu.weight mu.bias _disc_mlp.0.weight _disc_mlp.0.bias _disc_mlp.2.weight _disc_mlp.2.bias _disc_logits.weight _disc_logits.bias RunningMeanStd: (1,) RunningMeanStd: (105,) RunningMeanStd: (105,)

With the error:
RuntimeError: Error(s) in loading state_dict for Network: size mismatch for a2c_network._disc_mlp.0.weight: copying a param with shape torch.Size([512, 210]) from checkpoint, the shape in current model is torch.Size([512, 105]).

I have been debugging this for several days now, but I cannot spot where this mismatch is created. Do anyone have a clue on why, and where this is happening in the code?
Thank you in advance