Bug: local variable 'grad' referenced before assignment
DanceInDark opened this issue · 0 comments
DanceInDark commented
Hi, when I run RN50_FP16_4GPU.sh in sparse_learning/imagenet/tuned_resnet, I got this error:
Traceback (most recent call last):
File "/home/z0132/workspace/codes/sparse_learning/imagenet/tuned_resnet/./main.py", line 811, in <module>
main()
File "/home/z0132/workspace/codes/sparse_learning/imagenet/tuned_resnet/./main.py", line 114, in main
train_net(args, logger)
File "/home/z0132/workspace/codes/sparse_learning/imagenet/tuned_resnet/./main.py", line 241, in train_net
train_loop(args, model_and_loss, optimizer, adjust_learning_rate(args), train_loader, val_loader, args.epochs,
File "/home/z0132/workspace/codes/sparse_learning/imagenet/tuned_resnet/./main.py", line 359, in train_loop
model_and_loss.mask.at_end_of_epoch()
File "/home/z0132/workspace/codes/sparse_learning/sparselearning/core.py", line 273, in at_end_of_epoch
self.truncate_weights()
File "/home/z0132/workspace/codes/sparse_learning/sparselearning/core.py", line 378, in truncate_weights
self.gather_statistics()
File "/home/z0132/workspace/codes/sparse_learning/sparselearning/core.py", line 440, in gather_statistics
self.name2variance[name] = self.redistribution_func(self, name, weight, mask)
File "/home/z0132/workspace/codes/sparse_learning/sparselearning/funcs.py", line 35, in momentum_redistribution
grad = masking.get_momentum_for_weight(weight)
File "/home/z0132/workspace/codes/sparse_learning/sparselearning/core.py", line 524, in get_momentum_for_weight
return grad UnboundLocalError: local variable 'grad' referenced before assignment
Function get_momentum_for_weight is:
def get_momentum_for_weight(self, weight):
if 'exp_avg' in self.optimizer.state[weight]:
adam_m1 = self.optimizer.state[weight]['exp_avg']
adam_m2 = self.optimizer.state[weight]['exp_avg_sq']
grad = adam_m1/(torch.sqrt(adam_m2) + 1e-08)
elif 'momentum_buffer' in self.optimizer.state[weight]:
grad = self.optimizer.state[weight]['momentum_buffer']
return grad
It seems that some cases have not been considered.
My torch is 1.13.1. It seems that this code does not work in torch 1.13.1. Any plan to upgrade?