TypeError: 'LightningSAM' object is not iterable. (with Pytorch Lightning.)
shogi880 opened this issue · 5 comments
shogi880 commented
Hi, nice to meet you.
Your SAM is exciting paper. Very good job! Thank you!
I want to use SAM optimizer with pytorch lightning. But I had an error, and I can not find any information about it.
Could you kindly help me?
Here is my code:
class my_class(LightningModule):
def __init__(self):
self.SAM = True
if self.SAM:
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
if self.SAM:
optimizer = self.optimizers()
# first forward-backward pass
self.manual_backward(loss, optimizer)
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_2 = self.shared_step(batch)
self.manual_backward(loss_2, optimizer)
optimizer.second_step(zero_grad=True)
return loss
def configure_optimizers(self):
if self.SAM:
base_optimizer = torch.optim.SGD
optimizer = SAM(self.parameters(), base_optimizer, lr=0.01)
return optimizer
shogi880 commented
the ERROR code.
Traceback (most recent call last):
File "relic_finetuner.py", line 126, in <module>
cli_main()
File "relic_finetuner.py", line 122, in cli_main
trainer.fit(model, datamodule=dm)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
self._run(model)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
self._dispatch()
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
self.accelerator.start_training(self)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
return self._run_train()
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 101, in run
super().run(batch, batch_idx, dataloader_idx)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 153, in advance
result = self._run_optimization(batch_idx, split_batch)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 204, in _run_optimization
result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 307, in _training_step
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 193, in training_step
return self.training_type_plugin.training_step(*step_kwargs.values())
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 172, in training_step
return self.model.training_step(*args, **kwargs)
File "/root/share/lightning-bolts/pl_bolts/models/self_supervised/ssl_finetuner.py", line 260, in training_step
self.manual_backward(loss, optimizer)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1461, in manual_backward
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, None, None, *args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 588, in backward
self.trainer.accelerator.backward(result, optimizer, *args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 276, in backward
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 78, in backward
model.backward(closure_loss, optimizer, *args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1481, in backward
loss.backward(*args, **kwargs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/torch/autograd/__init__.py", line 142, in backward
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/torch/autograd/__init__.py", line 65, in _tensor_or_tensors_to_tuple
return tuple(tensors)
TypeError: 'LightningSAM' object is not iterable
katsura-jp commented
Hi, @shogi880
I had the same error and read the source code.
optimizer
has been removed from arguments of manual_backward
in ver 1.4, as shown in the warning in ver 1.3.
https://github.com/PyTorchLightning/pytorch-lightning/blob/1.3.1/pytorch_lightning/core/lightning.py#L1244
Are you using version 1.4+?
So I followed Document's example and modified it as follows:
def training_step(self, batch, batch_idx):
optimizer = self.optimizers()
# first forward-backward pass
loss_1 = self.compute_loss(batch)
self.manual_backward(loss_1)
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_2 = self.compute_loss(batch)
self.manual_backward(loss_2)
optimizer.second_step(zero_grad=True)
return loss_1
I hope this helps.
shogi880 commented
Hi @katsura-jp
Thanks for your super helpful comments.