davda54/sam

TypeError: 'LightningSAM' object is not iterable. (with Pytorch Lightning.)

shogi880 opened this issue · 5 comments

Hi, nice to meet you.
Your SAM is exciting paper. Very good job! Thank you!
I want to use SAM optimizer with pytorch lightning. But I had an error, and I can not find any information about it.
Could you kindly help me?

Here is my code:

class my_class(LightningModule):
    def __init__(self):
            self.SAM = True
            if self.SAM:
                self.automatic_optimization = False
    
    def training_step(self, batch, batch_idx):
            loss = self.shared_step(batch)
            
            if self.SAM:
                optimizer = self.optimizers()
                # first forward-backward pass
                self.manual_backward(loss, optimizer)
                optimizer.first_step(zero_grad=True)
    
                # second forward-backward pass
                loss_2 = self.shared_step(batch)
                self.manual_backward(loss_2, optimizer)
                optimizer.second_step(zero_grad=True)
            return loss
    
     def configure_optimizers(self):
            if self.SAM:
                base_optimizer = torch.optim.SGD
                optimizer = SAM(self.parameters(), base_optimizer, lr=0.01)
            return optimizer

the ERROR code.

Traceback (most recent call last):
  File "relic_finetuner.py", line 126, in <module>
    cli_main()
  File "relic_finetuner.py", line 122, in cli_main
    trainer.fit(model, datamodule=dm)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
    self._run(model)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
    self._dispatch()
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
    self.accelerator.start_training(self)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
    return self._run_train()
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
    self.fit_loop.run()
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 130, in advance
    batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 101, in run
    super().run(batch, batch_idx, dataloader_idx)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 153, in advance
    result = self._run_optimization(batch_idx, split_batch)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 204, in _run_optimization
    result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 307, in _training_step
    training_step_output = self.trainer.accelerator.training_step(step_kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 193, in training_step
    return self.training_type_plugin.training_step(*step_kwargs.values())
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 172, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/root/share/lightning-bolts/pl_bolts/models/self_supervised/ssl_finetuner.py", line 260, in training_step
    self.manual_backward(loss, optimizer)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1461, in manual_backward
    self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, None, None, *args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 588, in backward
    self.trainer.accelerator.backward(result, optimizer, *args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 276, in backward
    self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 78, in backward
    model.backward(closure_loss, optimizer, *args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1481, in backward
    loss.backward(*args, **kwargs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/torch/autograd/__init__.py", line 142, in backward
    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
  File "/root/.pyenv/versions/3.8.11/lib/python3.8/site-packages/torch/autograd/__init__.py", line 65, in _tensor_or_tensors_to_tuple
    return tuple(tensors)
TypeError: 'LightningSAM' object is not iterable

Hi, @shogi880
I had the same error and read the source code.
optimizer has been removed from arguments of manual_backward in ver 1.4, as shown in the warning in ver 1.3.
https://github.com/PyTorchLightning/pytorch-lightning/blob/1.3.1/pytorch_lightning/core/lightning.py#L1244

Are you using version 1.4+?

So I followed Document's example and modified it as follows:

def training_step(self, batch, batch_idx):
    optimizer = self.optimizers()

    # first forward-backward pass
    loss_1 = self.compute_loss(batch)
    self.manual_backward(loss_1)
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    loss_2 = self.compute_loss(batch)
    self.manual_backward(loss_2)
    optimizer.second_step(zero_grad=True)

    return loss_1

I hope this helps.

Hi @katsura-jp
Thanks for your super helpful comments.