/MAE_Satellite

Official code repository for NeurIPS 2022 paper "SatMAE: Pretraining Transformers for Temporal and Multi-Spectral Satellite Imagery"

Primary LanguagePythonOtherNOASSERTION

SatMAE (NeurIPS 2022)

Project | Paper | Video

This is the official repository for the NeurIPS 2022 paper "SatMAE: Pre-training Transformers for Temporal and Multi-Spectral Satellite Imagery".

Authors: Yezhen Cong 1, Samar Khanna 1, Chenlin Meng, Patrick Liu, Erik Rozi, Yutong He, Marshall Burke, David B. Lobell, Stefano Ermon.

1 Equal contribution, order determined via coin flip.

Temporal SatMAE

Pre-training and finetuning on fMoW-Temporal are MEMORY-HEAVY. Please make sure you have enough memory. For context, we ran our experiments on 8 NVIDIA V100 GPUs.

fMoW-Temporal

You can download the fMoW dataset here. Then follow this piece of code to preprocess it. We will also upload the pre-processed dataset soon. The metadata files are here.

After you download the dataset and metadata files, your directory should look like:

<PATH_TO_DATASET_ROOT_FOLDER>
--- train_62classes.csv
--- val_62classes.csv
--- fmow
------- train
---------- airport
---------- ...
------- val
---------- airport
---------- ...

Pretraining

For pretraining, this is the default command:

python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=1 --master_port=1234 main_pretrain.py \
    --batch_size 8 --accum_iter 16 \
    --norm_pix_loss --epochs 100 \
    --blr 1.5e-4 --mask_ratio 0.75 \
    --input_size 224 --patch_size 16 \
    --model mae_vit_large_patch16 \
    --model_type temporal \
    --dataset_type temporal \
    --train_path <PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv
    --output_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --log_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --num_workers 8

Note that if you want to use wandb, login to wandb after activating conda and before running the code by doing wandb login in the shell, and add --wandb <YOUR_WANDB_PROJECT_NAME> to the command above. This applies to all following commands. You will also have to edit the entity name in main_pretrain.py and main_finetune.py.

Finetuning

To finetune, the basic command is:

python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=1 --master_port=1234 main_finetune.py \
    --output_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --log_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --batch_size 4 --accum_iter 4 \
    --model vit_large_patch16 --epochs 50 --blr 1e-3 --layer_decay 0.75 \
    --weight_decay 0.05 --drop_path 0.2 --reprob 0.25 \
    --mixup 0.8 --cutmix 1.0 --model_type temporal \
    --finetune <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> \
    --dist_eval --num_workers 8 --dataset temporal \
    --train_path <PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv \
    --test_path <PATH_TO_DATASET_ROOT_FOLDER>/val_62classes.csv

Note: If you are using our provided checkpoint, please add --nb_classes 1000. This is a legacy issue which won't affect the model performance since the actual number of classes is less than 1000. To resume a finetuning job, you can replace the --finetune <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> to --resume <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> in the command above. To finetune from scratch, simply omit the --finetune argument.

Evaluation

To evaluate, the basic command is:

python -m torch.distributed.launch --nproc_per_node=8 \
    --nnodes=1 --master_port=1234 main_finetune.py \
    --output_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --log_dir <PATH_TO_YOUR_OUTPUT_FOLDER> \
    --batch_size 16 \
    --model vit_large_patch16 \
    --model_type temporal \
    --resume <PATH_TO_YOUR_FINEtune_CHECKPOINT>  \
    --dist_eval --eval --num_workers 8 --dataset fmow_temporal \
    --train_path <PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv \
    --test_path <PATH_TO_DATASET_ROOT_FOLDER>/val_62classes.csv

Similarly, if you are using our provided checkpoint, please add --nb_classes 1000.

Model Weights

You can download model weights pre-trained on fMoW-temporal and weights finetuned on fMoW-temporal here.

fMoW Non-Temporal Checkpoints

Model Top 1 Accuracy Pretrain Finetune
ViT-Large 77.78% download, 800 epochs download, 29 epochs

fMoW Temporal Checkpoints

Model Top 1 Accuracy Pretrain Finetune
ViT-Large 79.99% download, 50 epochs download, 24 epochs

The accuracy of SatMAE on fMoW-Temporal (reported above) is achieved without using test-time augmentation (see paper).

Multi-Spectral SatMAE

Training multi-spectral SatMAE is similar to training temporal SatMAE.

fMoW-Sentinel Dataset

You can access and download the fMoW-Sentinel dataset we collected here. Try this link if the previous one doesn't display correctly.

Note that when loading the train.csv or val.csv files, you may have to preprocess a column called image_path. The image_path for any row can be constructed like this:

fmow-sentinel/<split>/<category>/<category>_<location_id>/<category>_<location_id>_<image_id>.tif

Pretraining

For pretraining, this is the default command:

python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \
--wandb satmae_pretrain \
--batch_size 16 --accum_iter 32 --blr 0.0001 \
--epochs 200 --warmup_epochs 20 --num_workers 16 \
--input_size 96 --patch_size 8 \
--mask_ratio 0.75 \
--model_type group_c \
--dataset_type sentinel --dropped_bands 0 9 10 \
--grouped_bands 0 1 2 6 --grouped_bands 3 4 5 7 --grouped_bands 8 9 \
--train_path /home/fmow-sentinel-filtered-csv/train.csv \
--output_dir /home/experiments/pretrain \
--log_dir /home/experiments/pretrain

You can use the --spatial_mask argument to toggle on consistent spatial masking (rather than independent masking). See paper for details (independent masking performs better).

To resume a pretraining job, you can use --resume PATH/TO/CKPT.PTH (eg: --resume /home/experiments/pretrain/checkpoint-175.pth).

Finetuning

To finetune, the basic command is:

python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--wandb satmae_finetune \
--batch_size 8 --accum_iter 16 --blr 0.0002 \
--epochs 30 --num_workers 16 \
--input_size 96 --patch_size 8  \
--weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
--model_type group_c  \
--dataset_type sentinel --dropped_bands 0 9 10 \
--train_path /home/fmow-sentinel-filtered-csv/train.csv \
--test_path /home/fmow-sentinel-filtered-csv/val.csv \
--output_dir /home/experiments/finetune \
--log_dir /home/experiments/finetune \
--finetune /home/experiments/pretain/checkpoint-199.pth

To finetune from scratch, simply omit the --finetune argument. To resume a finetuning job, you can replace --finetune path/to/pretrained_checkpoint.pth with --resume path/to/finetune_checkpoint.pth in the command above.

Model Weights

We will be uploading model checkpoints here. The pretrained checkpoints have been trained for 200 epochs, so the accuracy numbers might be higher than in the paper (where the models were pretrained for 50 epochs).
The Top 1 accuracy is measured on the validation set of fMoW-Sentinel.

Model Top 1 Accuracy Pretrain Finetune
ViT-Base (200 epochs) 62.65% download download
ViT-Large (200 epochs) 63.84% download download

Acknowledgements

Code from this repository is inspired from the Masked Autoencoders (MAE) repository.

Citation

If you found our project helpful, please cite our paper:

@inproceedings{
    satmae2022,
    title={Sat{MAE}: Pre-training Transformers for Temporal and Multi-Spectral Satellite Imagery},
    author={Yezhen Cong and Samar Khanna and Chenlin Meng and Patrick Liu and Erik Rozi and Yutong He and Marshall Burke and David B. Lobell and Stefano Ermon},
    booktitle={Advances in Neural Information Processing Systems},
    editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
    year={2022},
    url={https://openreview.net/forum?id=WBhqzpF6KYH}
}