Fun solution to loop replacement (`makemore_part4_backprop.ipynb`)
jchwenger opened this issue · 1 comments
Hi Andrej, hi everyone,
First of all, let me add my voice to the chorus: such awesome lectures, very grateful for them, I recommend them around me as soon as I have the opportunity!
At one point in the backprop lecture, you mention that there might be slicker way to update the last gradient tensor, dC
, instead of the Python loop you used. This tickled my curiosity, so I tinkered, and here's the solution I came up with, maybe others have found even better ways! (Although, arguably, if you're not into Torch nerdiness the threat to time management/peace of mind when basking in advanced indexing might not be lead to a great trade-off with the slow but straightforward loop! : >)
So, instead of:
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
for j in range(Xb.shape[1]):
ix = Xb[k,j]
dC[ix] += demb[k,j]
It is possible to do:
# arange -> unsqueeze -> tile -> flatten
# [ 0,1,...32 ] -> [[0], -> [[0,0,0], -> [0,0,0,1,1,1,...,31,31,31] # batch_size * block_size times
# [1], [1,1,1],
# ... ...
# [31]] [31,31,31]]
rows_xi = torch.tile(torch.arange(0, Xb.shape[0]).unsqueeze(1), (1,3)).flatten()
# [0,1,2] -> [[0,1,2],[0,1,2],...,[0,1,2]] # block_size * batch_size times
cols_xi = torch.tile(torch.arange(0, Xb.shape[1]), (Xb.shape[0],))
emb_xi = Xb[rows_xi, cols_xi] # block_size * batch_size indices to retrieve rows
dC1 = torch.zeros_like(C)
dC1.index_put_((emb_xi,), demb[rows_xi, cols_xi], accumulate=True)
A torch.allclose(dC1, dC)
yields True
on my end.
I'm indebted to the all-answering @ptrblck for that .index_put_(... accumulate=True)
reference!
Have a great day!