SymbolicML/DynamicExpressions.jl

`tree_map`?

MilesCranmer opened this issue · 16 comments

I realize a lot of functions could just be implemented as calls to a generic tree_map function. For example,

tree_map(t -> 1, tree; merge=max)

would calculate the depth of a tree. The “merge” function would be used to aggregate left/right child for binary nodes. For example,

tree_map(t -> 1, tree; merge=(+))

would count the total number of nodes. Meanwhile,

tree_map(tree; merge=(+)) do t
    Int(t.degree==2)
end

would count the number of binary operators. Then something like

tree_map(tree; merge=(l, r)->[l, r]) do t
    if t.degree != 0 || !t.constant
        return []
    end
    return [t.val]
end

would return all constants in a tree (in depth-first traversal order).


@Moelf would this have been helpful for writing that NYT puzzle solver? What do you think of the API?

@AlCap23 any comment?

Moelf commented

I'm not very familiar with common types of tree-based algorithms, but yeah I think functions that facilitate tree-talking should be useful for making custom loss function right?

My god, it is beautiful.

image

The best part about this is that some of these functions actually got faster after this refactor.

Moelf commented

yeah, compiler > human when it comes to optimizing code (un)fortunately.

btw _ -> 1 can be expressed as Returns(1) in Julia now

Yeah. And it's so much more readable now! Makes it easier to think of other functions to map over these objects.

btw _ -> 1 can be expressed as Returns(1) in Julia now

I'm probably doing something wrong but it seems slower for some reason:

julia> f(_) = 1
f (generic function with 1 method)

julia> @btime f(3.2)
  1.272 ns (0 allocations: 0 bytes)
1

julia> g = Returns(1)
Returns{Int64}(1)

julia> @btime g(3.2)
  24.865 ns (0 allocations: 0 bytes)
1
Moelf commented

yeah const g = ... right now g is global non-const.

Cool. Thanks!

@Moelf what do you think about if I were to overload Base.mapreduce, rather than define a custom tree_mapreduce? Since it has the exact same syntax as a regular mapreduce, perhaps it is suitable. The one difference is the merge function here takes in (parent, child_l, child_r), whereas a normal mapreduce's merge takes in (element1, element2, ...).

Why don't I just overload all collection functions... Then you could just iterate through a tree!

Moelf commented

hm, I'm trying to imagine if walking down a tree is ~ iteration, and if that would surprise people

I added some other collection functions including iterate

map(f::F, tree::Node) where {F<:Function} = f.(collect(tree))
all(f::F, tree::Node) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
getindex(tree::Node, i::Int) = collect(tree)[i]
iterate(root::Node) = (root, collect(root)[(begin + 1):end])
iterate(_, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
length(tree::Node) = mapreduce(_ -> 1, +, tree)
firstindex(::Node) = 1
lastindex(tree::Node) = length(tree)
setindex!(::Node, _, ::Int) = error("Cannot setindex! on a tree. Use `set_node!` instead.")

The behavior is to traverse a tree depth-first, left to right, and return the current node at each step. So you can do:

for node in tree
    if node.degree == 1 && tree.op == 2
        my_operator_count += 1
    end
end

And it will work as you might expect. It’s a tiny bit slower than using the mapreduce because it allocates a stack of nodes, but I think it might be easier for users to write custom losses.

One thing that might make this more intuitive is have a no-op type conversion. For example:

node_stack = DepthFirstTraversal(tree) |> collect

or filtering:

constant_nodes = filter(t -> t.degree == 0 && t.constant, DepthFirstTraversal(tree))

or looping:

for node in DepthFirstTraversal(tree)
    # next node will be child, if this node has degree > 0
end

DepthFirstTraversal would wrap the Node type and define how the tree is iterated over. But otherwise it wouldn't do anything, and (hopefully) wouldn't affect the performance.

struct DepthFirstTraversal{N<:Node} <: AbstractTraversal
    x::N
end

In the future could also add other traversal strategies.

What do you think?

@odow we chatted about tree structures at one point - I'd love to hear your take on this sort of interface!

odow commented

It looks nice. We don't really use algorithms over trees in JuMP/MOI. The first step in our AD engine is to convert everything to a single topologically sorted tape so everything requires a linear pass.

The other issue I ran into was people constructing nested expressions that mean you can't use recursion.

A somewhat artificial example, but expressions like this cause trouble:

N = 1_000_000
x = [Variable() for _ in 1:N]
y = x[1]
for i in 2:N
    y = +(y, x[i])
end

(Overlook the fact that you could lift all the nodes into +(x...) etc. It's just an artificial example.)

I see, thanks. Indeed that look hard for recursion. We are lucky in this sense because we never have expressions with more than ~100 nodes or so, but in the future I definitely want to try some sort of stack-based evaluation.

By the way, unrelated but eventually I would love to build some sort of interface between JuMP/MOI and SymbolicRegression.jl. Maybe so you evolve symbolic models in a JuMP problem, or maybe so you could use JuMP inside an objective to optimize a symbolic expression found by SymbolicRegression.jl. At the very least it could be useful to build a converter between them, like the one we have for Symbolics.jl.

Sometimes Julia code optimization is so weird. On Julia 1.9-rc3, this function:

function tree_mapreduce(f_leaf::F1, f_branch::F2, op::G, tree::N; preserve_sharing::Bool=false, result_type::Type{RT}=Nothing) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
    preserve_sharing && return @with_memoization(_tree_mapreduce(f_leaf, f_branch, op, tree), IdDict{N,RT}())
    return _tree_mapreduce(f_leaf, f_branch, op, tree)
end

is 2x slower than this function (only change is commenting out the conditional return):

function tree_mapreduce(f_leaf::F1, f_branch::F2, op::G, tree::N; preserve_sharing::Bool=false, result_type::Type{RT}=Nothing) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
    # preserve_sharing && return @with_memoization(_tree_mapreduce(f_leaf, f_branch, op, tree), IdDict{N,RT}())
    return _tree_mapreduce(f_leaf, f_branch, op, tree)
end

even when preserve_sharing is set to false! This is true even if I annotate the return type, even with -O3, etc.

But I don't see this for 1.8.5.

Maybe the precompilation caching is breaking something about this inline macro?

Fixed in https://discourse.julialang.org/t/strange-performance-issue-on-1-9-0-rc3/98427/10. Was quite a subtle issue.

I have since merged the "tree as collections" in #27 to be included in v0.8.0 onwards. Thanks for the tips @Moelf @odow!