/RMINN

Primary LanguagePython

复现论文Revisiting Multiple Instance Neural Networks

1.文件说明

“train_mi.py”文件是mi-Net模型。
“MI_net.py”文件是MI-net模型。
“MI_net_DS.py”文件是 MI-net with deep supervision模型。
"MI_Res.py"文件是MI-net with Res 模型。
"loader.py"文件是针对MUSK1/2数据集的导入方法。
"pre_.py"文件是针对fox/elephant/tiger数据集的导入方法,导入数据集的同时对数据集的顺序进行了打乱,由于没有找到文中的第三个数据集,故第三个数据集的导入函数尚未编写。
"main.py"文件是训练、测试文件,可以通过修改注释改变导入的数据集和模型。

2.环境

CPU的实验manjaro系统下完成,利用anaconda建立虚拟环境,使用CPU版本的pytorch进行实验,python版本为3.7.11。GPU服务器的环境为ubuntu系统,CUDA11.1。

3.训练过程

对每个网络的每个数据集作十折交叉验证,即将数据集分为十份,每次取9份做训练,取1份做测试,总共做十次,使得每份数据都有做过测试集,取十次结果的平均数。总共训练五次求平均数,每次都在训练集上做多次迭代训练再在测试集上做验证.迭代次数根据训练模型在测试集上的泛化性能决定,例如实验中发现MI-net_DS网络在fox数据集上迭代100次时泛化性能最好。可能网络的规模和深度比较小,用于传输数据的时间比例过高,利用GPU训练后速度反而有所下降,但是训练的效果提高了,可能是因为GPU的计算精度更高。下表中标粗的部分就是利用GPU重新训练后的结果。
训练过程采用随机梯度下降算法,学习率设置为0.01

4.实验结果

1. 在max池化函数下的实验结果

MUSK1 MUSK2 fox elepant tiger
mi-net 81.957% 79.216% 62.200% 85.000% 81.700%
MI-net 87.174% 85.098% 62.400% 86.400% 81.400%
MI-Net-DS 92.826% 86.078% 64.500% 87.200% 82.400%
MI-Net-RS 88.913% 85.686% 62.700% 87.800% 82.600%

优化器的lr=0.01,momentum=0.5
其中在MUSK2数据集上需要在优化器中添加weight-decay参数以进行权重衰减再训练才可以得到文中的预测效果
其中momentum=0.9,weight-decay=0.003 , DS。
momentum=0.9,weight-decay=0.003 , RS。
得到上述结果的迭代次数

MUSK1 MUSK2 fox elephant tiger
mi-net 120 300 200 150 200
MI-net 200 60 200 200 200
MI-net-DS 80 60 100 150 100
MI-net-RS 150 60 150 100 100

2. 池化函数的影响

对MI-net-DS网络分别应用max 、mean、lse池化函数,lse函数的r参数设置为2

MUSK1 MUSK2 fox elephant tiger
max 92.826% 83.333% 64.500% 87.200% 82.400%
mean 86.739% 78.039% 64.400% 85.800% 84.200%
lse 87.391% 81.373% 63.900% 87.100% 84.00%

2022.2.11补充,复现文章Attention-based Multiple Instance Learning

这篇文章的核心**就是在基于包分类的方法的基础上,对神经网络输出的最后一层的结果采用注意力机制,使池化方式也是可学习的,从而使最能代表包特征的instance在结果中占据的权重最大。
其中在benchmark数据集上的网络结构就是将MI-net的池化函数变为注意力机制
注意力机制的代码在MIL_pooling.py中,被写为attention类

MUSK1 MUSK2 fox elephant tiger
Attention 87.174% 82.157% 59.100% 83.500% 85.900%
Attention-gate 87.174% 82.7451% 59.100% 84.700% 83.900%

其中fox、musk2数据集使用文中的参数无法得到最优结果
fox : lr = 0.005 , momentum=0.5,weight-decay=0
musk2: lr = 0.0005 , momentum=0.9, weight-decay=0.003