Zj-BinXia/SSL

Code confusion

Closed this issue · 1 comments

    def train(self):
        self.model.train()
        self.train_sampler.set_epoch(100000)
        self.prefetcher.reset()
        train_data = self.prefetcher.next()
        while train_data is not None:
            self.total_iter += 1

            self.lq = train_data['lq'].to(self.device)
            self.gt = train_data['gt'].to(self.device)

            finished = self.optimize_parameters(self.total_iter)
            if finished:
                return True

In this code snippet(from SSL/blob/master/basicsr/pruner/SSL_pruner.py), while training the model in a while loop, it might be necessary to add train_data = self.prefetcher.next() within the loop. Without the complete code context, but based on your description, it appears that the same data is being used repeatedly for model training without fetching new data.

should it be or not:

    def train(self):
        self.model.train()
        self.train_sampler.set_epoch(100000)
        self.prefetcher.reset()
        train_data = self.prefetcher.next()
        while train_data is not None:
            self.total_iter += 1

            self.lq = train_data['lq'].to(self.device)
            self.gt = train_data['gt'].to(self.device)

            finished = self.optimize_parameters(self.total_iter)
            train_data = self.prefetcher.next()  #add this ?
            if finished:
                return True

Yes, you are right. I just found out that I seem to have deleted this sentence by mistake while deleting the debug code. I have added it now.