Scalarization of registered array functions breaks
BenChung opened this issue · 0 comments
BenChung commented
If you register a vector-valued function and then try to scalarize the system you get a confusing error:
julia> f(a, b, c::AbstractVector, d::AbstractVector) = c
julia> @register_array_symbolic f(a, b, c::AbstractVector, d::AbstractVector) begin
size = (length(c),)
eltype = eltype(c)
end
julia> using Symbolics
julia> @register_array_symbolic f(a, b, c::AbstractVector, d::AbstractVector) begin
size = (length(c),)
eltype = eltype(c)
end
julia> @variables a b c[1:3] d[1:3]
julia> scalarize(f(a, b, c, d) + c)
ERROR: MethodError: no method matching +(::SymbolicUtils.BasicSymbolic{Any}, ::SymbolicUtils.BasicSymbolic{Real})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:587
+(::ChainRulesCore.ZeroTangent, ::Any)
@ ChainRulesCore C:\Users\benchung\.julia\packages\ChainRulesCore\zgT0R\src\tangent_arithmetic.jl:99
+(::Any, ::ChainRulesCore.ZeroTangent)
@ ChainRulesCore C:\Users\benchung\.julia\packages\ChainRulesCore\zgT0R\src\tangent_arithmetic.jl:100
...
Stacktrace:
[1] #129#130
@ C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:616 [inlined]
[2] (::Symbolics.var"#129#134"{…})(x::SymbolicUtils.BasicSymbolic{…}, f::Function, args::Vector{…})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:612
[3] prewalk_if(cond::Symbolics.var"#131#137", f::SymbolicUtils.Rewriters.PassThrough{…}, t::SymbolicUtils.BasicSymbolic{…}, similarterm::Symbolics.var"#129#134"{…})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:642
[4] replace_by_scalarizing(ex::SymbolicUtils.BasicSymbolic{Real}, dict::Dict{SymbolicUtils.BasicSymbolic{Int64}, Int64})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:634
[5] scalarize(arr::Symbolics.ArrayOp{AbstractVector{Real}}, idx::Tuple{Int64})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:750
[6] scalarize(arr::SymbolicUtils.BasicSymbolic{Real})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:772
[7] scalarize(arr::Num)
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:774
[8] (::Symbolics.var"#150#151")(i::Tuple{Int64})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:768
[9] iterate
@ .\generator.jl:47 [inlined]
[10] collect(itr::Base.Generator{Base.Iterators.ProductIterator{Tuple{Base.OneTo{Int64}}}, Symbolics.var"#150#151"})
@ Base .\array.jl:834
[11] map(f::Function, A::Base.Iterators.ProductIterator{Tuple{Base.OneTo{Int64}}})
@ Base .\abstractarray.jl:3313
[12] scalarize(arr::Symbolics.Arr{Num, 1})
@ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:767
[13] top-level scope
@ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.