PyTorch and Torchvision needs to be installed before running the scripts, together with PIL
for data-preprocessing and tqdm
for showing the training progress.
To run this repository, kindly install python 3.7 and PyTorch 1.5.0 with Anaconda.
You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/
Create a new environment and install PyTorch and torchvision on it:
conda create --name mseg python=3.7
conda activate mseg
conda install pytorch=1.5.0
conda install torchvision -c pytorch
Clone this repository:
git clone https://github.com/ahirsharan/MetaSegNet.git
The code structure is based on MTL-template and Pytorch-Segmentation.
.
├── Datasets
|
├── COCOAug
├── Pascal5Aug
├── FSS1000Aug
|
├── MetaSegNet
|
├── FewShotPreprocessing.py # utility to organise the Few-shot data into train and novel
├── cocogen.py # utility to organise the Few-shot data into train and novel after generating masks
├── augment.py # For generic data Augmentation
|
|
├── dataloader
| ├── dataset_loader.py # data loader for pre datasets
| └── samplers.py # samplers for meta task dataset(Few-Shot)
|
|
├── models
| ├── mtl.py # meta-transfer class
| └── metasegnet.py # Resnet-9 class
|
├── trainer
| ├── meta.py # meta-train trainer class
|
|
├── utils
| ├── gpu_tools.py # GPU tool functions
| ├── metrics.py # Metrics functions
| ├── losses.py # Loss functions
| ├── lovasz_losses.py # Lovasz Loss function
| └── misc.py # miscellaneous tool functions
|
├── main.py # the python file with main function and parameter settings
└── run_meta.py # the script to run meta-train and meta-test phases
Run meta-train and meta-test phase:
python run_meta.py
The test predictions and logs(models) will be stored in the same root directory under resultsx and logsx where x can be changed in trainer/meta.py . The tensorboardX log for loss and mIoU would be stored in runs in the MetaSegNet directory.
Hyperparameters and options in main.py
.
model_type
The network architecturemtype
The ablation study argument for choosing MetaSegNet, MetaSegNet-NG and MetaSegConvvaldata
The ablation study argument for choosing validation set alsodataset
Meta datasetphase
train or testseed
Manual seed for PyTorch, "0" means using random seedgpu
GPU iddataset_dir
Directory for the imagesmax_epoch
Epoch number for meta-train phasenum_batch
The number for different tasks used for meta-trainway
Way number, how many classes in a task(Background excluded)train_query
Shots: The number of training samples for each class in a tasktest_query
The number of test samples for each class in a taskmeta_lr
Learning rate for embedding modelbase_lr
Learning rate for the inner loopupdate_step
The number of updates for the inner loopstep_size
The number of epochs to reduce the meta learning ratesgamma
Gamma for the meta-train learning rate decayinit_weights
The pretained weights for meta-train phasemeta_label
Additional label for meta-train