This is the code for the AISTATS2021 paper "Distributionally Robust Optimization for Deep Kernel Multiple Instance Learning". We have evaluation on five different datasets in the paper: (1) SanghaiTech, (2) UCF-Crime, (3) Avenue, (4) SanghaiTech Outlier, and (5) UCF-Crime Multimoal.
To train the model for SanghaiTech Dataset execute the following command: python train_mil_sanghaitech.py split_no rep_no eta
Trains the model and Saves (1) all losses, (2) testing AUCS, and (3) validation AUCS under the directory logs/SanghaiTech in each 10 number of iterations. Also, stores the resulting best model under the directory trained_models/SanghaiTech
To train the model for UCF-Crime Dataset, execute the following command:
python train_mil_ucfcrime.py rep_no eta
Trains the model and Saves (1) all losses, and (2) testing AUCS under the directory logs/UCF_Crime in each 10 number of iterations. Also, stores the resulting best model under the directory trained_models/UCF_Crime
To train the model for Avenue Dataset, execute the following command:
python train_mil_avenue.py cv_no rep_no eta
Trains the model and Saves (1) total loss, and (2) testing AUCS under the directory logs/Avenue in each 10 number of iterations. Also, stores the resulting best model under the directory trained_models/Avenue
To train the model for SanghaiTech Outlier Dataset, execute the following command:
python train_mil_sanghaitech_outlier.py split_no rep_no eta
Trains the model and Saves (1) all losses, (2) testing AUCS, and (3) validation AUCS under the directory logs/SanghaiTech_Outlier in each 10 number of iterations. Also, stores the resulting best model under the directory trained_models/SanghaiTech_Outlier
To train the model for UCF-Crime Dataset, execute the following command:
python train_mil_ucfcrime_multimodal.py rep_no eta
Trains the model and Saves (1) total loss, and (2) testing AUCS under the directory logs/UCF_Crime_Multimodal in each 10 number of iterations. Also, stores the resulting best model under the directory trained_models/UCF_Crime_Multimodal
Where;