Did you just use the first batch to train the model? Can you help me solve my problem?
ToneLi opened this issue · 2 comments
I have question in (ROTATE) model.py. ROTATE uses next function to generate the data, shouldn't the next function be inside the loop? If use this function, I found in every step, ROTATE just chooses the first batch to train, because if next function is not in the loop, it will generates the first data in the list/dict.... who can help me answer my question?
class BidirectionalOneShotIterator(object):
def __init__(self, dataloader_head, dataloader_tail):
self.iterator_head = self.one_shot_iterator(dataloader_head)
# print("bb",next(self.iterator_head)) #一个batch的
self.iterator_tail = self.one_shot_iterator(dataloader_tail)
self.step = 0
def __next__(self):
self.step += 1
if self.step % 2 == 0:
data = next(self.iterator_head)
else:
data = next(self.iterator_tail)
print("self.step", self.step)
return data
@staticmethod
def one_shot_iterator(dataloader):
'''
Transform a PyTorch Dataloader into python iterator
'''
while True:
for data in dataloader:
yield data
def train_step(model, optimizer, train_iterator, args):
'''
A single train step. Apply back-propation and return the loss
'''
model.train()
optimizer.zero_grad()
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)
Hi Tone,
The "next" function will always give the next item in a "python iterator" (please search "python iterator" in Google for python syntax)
list and dict are not "python iterators".