Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into InChI (International Chemical Identifier) texts.
This repo is partially based on the following resources:
- Y.Nakama's tokenization
- Heng's transformer decoder
- Sam Stainsby's external images creation updated by ZFTurbo
- install and activate the conda environment
- download and extract the data into
/data/bms/
- extract and move sample_submission_with_length.csv.gz into
/data/bms/
- tokenize training inputs:
python datasets/prepocess2.py
- if you want to use pseudo labeling, execute:
python datasets/pseudo_prepocess2.py your_submission_file.csv
- if you want to use external images, you can create with the following commands:
python r09_create_images_from_allowed_inchi.py
python datasets/extra_prepocess2.py
- and also install apex
This repo supports training any VIT/SWIN/CAIT transformer models from timm as encoder together with the fairseq transformer decoder.
Here is an example configuration to train a SWIN swin_base_patch4_window12_384
as encoder and 12 layer 16 head fairseq decoder:
python -m torch.distributed.launch --nproc_per_node=N train.py --logdir=logdir/ \
--pipeline --train-batch-size=50 --valid-batch-size=128 --dataload-workers-nums=10 --mixed-precision --amp-level=O2 \
--aug-rotate90-p=0.5 --aug-crop-p=0.5 --aug-noise-p=0.9 --label-smoothing=0.1 \
--encoder-lr=1e-3 --decoder-lr=1e-3 --lr-step-ratio=0.3 --lr-policy=step --optim=adam --lr-warmup-steps=1000 --max-epochs=20 --weight-decay=0 --clip-grad-norm=1 \
--verbose --image-size=384 --model=swin_base_patch4_window12_384 --loss=ce --embed-dim=1024 --num-head=16 --num-layer=12 \
--fold=0 --train-dataset-size=0 --valid-dataset-size=65536 --valid-dataset-non-sorted
For pseudo labeling, use --pseudo=pseudo.pkl
. If you want subsample the pseudo dataset, use: --pseudo-dataset-size=448000
.
For using external images, use --extra
(--extra-dataset-size=448000
).
After training, you can also use Stochastic Weight Averaging (SWA) which gives a boost around 0.02:
python swa.py --image-size=384 --input logdir/epoch-17.pth,logdir/epoch-18.pth,logdir/epoch-19.pth,logdir/epoch-20.pth
Evaluation:
python -m torch.distributed.launch --nproc_per_node=N eval.py --mixed-precision --batch-size=128 swa_model.pth
Inference:
python -m torch.distributed.launch --nproc_per_node=N inference.py --mixed-precision --batch-size=128 swa_model.pth
Normalization with RDKit:
./normalize_inchis.sh submission.csv