cleardusk/3DDFA_V2

Question about the implement of Meta-Joint Optimzation

zero0kiriyu opened this issue · 0 comments

I try to reimplement the Meta-joint optimization part, but it always output inf loss after several iter. @cleardusk

model_vdc = mobilenet()
model_wpdc = copy.deepcopy(model_vdc)

optimizer_vdc = torch.optim.SGD(params=model_vdc.parameters(),lr=lr)
optimizer_wpdc = torch.optim.SGD(params=model_wpdc.parameters(),lr=lr)

for epoch in range(N):
    for batch_idx,batch in enumerate(trainloader):
        if batch_idx == 0 or (batch_idx != 0 and batch_idx % meta_joint_k != 0):
            # update by vdc loss
            loss_vdc.backward()
            optimizer_vdc.step()
            optimizer_vdc.zero_grad()
           
            # update by wpdc loss
            loss_wpdc.backward()
            optimizer_wpdc.step()
            optimizer_wpdc.zero_grad()
        elif batch_idx != 0 and batch_idx % meta_joint_k == 0:
            model_vdc.eval();model_wpdc.eval()
            # calculate the vdc loss for two model
            ......

            if loss_vdc_vdc > loss_vdc_wpdc:
                model_vdc.load_state_dict(copy.deepcopy(model_wpdc))
                optimizer_vdc.load_state_dict(copy.deepcopy(optimizer_wpdc))
            else:
                model_wpdc.load_state_dict(copy.deepcopy(model_vdc))
                optimizer_wpdc.load_state_dict(copy.deepcopy(optimizer_vdc))
           model_vdc.training();model_wpdc.training()