te.Checkpoint does not work for nested autocast
tohinz opened this issue · 4 comments
According to #438 we should be able to use both BF16 and FP8 autocasts.
In our specific setting our module consists of some linear layers that are torch.nn.Linear
and some layers that are te.Linear
(due to some input sizes not being compatible with FP8 and padding not being an option in this case). When we wrap this module with te.Checkpoint (following the fix in #776) we get errors in the backwards pass since the BF16 autocast is not used when the function is recomputed.
Concretely, for something like the following:
criterion = torch.nn.MSELoss()
model = torch.nn.Sequential()
model.append(torch.nn.Linear(hidden_dim, hidden_dim))
for _ in range(num_layers):
model.append(te.Linear(hidden_dim, hidden_dim))
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
with te.fp8_autocast(enabled=True):
output = te.checkpoint(model, model_input, use_reentrant=False)
loss = criterion(output, target)
loss.backward()
we get the error
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float
since the torch.nn.Linear
layer is not autocast to BF16 when the function is recomputed with te.Checkpoint
.
Looking at the PyTorch implementation they have functionality to make sure the autocast is also applied during the recomputation of the function in the backward pass.
The salient code pieces being something like here, here, here, and here.
I've added that to my local TE branch and it seems to fix the issue, i.e., the code with two autocasts now runs through and the gradient check returns True.
Example code to reproduce the error (adapted from #438):
import torch
from torch.autograd import grad
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
if __name__ == "__main__":
num_layers = 5
seq_length = 1024
hidden_dim = 2048
# Generate random model input and target for MSE loss
model_input = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.float)
target = torch.rand(8, seq_length, hidden_dim).cuda().to(dtype=torch.bfloat16)
criterion = torch.nn.MSELoss()
# Define the model
model = torch.nn.Sequential()
model.append(torch.nn.Linear(hidden_dim, hidden_dim))
for _ in range(num_layers):
model.append(te.Linear(hidden_dim, hidden_dim))
model.to(dtype=torch.float32).cuda()
# Define FP8
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
autocast_args = {"enabled": True, "fp8_recipe": fp8_recipe}
autocast = te.fp8_autocast
def inner(compare_grads):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
with autocast(**autocast_args):
output = te.checkpoint(model, model_input, use_reentrant=False)
output2 = model(model_input)
loss = criterion(output, target)
loss2 = criterion(output2, target)
if compare_grads:
# compute gradients
grads = grad(loss, model.parameters())
grads2 = grad(loss2, model.parameters())
# compare gradients
print("Gradients are equal: ")
print(torch.all(torch.eq(grads[0], grads2[0])))
print(torch.all(torch.eq(grads[1], grads2[1])))
# print gradients to check they are nonzero and not nan
print("")
print(grads[0])
print(grads2[0])
print("")
print(grads[1])
print(grads2[1])
else:
loss.backward()
loss2.backward()
# run model
fp8_scaling_iters = 50
# warmup iterations to get FP8 scaling parameters
for _ in range(fp8_scaling_iters):
inner(compare_grads=False)
inner(compare_grads=True)
@tohinz I will take a look at how we can automatically handle this in the TE checkpoint tomorrow. In the meantime, you should be able to make this work via user context functions.
def torch_autocast_ctx():
fwd_ctx = torch.amp.autocast(...)
recomp_ctx = torch.amp.autocast(...)
return ctx, ctx
te.distributed.checkpoint(..., context_fn=torch_autocast_ctx, ...)
The autocast contexts here would need to be configured consistently for how you want the forward and recompute to be done.