How is the autoregressive loss handled?
Closed this issue · 2 comments
I apologize if this is a trivial question, but I couldn't find any answers in the YouTube comments or the issues posted here.
In the train_gpt2.py
file, the forward function for the CausalSelfAttention
module is implemented without a mask. As a result, the returned value y
attends to all tokens in the input x
. Given this, when calculating the loss in the forward function of the GPT
model, it appears that we are not computing ( P(x_t \mid x_1, x_2, \ldots, x_{t-1}) ) as expected. Could you provide some insight into this?
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
Attention with mask is in F.scaled_dot_product_attention
below is source code of F.scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
Awesome! That answers my question!