Proposal to refactor 'explicit' broadcasting
Closed this issue · 1 comments
Right now, broadcasting seems to be handled explicitly in each operation. That means that you also need to handle it in the backwards pass (by summing across broadcasted dims). I think it might be a good option to create a 'broadcast_to' operation that will handle both the forward and the backward pass of the tensors. Something like:
def broadcast_to(x, shape):
res = ....
def grad(...):
g = sum_accordingly(grad_out)
...
def broadcast_tensors(x, y):
shape = find_common_shape(x, y)
return broadcast_to(x, shape), broadcast_to(y, shape)
def add(x, y):
x, y = broadcast_tensors(x, y) # now x and y backward functions are the one from broadcast_to
res = impl(x, y) # res bw function is the 'add' bw function
def grad(...):
# only handle grad for the actual add function, grad_out will already be summed or whatever here
...
This would make broadcasting more general, and extendible to any binary or ternary op (for example, matmuls do not handle broadcasting right now, and implementing it manually is a pain haha). The code is only an example but I think it makes more sense than manually handling broadcasting on every op (the actual implementation will probably look different, this is just a rough 'draft'). There is the small overhead of having one more operation in the compute graph, but the runtime is pretty much the same. I can get something working quickly if it seems fine to you.
good point! @davidgonmar that's a very good idea. sure i would be more than happy if you want to implement this feature in smolgrad. Although, before you send a PR, attach some tests too for this feature by comparing it with pytorch, so that i can reproduce it locally as well.