Create the conda environment with:
conda env create --prefix ./local_venv -f environment.yaml
Then run
python train_segmentation.py [--options, see train_segmentation.py arguments]
Default paramenters:
- dataset: "landcoverai"
- train-mode: "SSL4EO-pretrain"
- epochs: 100
- batch-size: 32
- workers: 4
- segmentation model: "deeplabv3+", more models here.
- segmentation loss: cross entropy
- segmentation backbone: resnet50, more backbones here
- encoder weights: ResNet50_Weights.SENTINEL2_RGB_MOCO, more ssl pretrained weights here
- SSL4EO-pretrain: https://wandb.ai/carlosh93/SSL-Pretraining-Satellite-Images/runs/zzmh6azv
- SSL4EO-pretrain-finetune: https://wandb.ai/carlosh93/SSL-Pretraining-Satellite-Images/runs/aa4vwyr0
- imagenet-pretrain: https://wandb.ai/carlosh93/SSL-Pretraining-Satellite-Images/runs/6tsrm0ol
Followed similar configuration described here
Maybe you will meet with this issue: microsoft/torchgeo#1143
If so, and/or you want to use WandB logger, you need to do the following steps:
Go to local_venv/lib/python3.10/site-packages/torchgeo/trainers/segmentation.py
and change the following lines (starting from line 206)
try:
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard
for key in ["image", "mask", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
summary_writer = self.logger.experiment
summary_writer.add_figure(
f"image/{batch_idx}", fig, global_step=self.global_step
)
plt.close()
except ValueError:
pass
to this lines
try:
datamodule = self.trainer.datamodule
batch["prediction"] = y_hat_hard
for key in ["image", "mask", "prediction"]:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
logger = self.logger
plt.savefig(f"image_{batch_idx}.png")
logger.log_image(key=f"image_{batch_idx}", images=[f"image_{batch_idx}.png"])
plt.close()
except ValueError:
pass
If not, you have to use the TensorBoardLogger as mentioned in the issue.