How many tags can this project train at the same time?
datar001 opened this issue · 3 comments
Hi, thanks for your sharing.
How many tags have you tried to train? What's the relation between the number of tags and that of training iterations?
And How many tags will you recommend at the once training?
I've succeeded to train 6 tags at the same time. In experiment, I found 50k per tag is enough (i.e., 20k for 6 tags).
HiSD supports various numbers of tags but you should increase the training iteration and the model capacity.
Using gradient accumulation and train all tags in one iteration is also important (so you need to change the code a little).
Sorry for the typo, it should be 200k for 3 tags with 7 attributes.
You get the idea of the gradient accumulation in a right way, and you can clarify the update code like:
def update(self, x, y, i, j, j_trg, iterations):
this_model = self.models.module if self.multi_gpus else self.models
# gen
for p in this_model.dis.parameters():
p.requires_grad = False
for p in this_model.gen.parameters():
p.requires_grad = True
self.loss_gen_adv, self.loss_gen_sty, self.loss_gen_rec, \
x_trg, x_cyc, s, s_trg = self.models((x, y, i, j, j_trg), mode='gen')
self.loss_gen_adv = self.loss_gen_adv.mean()
self.loss_gen_sty = self.loss_gen_sty.mean()
self.loss_gen_rec = self.loss_gen_rec.mean()
# dis
for p in this_model.dis.parameters():
p.requires_grad = True
for p in this_model.gen.parameters():
p.requires_grad = False
self.loss_dis_adv = self.models((x, x_trg, x_cyc, s, s_trg, y, i, j, j_trg), mode='dis')
self.loss_dis_adv = self.loss_dis_adv.mean()
if (iterations + 1) % self.tag_num == 0:
nn.utils.clip_grad_norm_(this_model.gen.parameters(), 100)
nn.utils.clip_grad_norm_(this_model.dis.parameters(), 100)
self.gen_opt.step()
self.dis_opt.step()
self.gen_opt.zero_grad()
self.dis_opt.zero_grad()
update_average(this_model.gen_test, this_model.gen)
return self.loss_gen_adv.item(), \
self.loss_gen_sty.item(), \
self.loss_gen_rec.item(), \
self.loss_dis_adv.item()
And you need to decrease the learning rate before backward (maybe lr/tag_num) since the gradient by 'sum' rather than 'average'.