JuliaSymbolics/Symbolics.jl

Scalarization of registered array functions breaks

BenChung opened this issue · 0 comments

If you register a vector-valued function and then try to scalarize the system you get a confusing error:

julia>  f(a, b, c::AbstractVector, d::AbstractVector) = c
julia> @register_array_symbolic f(a, b, c::AbstractVector, d::AbstractVector) begin
                  size = (length(c),)
                  eltype = eltype(c)
              end
julia> using Symbolics
julia> @register_array_symbolic f(a, b, c::AbstractVector, d::AbstractVector) begin
                  size = (length(c),)
                  eltype = eltype(c)
              end
julia> @variables a b c[1:3] d[1:3]
julia> scalarize(f(a, b, c, d) + c)
ERROR: MethodError: no method matching +(::SymbolicUtils.BasicSymbolic{Any}, ::SymbolicUtils.BasicSymbolic{Real})

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:587
  +(::ChainRulesCore.ZeroTangent, ::Any)
   @ ChainRulesCore C:\Users\benchung\.julia\packages\ChainRulesCore\zgT0R\src\tangent_arithmetic.jl:99
  +(::Any, ::ChainRulesCore.ZeroTangent)
   @ ChainRulesCore C:\Users\benchung\.julia\packages\ChainRulesCore\zgT0R\src\tangent_arithmetic.jl:100
  ...

Stacktrace:
  [1] #129#130
    @ C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:616 [inlined]
  [2] (::Symbolics.var"#129#134"{…})(x::SymbolicUtils.BasicSymbolic{…}, f::Function, args::Vector{…})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:612
  [3] prewalk_if(cond::Symbolics.var"#131#137", f::SymbolicUtils.Rewriters.PassThrough{…}, t::SymbolicUtils.BasicSymbolic{…}, similarterm::Symbolics.var"#129#134"{…})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:642
  [4] replace_by_scalarizing(ex::SymbolicUtils.BasicSymbolic{Real}, dict::Dict{SymbolicUtils.BasicSymbolic{Int64}, Int64})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:634
  [5] scalarize(arr::Symbolics.ArrayOp{AbstractVector{Real}}, idx::Tuple{Int64})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:750
  [6] scalarize(arr::SymbolicUtils.BasicSymbolic{Real})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:772
  [7] scalarize(arr::Num)
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:774
  [8] (::Symbolics.var"#150#151")(i::Tuple{Int64})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:768
  [9] iterate
    @ .\generator.jl:47 [inlined]
 [10] collect(itr::Base.Generator{Base.Iterators.ProductIterator{Tuple{Base.OneTo{Int64}}}, Symbolics.var"#150#151"})
    @ Base .\array.jl:834
 [11] map(f::Function, A::Base.Iterators.ProductIterator{Tuple{Base.OneTo{Int64}}})
    @ Base .\abstractarray.jl:3313
 [12] scalarize(arr::Symbolics.Arr{Num, 1})
    @ Symbolics C:\Users\benchung\.julia\packages\Symbolics\HIg7O\src\arrays.jl:767
 [13] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.