mickaelChen/ReDO

TypeError: tuple indices must be integers or slices, not str on optimizerRecZ.load_state_dict

Closed this issue · 3 comments

knok commented

I tried to resume training, it causes the following error:

  File "/home/knok/ml/ReDO/train.py", line 217, in <module>
    optimizerRecZ.load_state_dict(state["optimizerRecZ"])
  File "/opt/p3/lib/python3.7/site-packages/torch/optim/optimizer.py", line 108, in load_state_dict
    saved_groups = state_dict['param_groups']

I can fix it by the following patch:

diff --git a/train.py b/train.py
index bd53fbf..4298230 100644
--- a/train.py
+++ b/train.py
@@ -214,7 +214,7 @@ if opt.iteration > 0:
     optimizerDX.load_state_dict(state["optimizerDX"])
     if opt.wrecZ > 0:
         netRecZ.load_state_dict(state["netRecZ"])
-        optimizerRecZ.load_state_dict(state["optimizerRecZ"])
+        optimizerRecZ.load_state_dict(state["optimizerRecZ"][0])
 else:
     try:
         os.remove(os.path.join(opt.outf, "train.dat"))

I don't know why optimizerRecZ state is only cause such error, anyway I can fix it.

Thank you for bringing this issue to our attention and for the proposed fix.

There was a bug when saving the state of optimizerRecZ. Since the original bug is now corrected, the fix should not be necessary for new saved states.

Older checkpoints can be loaded using your method.

knok commented

Thank you for your fix.
I could understand why it make value as a list.
Needless comma in the end of the line makes a value to list.

Yes, this is how you make a tuple with only one element.

>>> print((1))
1
>>> print((1,))
(1,)
>>> print((1,)[0])
1