How to create LMDB dataset for finetuning on custom dataset?
siddagra opened this issue · 11 comments
I am trying to finetune this model on a custom dataset. How would I go about creating an LMDB dataset?
Currently my ground truth file is as follows:
img_filepath1.jpg label_text1
img_filepath2.jpg label_text2
I tried using the converters from MMOCR and CDistNet (https://github.com/simplify23/CDistNet) but they do not seem to work. Showed that there were no samples in the datset. I cannot find any converter from gt.txt
to train.mdb
in the code provided by you.
Please let me know which one you used or share code for the converter you used if possible. Thank you.
Nvm... My bad. Got it to work. Had to keep the train labels in data/train/real
they were in data/train
.
For future reference if anyone needs to train on custom dataset, directory structure should be as follows:
data
├── test
│ ├── data.mdb
│ └── lock.mdb
├── train
│ └── real
│ ├── data.mdb
│ └── lock.mdb
└── val
├── data.mdb
└── lock.mdb
mdb can be created using https://github.com/chibohe/CdistNet-pytorch/blob/main/tool/create_lmdb_dataset.py
Change the follwing in main.yaml
:
train_dir: ???
to
train_dir: train
so that it looks for the train subdirectory in data directory. You can also change it similarly for val_dir
as well.
Furthermore, if one wants to load pretrained checkpoint found at torchhub (parseq-bb5792a6.pt
) in train.py
, to continue finetuning from the pretrained model, they should add the following import
to the train.py
:
from strhub.models.utils import load_from_checkpoint
and also add the following in train.py
after the model.summarize(...)
line (line 61):
model.summarize(max_depth=1 if config.model.name.startswith('parseq') else 2)
model = load_from_checkpoint("path/to/parseq-bb5792a6.pt", **(config.model)).to('cuda:0')
Nvm... My bad. Got it to work. Had to keep the train labels in
data/train/real
they were indata/train
.For future reference if anyone needs to train on custom dataset, directory structure should be as follows:
data ├── test │ ├── data.mdb │ └── lock.mdb ├── train │ └── real │ ├── data.mdb │ └── lock.mdb └── val ├── data.mdb └── lock.mdb
mdb can be created using https://github.com/chibohe/CdistNet-pytorch/blob/main/tool/create_lmdb_dataset.py
I suggest using the create_lmdb_dataset.py
script in this repo. The original implementation of checkImageIsValid()
uses cv2.imdecode()
, which doesn't fail for some corrupted images (e.g. some images in MJSynth) as I found out.
I would like to mention as well that making use of the variable config.trainer.resume_from_checkpoint
that appears in train.py
by adding it in main.yaml
under trainer
leads to a KeyError: 'state_dict'
at execution of trainer.fit()
My guess is that if you give the path to the checkpoint via the variable config.trainer.resume_from_checkpoint
then it got embbed within trainer's arguments and causes problems down the line.
Conclusion: just follow @siddagra way of doing things and hardcode the path to your checkpoint in train.py
I would like to mention as well that making use of the variable
config.trainer.resume_from_checkpoint
that appears intrain.py
by adding it inmain.yaml
undertrainer
leads to aKeyError: 'state_dict'
at execution oftrainer.fit()
My guess is that if you give the path to the checkpoint via the variable
config.trainer.resume_from_checkpoint
then it got embbed within trainer's arguments and causes problems down the line.Conclusion: just follow @siddagra way of doing things and hardcode the path to your checkpoint in
train.py
Yeah, that is why I used load_from_checkpoint("parseq-bb5792a6.pt", **(config.model)).to('cuda:0')
The issue may be the fact that this model is a pytorch model.. whereas a lightning model (.ckpt
) requires more parameters than just the weights to run training. Though I am not sure.
@bmusq please use the ckpt_path
parameter (./train.py ckpt_path=/path/to/lightning.ckpt
) instead of trainer.resume_from_checkpoint
.
Aside from the model weights, PyTorch Lightning's checkpoint also contains some training parameters. ckpt_path
expects a PyTorch Lightning checkpoint (*.ckpt
). The released weights (*.pt
) contain the model parameters only, and won't work with ckpt_path
.
You may use the code given by @siddagra to finetune the pretrained weights.
@siddagra follow the paper, the model can spell correction by dictionary(in this repo is nltk). If for china language,Does change to china nltk or not?
Furthermore, if one wants to load pretrained checkpoint found at torchhub (
parseq-bb5792a6.pt
) intrain.py
, to continue finetuning from the pretrained model, they should add the followingimport
to thetrain.py
:from strhub.models.utils import load_from_checkpoint
and also add the following in
train.py
after themodel.summarize(...)
line (line 61):model.summarize(max_depth=1 if config.model.name.startswith('parseq') else 2) model = load_from_checkpoint("path/to/parseq-bb5792a6.pt", **(config.model)).to('cuda:0')
@siddagra @bmusq thank for your example code
but when I pull code, and add code as you recommned, I get error:
model = load_from_checkpoint("weights/parseq-bb5792a6.pt", **(config.model)).to('cuda:0')
File "/home/tupk/tupk/ocr-digits/parseq/strhub/models/utils.py", line 85, in load_from_checkpoint
model = ModelClass.load_from_checkpoint(checkpoint_path, **kwargs)
File "/home/tupk/anaconda3/envs/ocr/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 161, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/home/tupk/anaconda3/envs/ocr/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 209, in _load_model_state
keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
KeyError: 'state_dict'
@phamkhactu fine tuning is officially supported with pretrained weights. This trick of ours is no longer needed.
Please see #9
I would like to mention as well that making use of the variable
config.trainer.resume_from_checkpoint
that appears intrain.py
by adding it inmain.yaml
undertrainer
leads to aKeyError: 'state_dict'
at execution oftrainer.fit()
My guess is that if you give the path to the checkpoint via the variableconfig.trainer.resume_from_checkpoint
then it got embbed within trainer's arguments and causes problems down the line.
Conclusion: just follow @siddagra way of doing things and hardcode the path to your checkpoint intrain.py
Yeah, that is why I used
load_from_checkpoint("parseq-bb5792a6.pt", **(config.model)).to('cuda:0')
The issue may be the fact that this model is a pytorch model.. whereas a lightning model (
.ckpt
) requires more parameters than just the weights to run training. Though I am not sure.
hardcode the path gives KeyError: 'pytorch-lightning_version'
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
parseq-bb5792a6.pt how do we find this check point