microsoft/robustlearn

diversify 中 c 和 c1 分别来自minibatch[1]和minibatch[4] 是这样吗?

Diting-li opened this issue · 0 comments

def update_a(self, minibatches, opt):
all_x = minibatches[0].cuda().float()
all_c = minibatches[1].cuda().long()
all_d = minibatches[4].cuda().long()

def update_d(self, minibatch, opt):
all_x1 = minibatch[0].cuda().float()
all_d1 = minibatch[1].cuda().long()
all_c1 = minibatch[4].cuda().long()