验证集的使用
mehrooio opened this issue · 4 comments
mehrooio commented
你好,根据main.py里面
source_set = EmbeddingDataset(data_root, args.img_size, 'train')
source_loader = DataLoader(
source_set, num_workers=0, batch_size=64, shuffle=True)
test_set = EmbeddingDataset(data_root, args.img_size, 'val')
test_loader = DataLoader(test_set, num_workers=0, batch_size=32, shuffle=False)
然后跳转到datasets.py的class EmbeddingDataset(Dataset):
self.ImagesDir = os.path.join(dataroot,'images')
self.data = loadSplit(splitFile = os.path.join(dataroot,'train' + '.csv'))
self.data = collections.OrderedDict(sorted(self.data.items()))
keys = list(self.data.keys())
self.classes_dict = {keys[i]:i for i in range(len(keys))} # map NLabel to id(0-99)
self.Files = []
self.belong = []
for c in range(len(keys)):
num = 0
**num_train = int(len(self.data[keys[c]]) * 9 / 10)**
for file in self.data[keys[c]]:
if type == 'train' and num <= num_train:
self.Files.append(file)
self.belong.append(c)
**elif type=='val' and num>num_train:
self.Files.append(file)
self.belong.append(c)**
num = num+1
请问验证集的读取是在这一块嘛,根据我理解的,source_set 读取的是train.csv里面百分之九十的数据,然后test_set 读取的是train.csv里面百分之十的数据,然后没明白test_set是否用到了val.csv的数据
Yikai-Wang commented
这里 test_set 是为了验证 backbone 学习性能的,所以选取了与训练集同样类别的数据,没有使用 val.csv 的数据
mehrooio commented
谢谢你的回复,那请问一下代码块的哪一部分使用了val.csv的数据,在main.py里面我没有找到,只发现了train.csv和test.csv数据的使用
Yikai-Wang commented
实验中用 val.csv选取 ICI 的超参,因为超参已经选好了放出来的 code 就直接用选好的超参在 test.csv 上跑测试了
mehrooio commented
理解了,非常感谢你的回答