Size hints for derivatives with undefined size
dfdx opened this issue · 1 comments
dfdx commented
ds = rdiff(:(sum(W * x + b)); ctx=[:outfmt => :ein], W = rand(2,2), x=rand(2), b=rand(2))
ds[:W]
Currently, this produces:
dtmp6_dW[m,n] = x[n]
This expression is a correct derivative for output w.r.t. to components of W
, however, it doesn't define the size of dtmp6_dW
. Thus we need to somehow pass size hints and use them when converting to vectorized form (to repmat
in this case).
dfdx commented
Current plan is:
- Create a new function
transfer_size(primitive_expression) -> size_expression
. - Call it during forward pass and save into
rdiff
's context. - Use it in
from_einstein
(have special syntax in templates?).