santi-pdp/segan

OOM with small batch size on GPU

Mega4alik opened this issue · 1 comments

Hi,
I'm on @lordet01 implementation - https://github.com/lordet01/segan

In model.py the following part of code leads to OOM even with small batch_size.
Could anyone solve this problem?

while not coord.should_stop():
                start = 0 #timeit.default_timer()
                if counter % config.save_freq == 0:
                    for d_iter in range(self.disc_updates):
                        _d_opt, _d_sum, \
                        d_fk_loss, \
                        d_rl_loss = self.sess.run([d_opt, self.d_sum,
                                                   self.d_fk_losses[0],
                                                   #self.d_nfk_losses[0],
                                                   self.d_rl_losses[0]])
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip)
                        #d_nfk_loss, \

                    # now G iterations
                    _g_opt, _g_sum, \
                    g_adv_loss, \
                    g_l1_loss = self.sess.run([g_opt, self.g_sum,
                                               self.g_adv_losses[0],
                                               self.g_l1_losses[0]])
                else:
                    for d_iter in range(self.disc_updates):
                        _d_opt, \
                        d_fk_loss, \
                        d_rl_loss = self.sess.run([d_opt,
                                                   self.d_fk_losses[0],
                                                   #self.d_nfk_losses[0],
                                                   self.d_rl_losses[0]])
                        #d_nfk_loss, \
                        if self.d_clip_weights:
                            self.sess.run(self.d_clip)

                    _g_opt, \
                    g_adv_loss, \
                    g_l1_loss = self.sess.run([g_opt, self.g_adv_losses[0],

                                               self.g_l1_losses[0]])

Same problem,did you solve it?