Question about the implement of Meta-Joint Optimzation
zero0kiriyu opened this issue · 0 comments
zero0kiriyu commented
I try to reimplement the Meta-joint optimization part, but it always output inf loss after several iter. @cleardusk
model_vdc = mobilenet()
model_wpdc = copy.deepcopy(model_vdc)
optimizer_vdc = torch.optim.SGD(params=model_vdc.parameters(),lr=lr)
optimizer_wpdc = torch.optim.SGD(params=model_wpdc.parameters(),lr=lr)
for epoch in range(N):
for batch_idx,batch in enumerate(trainloader):
if batch_idx == 0 or (batch_idx != 0 and batch_idx % meta_joint_k != 0):
# update by vdc loss
loss_vdc.backward()
optimizer_vdc.step()
optimizer_vdc.zero_grad()
# update by wpdc loss
loss_wpdc.backward()
optimizer_wpdc.step()
optimizer_wpdc.zero_grad()
elif batch_idx != 0 and batch_idx % meta_joint_k == 0:
model_vdc.eval();model_wpdc.eval()
# calculate the vdc loss for two model
......
if loss_vdc_vdc > loss_vdc_wpdc:
model_vdc.load_state_dict(copy.deepcopy(model_wpdc))
optimizer_vdc.load_state_dict(copy.deepcopy(optimizer_wpdc))
else:
model_wpdc.load_state_dict(copy.deepcopy(model_vdc))
optimizer_wpdc.load_state_dict(copy.deepcopy(optimizer_vdc))
model_vdc.training();model_wpdc.training()