DeepGraphLearning/KnowledgeGraphEmbedding

Did you just use the first batch to train the model? Can you help me solve my problem?

ToneLi opened this issue · 2 comments

I have question in (ROTATE) model.py. ROTATE uses next function to generate the data, shouldn't the next function be inside the loop? If use this function, I found in every step, ROTATE just chooses the first batch to train, because if next function is not in the loop, it will generates the first data in the list/dict.... who can help me answer my question?

class BidirectionalOneShotIterator(object):

def __init__(self, dataloader_head, dataloader_tail):
    self.iterator_head = self.one_shot_iterator(dataloader_head)
    # print("bb",next(self.iterator_head))  #一个batch的
    self.iterator_tail = self.one_shot_iterator(dataloader_tail)
    self.step = 0


def __next__(self):
    self.step += 1

    if self.step % 2 == 0:
        data = next(self.iterator_head)
    else:
        data = next(self.iterator_tail)
    print("self.step", self.step)
    return data


@staticmethod
def one_shot_iterator(dataloader):
    '''
    Transform a PyTorch Dataloader into python iterator
    '''
    while True:
        for data in dataloader:
            yield data

def train_step(model, optimizer, train_iterator, args):
'''
A single train step. Apply back-propation and return the loss
'''
model.train()
optimizer.zero_grad()
positive_sample, negative_sample, subsampling_weight, mode = next(train_iterator)

Hi Tone,

The "next" function will always give the next item in a "python iterator" (please search "python iterator" in Google for python syntax)

list and dict are not "python iterators".