Type of AutoGrad.Rec
xukai92 opened this issue · 15 comments
Is it possible to make AutoGrad.Rec
a subtype of Real
so that this package can work with Distributions.jl
?
I tried that naively but it doesn't successfully compile
What I am trying to do is to allow DIstributions.jl (and its dependency functions) to support AD by AutoGrad.
- Most of the functions in DIstributions.jl allow
Real
to pass - If
Rec <: Real
istrue
I think backward-AD could work with functions likelogpdf
easily (?) - I guess even if
Rec
is an array, technically we can still make it subtype ofReal
?- This is what
ForwardDiff.jl
does, i.e. make theirDual
number a subtype ofReal
, even it also has an array recording partial derivatives. - And by doing this, it allows user can use AD for all distributions and corresponding
logpdf
.
- This is what
- It would be very hard to work around in the other way, i.e. changing all the function signatures in Distributions.jl, StatsFuns.jl, etc.
- It will requires these functions to depend on AutoGrad.jl
- Even if they don't mind depend on AutoGrad.jl, there is a lot of work to do as these package have a lot dependencies.
- Some functions used in Distributions.jl depend on StatsFuns.jl - in order to make all of them passing Rec is not easy work.
For what is worth, I think the same happens with the Distances.jl package. When trying to autodifferentiate functions that call functions in Distances.jl, a type problem will occur similar to that described in the original post.
For instance, here is a simple example:
A=randn(3,5);
f = x -> sum(pairwise(SqEuclidean(), A, reshape(x, size(A,1), size(A,2)))) # some arbitrary function that calls a function in Distances.jl
g = grad(f)
g(vec(randn(3,5)))
Unfortunately, the above will give the error:
ERROR: MethodError: no method matching pairwise(::Distances.SqEuclidean, ::Array{Float64,2}, ::AutoGrad.Rec{Array{Float64,2}})
Closest candidates are:
pairwise(::Distances.PreMetric, ::AbstractArray{T,2} where T) at /Users/ngiann/.julia/v0.6/Distances/src/generic.jl:125
pairwise(::Distances.PreMetric, ::AbstractArray{T,2} where T, ::AbstractArray{T,2} where T) at /Users/ngiann/.julia/v0.6/Distances/src/generic.jl:118
I think that the problem with using AutoGrad on Distances.jl is very similar to the one using AutoGrad on Distributions.jl .
Many thanks.
Thanks for the response. I will see if I can modify my example above using the suggested macro to get around the problem.
@denizyuret If my understanding is correct, as long as the underlying code is written in pure Julia and the corresponding primitive functions is extended to handle Rec, I can use AutoGrad.jl on that function.
For Distributions.jl and its dependencies, they do have corresponding codes written in pure Julia - which means as long as the primitive functions supported AutoGrad.jl could work.
The only problem here is Distributions.jl and its dependencies use Real
in their function signatures for variables - because Rec <: Real
is false
, we cannot pass Rec to their functions.
ForwardDiff.jl
handles this problem by handle theirDual
type (similar toRec
but with forward AD) a subtype ofReal
. By doing that it's compatible with Distributions.jl
Yes I understand this. The problem is not about the primitives but the function signature Distributions.jl and its dependencies use. Let me put it in an example.
AutoGrad.jl works in the example below
p(x) = x * x
f(w) = p(w[1])
df = grad(f)
w = KnetArray([1.0])
df(w)
But if we change p(x) = x * x
to
p(x::Real) = x * x
AutoGrad.jl doesn't work through because of restriction on Real
will deny passing Rec
.
Well you can do it but the point here is that p(x)
is usually a very complex function, e.g. logpdf
of a distribution. We really don't want to derive the derivative on our own and define it by @primitive
otherwise there is not point of using automatic differentiation here.
However if we can either 1) change the function signature in Distributions.jl or 2) make Rec <: Real
the problem itself doesn't exist any more because the function is written in pure Julia and all the primitives used are defined in AutoGrad.jl
I tried to simply add <: Real
and it doesn't compile successfully, which gives me the error below
julia> using AutoGrad
INFO: Precompiling module AutoGrad.
WARNING: Method definition broadcast(Any, Union{Number, AbstractArray{T, N} where N where T}...) in module AutoGrad at /home/kai/.julia/v0.6/AutoGrad/src/unfuse.jl:35 overwritten at /home/kai/.julia/v0.6/AutoGrad/src/unfuse.jl:37.
ERROR: LoadError: LoadError: MethodError: no method matching sign(::AutoGrad.Broadcasted{Array{Float64,1}})
Closest candidates are:
sign(::Bool) at bool.jl:76
sign(::Unsigned) at number.jl:81
sign(::Rational) at rational.jl:221
...
Stacktrace:
[1] broadcast(::Function, ::Array{Float64,1}) at /home/kai/.julia/v0.6/AutoGrad/src/unfuse.jl:37
[2] #randin#25(::Float64, ::Function, ::Tuple{Float64,Float64}, ::Int64, ::Vararg{Int64,N} where N) at /home/kai/.julia/v0.6/AutoGrad/src/gradcheck.jl:209
[3] addtest1(::Symbol, ::Tuple{Float64,Float64}) at /home/kai/.julia/v0.6/AutoGrad/src/gradcheck.jl:193
[4] macro expansion at /home/kai/.julia/v0.6/AutoGrad/src/base/number.jl:12 [inlined]
[5] anonymous at ./<missing>:?
[6] include_from_node1(::String) at ./loading.jl:569
[7] include(::String) at ./sysimg.jl:14
[8] include_from_node1(::String) at ./loading.jl:569
[9] include(::String) at ./sysimg.jl:14
[10] anonymous at ./<missing>:2
while loading /home/kai/.julia/v0.6/AutoGrad/src/base/number.jl, in expression starting on line 6
while loading /home/kai/.julia/v0.6/AutoGrad/src/AutoGrad.jl, in expression starting on line 25
ERROR: Failed to precompile AutoGrad to /home/kai/.julia/lib/v0.6/AutoGrad.ji.
Stacktrace:
[1] compilecache(::String) at ./loading.jl:703
[2] _require(::Symbol) at ./loading.jl:490
[3] require(::Symbol) at ./loading.jl:398
Yes - I know the Cassette.jl package but for now AutoGrad.jl seems to be the most mature AD solution with GPU support in Julia.
Also will you consider to add a compilation flag to let AutoGrad.jl set Rec <: Real
or not. As I feel even it might be experimental but this is very useful when user want to use AD through other packages.
Any idea on this?