karpathy/nn-zero-to-hero

Fun solution to loop replacement (`makemore_part4_backprop.ipynb`)

jchwenger opened this issue · 1 comments

Hi Andrej, hi everyone,

First of all, let me add my voice to the chorus: such awesome lectures, very grateful for them, I recommend them around me as soon as I have the opportunity!

At one point in the backprop lecture, you mention that there might be slicker way to update the last gradient tensor, dC, instead of the Python loop you used. This tickled my curiosity, so I tinkered, and here's the solution I came up with, maybe others have found even better ways! (Although, arguably, if you're not into Torch nerdiness the threat to time management/peace of mind when basking in advanced indexing might not be lead to a great trade-off with the slow but straightforward loop! : >)

So, instead of:

dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
  for j in range(Xb.shape[1]):
    ix = Xb[k,j]
    dC[ix] += demb[k,j]

It is possible to do:

# arange        -> unsqueeze  -> tile         -> flatten
# [ 0,1,...32 ] -> [[0],      -> [[0,0,0],    -> [0,0,0,1,1,1,...,31,31,31] # batch_size * block_size times
#                   [1],          [1,1,1], 
#                   ...           ...  
#                   [31]]         [31,31,31]]
rows_xi = torch.tile(torch.arange(0, Xb.shape[0]).unsqueeze(1), (1,3)).flatten()

# [0,1,2] -> [[0,1,2],[0,1,2],...,[0,1,2]] # block_size * batch_size times
cols_xi = torch.tile(torch.arange(0, Xb.shape[1]), (Xb.shape[0],))

emb_xi = Xb[rows_xi, cols_xi] # block_size * batch_size indices to retrieve rows

dC1 = torch.zeros_like(C)

dC1.index_put_((emb_xi,), demb[rows_xi, cols_xi], accumulate=True)

A torch.allclose(dC1, dC) yields True on my end.

I'm indebted to the all-answering @ptrblck for that .index_put_(... accumulate=True) reference!

Have a great day!