falkaer/artist-group-factors

Clarification: GradNorm implementation

jeffbaena opened this issue · 1 comments

Dear authors,
thanks for your great work! I am going trough your implementation of gradnorm. I am reproducing your code for another task (optical flow estimation).

Is there any workaround to avoid using retain graph = True when doing backward?
Why do we also need to retain the graph when calculating the partial "derivative?"

    # compute and retain gradients
    total_weighted_loss.backward(retain_graph=True)
    
    # GRADNORM - learn the weights for each tasks gradients
    
    # zero the w_i(t) gradients since we want to update the weights using gradnorm loss
    self.weights.grad = 0.0 * self.weights.grad
    
    W = list(self.model.mtn.shared_block.parameters())
    norms = []
    
    for w_i, L_i in zip(self.weights, task_losses):
        # gradient of L_i(t) w.r.t. W
        gLgW = torch.autograd.grad(L_i, W, retain_graph=True)
        
        # G^{(i)}_W(t)
        norms.append(torch.norm(w_i * gLgW[0]))
    
    norms = torch.stack(norms)

this leads to an out of memory issue which I am not able to avoid, did you face a similar problem?

Thanks,
Stefano

Hi Stefano,

With retain_graph=False, PyTorch will free parts of the computational graph on-the-fly as they become unneeded during backpropagation. This results in a lower peak memory usage than retain_graph=True, where the computational graph is not freed at all, and we thus have to keep both the computational graph and the resulting gradients in memory.

Since we backpropagate through the same computational graph (which is tied to total_weighted_loss) multiple times, we thus have to use retain_graph=True for all gradient computations except for the last one:

self.weights.grad = torch.autograd.grad(grad_norm_loss, self.weights)[0]

However, I was fairly new to PyTorch when writing this code and there does indeed appear to be a way to reduce the peak memory usage! By changing the order of computation so that we perform all the gradnorm related computations first - storing the resulting gradient - then perform a full backward pass without retain_graph=True, and then finally overwrite the weight gradient computed by the backward pass with the gradnorm one, we can let PyTorch free the computational graph-on-the-fly during the memory-intensive backwards pass, like so:

# compute task losses
task_losses = tuple(crit(out, tar) for out, tar, crit in zip(task_outs, targets, self.criterions))
task_losses = torch.stack(task_losses)

# get the sum of weighted losses
weighted_losses = self.weights * task_losses
total_weighted_loss = weighted_losses.sum()

self.optimizer.zero_grad()

# GRADNORM - learn the weights for each tasks gradients

W = list(self.model.mtn.shared_block.parameters())
norms = []

for w_i, L_i in zip(self.weights, task_losses):
    # gradient of L_i(t) w.r.t. W
    gLgW = torch.autograd.grad(L_i, W, retain_graph=True)
    
    # G^{(i)}_W(t)
    norms.append(torch.norm(w_i * gLgW[0]))

norms = torch.stack(norms)

# set L(0)
# if using log(C) init, remove these two lines
if t == 0:
    self.initial_losses = task_losses.detach()

# compute the constant term without recording it in the computational graph
# as it should stay constant during back-propagation
with torch.no_grad():
    
    # loss ratios \curl{L}(t)
    loss_ratios = task_losses / self.initial_losses
    
    # inverse training rate r(t)
    inverse_train_rates = loss_ratios / loss_ratios.mean()
    
    constant_term = norms.mean() * (inverse_train_rates ** self.alpha)

# write out the gradnorm loss L_grad and set the weight gradients
grad_norm_loss = (norms - constant_term).abs().sum()
gweights = torch.autograd.grad(grad_norm_loss, self.weights, retain_graph=True)[0]

# compute and accumulate regular gradients
total_weighted_loss.backward()

# overwrite the gradients computed by backward with the gradnorm ones
self.weights.grad = gweights

# apply gradient descent
self.optimizer.step()

# renormalize the gradient weights
with torch.no_grad():
    
    normalize_coeff = len(self.weights) / self.weights.sum()
    self.weights = self.weights * normalize_coeff

Unfortunately I don't have the time to test this code myself right now, but hopefully it will be useful to you.

Best regards,
Kenny