Questions about training big-lama and the full-checkpoint
Closed this issue · 17 comments
Hi, thanks again for your excellent works.
Is the big-lama model trained on places-challenge dataset? Whether it performs greatly better than a big-lama trained with places2-standard?
Is it possible to release the full checkpoints of the big-lama model, so we can finetune it on other data? Thanks.
Could you also share the training log or time of big-lama? Thanks so much.
Is the big-lama model trained on places-challenge dataset?
Not exactly Places Challenge - it was trained on a subset of 157 categories from Places Challenge. Please refer to supp.mat for exact list of these categories.
Whether it performs greatly better than a big-lama trained with places2-standard?
The difference is pretty noticeable by a naked eye, but the improvement from standard -> subset-of-challenge is less than the most important contributions from our paper (e.g. masks, architecture and segm-pl).
Could you also share the training log or time of big-lama? Thanks so much.
It took approximately 12 days to train this big-lama on 8xV100 32GB with total batch size of 120 (8 gpus x 15 samples).
Is it possible to release the full checkpoints of the big-lama model, so we can finetune it on other data?
I've just uploaded full checkpoint to https://disk.yandex.ru/d/wJ2Ee0f1HvasDQ subfoler big-lama-with-discr
- unlike other checkpoints, this one has discriminator and SegmPL weights included.
Please share your experience with finetuning - does it help and how dramatically.
Thanks so much! That is super helpful!
I'll close that issue for now - feel free to reopen if you have any issies with fine-tuning
Hello,
I am having some issues loading the big-lama-with-discr
for finetuning. Please correct me if I am wrong but I notice that the SegmPL weights are loss_segm_pl.impl...
in the .ckpt, but the current trainer loads it as loss_resnet_pl.impl...
https://github.com/saic-mdal/lama/blob/ede702b19b027ad2c0380419b2b71a90fe90a14f/saicinpainting/training/trainers/base.py#L110
After modifying this, I get the following error:
'Trying to restore training state but checkpoint contains only the model.'
KeyError: 'Trying to restore training state but checkpoint contains only the model. This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
@yzhouas did you have any success with this? I am wondering if it is just me.
Apparently, this is a known issue in Pytorch Lightning, and the problem for the suggested Pytorch Lightning 1.2.9 seems to be here:
# restore training state
self.restore_training_state(checkpoint)
So, a very ugly hack would be to bypass it as:
# restore training state
try:
self.restore_training_state(checkpoint)
except KeyError:
rank_zero_warn(
"File at `resume_from_checkpoint` Trying to restore training state but checkpoint contains only the model."
)
Hi @affromero !
Yeah, I forgot that we changed the name of this variable already after training big lama... Another possible solution is to just strip loss_segm_pl.impl...
from the checkpoint altogether - anyway it is initialized from a fixed ade20k checkpoint.
Trying to restore training state but checkpoint contains only the model.
I have not faced this issue yet. Have you resolved it?
Hi @windj007,
I looked into the Supplementary Material but I was not able to find what categories from Places Challenge were used for training Big-Lama. Could you please list these categories? Also, why haven't you used the entire Places Challenge for training Big-Lama?
Thank you
Hi @windj007 ,
I am having the some issue loading the big-lama-with-discr for finetuning, please correct me if I made any mistake.
I run this command:
python bin/train.py -cn big-lama location=my_dataset data.batch_size=10 +trainer.kwargs.resume_from_checkpoint=path\\to\\big-lama-with-discr\\best.ckpt
and got this error message:
RuntimeError: Error(s) in loading state_dict for DefaultInpaintingTrainingModule:
Missing key(s) in state_dict: "loss_resnet_pl.impl.conv1.weight", "loss_resnet_pl......
Unexpected key(s) in state_dict: "loss_segm_pl.impl.conv1.weight", "loss_segm_pl.impl....
I modified base.py Line 109:
From:
if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0: self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
To:
if self.config.losses.get("sege_pl", {"weight": 0})['weight'] > 0: self.loss_sege_pl = ResNetPL(**self.config.losses.sege_pl)
And and Missing key error is disappeared, but still have the Unexpected key error message:
Unexpected key(s) in state_dict: "loss_segm_pl.impl.conv1.weight", "loss_segm_pl.impl...._
Do you have any suggestion for this?
@marcelsan The list is there, on page 5.
why haven't you used the entire Places Challenge for training Big-Lama?
Bigger datasets need bigger models - and smaller models work better when the dataset is more focused. And Big-LaMa is not that big in terms of number of trainable parameters.
And and Missing key error is disappeared, but still have the Unexpected key error message:
The quick solution is a couple of comments above:
Another possible solution is to just strip loss_segm_pl.impl... from the checkpoint altogether - anyway it is initialized from a fixed ade20k checkpoint.
I should have fixed and reupploaded the checkpoint, but have not found time yet...
@windj007
Thanks for your reply.
I just remove the "loss_segm_pl" from the checkpoint and its worked.
Share the remove_checkpoint here:
https://drive.google.com/file/d/1YTiKZ1hQnKvTEbXIxFXjGg61pBAch_N7/view?usp=sharing
@Liang-Sen thank you!
I summed up the experience above and trained big-lama like this. If I made any mistakes, please correct me.
1.modified pytorch_lightning/trainer/connectors/checkpoint_connector.py Line 106:
https://github.com/PyTorchLightning/pytorch-lightning/blob/f9f4853f3663404362c7de8614a504b0403c25b8/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L106
# restore training state
self.restore_training_state(checkpoint)
to
# restore training state
try:
self.restore_training_state(checkpoint)
except KeyError:
rank_zero_warn(
"File at `resume_from_checkpoint` Trying to restore training state but checkpoint contains only the model."
)
2.modified lama-main/saicinpainting/training/trainers/base.py Line 109:
if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0:
self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
to
if self.config.losses.get("sege_pl", {"weight": 0})['weight'] > 0:
self.loss_sege_pl = ResNetPL(**self.config.losses.sege_pl)
3.run
python bin/train.py -cn big-lama location=my_dataset data.batch_size=10 +trainer.kwargs.resume_from_checkpoint=abspath\\to\\big-lama-with-discr-remove-loss_segm_pl.ckpt
https://drive.google.com/file/d/1YTiKZ1hQnKvTEbXIxFXjGg61pBAch_N7/view?usp=sharing
model shared by @Liang-Sen
@windj007 I just need to run the inference with lama-fourier-with-discr. As mentioned I have downloaded weights from https://drive.google.com/file/d/1YTiKZ1hQnKvTEbXIxFXjGg61pBAch_N7/view?usp=sharing mentioned by @Liang-Sen Can you please give me the config file for lama-fourier-with-discr.