`tree_map`?
MilesCranmer opened this issue · 16 comments
I realize a lot of functions could just be implemented as calls to a generic tree_map
function. For example,
tree_map(t -> 1, tree; merge=max)
would calculate the depth of a tree. The “merge” function would be used to aggregate left/right child for binary nodes. For example,
tree_map(t -> 1, tree; merge=(+))
would count the total number of nodes. Meanwhile,
tree_map(tree; merge=(+)) do t
Int(t.degree==2)
end
would count the number of binary operators. Then something like
tree_map(tree; merge=(l, r)->[l…, r…]) do t
if t.degree != 0 || !t.constant
return []
end
return [t.val]
end
would return all constants in a tree (in depth-first traversal order).
@Moelf would this have been helpful for writing that NYT puzzle solver? What do you think of the API?
@AlCap23 any comment?
I'm not very familiar with common types of tree-based algorithms, but yeah I think functions that facilitate tree-talking should be useful for making custom loss function right?
yeah, compiler > human when it comes to optimizing code (un)fortunately.
btw _ -> 1
can be expressed as Returns(1)
in Julia now
Yeah. And it's so much more readable now! Makes it easier to think of other functions to map over these objects.
btw
_ -> 1
can be expressed asReturns(1)
in Julia now
I'm probably doing something wrong but it seems slower for some reason:
julia> f(_) = 1
f (generic function with 1 method)
julia> @btime f(3.2)
1.272 ns (0 allocations: 0 bytes)
1
julia> g = Returns(1)
Returns{Int64}(1)
julia> @btime g(3.2)
24.865 ns (0 allocations: 0 bytes)
1
yeah const g = ...
right now g
is global non-const.
Cool. Thanks!
@Moelf what do you think about if I were to overload Base.mapreduce
, rather than define a custom tree_mapreduce
? Since it has the exact same syntax as a regular mapreduce
, perhaps it is suitable. The one difference is the merge function here takes in (parent, child_l, child_r)
, whereas a normal mapreduce's merge takes in (element1, element2, ...)
.
Why don't I just overload all collection functions... Then you could just iterate through a tree!
hm, I'm trying to imagine if walking down a tree is ~ iteration, and if that would surprise people
I added some other collection functions including iterate
DynamicExpressions.jl/src/tree_map.jl
Lines 107 to 115 in 8d29619
The behavior is to traverse a tree depth-first, left to right, and return the current node at each step. So you can do:
for node in tree
if node.degree == 1 && tree.op == 2
my_operator_count += 1
end
end
And it will work as you might expect. It’s a tiny bit slower than using the mapreduce because it allocates a stack of nodes, but I think it might be easier for users to write custom losses.
One thing that might make this more intuitive is have a no-op type conversion. For example:
node_stack = DepthFirstTraversal(tree) |> collect
or filtering:
constant_nodes = filter(t -> t.degree == 0 && t.constant, DepthFirstTraversal(tree))
or looping:
for node in DepthFirstTraversal(tree)
# next node will be child, if this node has degree > 0
end
DepthFirstTraversal
would wrap the Node
type and define how the tree is iterated over. But otherwise it wouldn't do anything, and (hopefully) wouldn't affect the performance.
struct DepthFirstTraversal{N<:Node} <: AbstractTraversal
x::N
end
In the future could also add other traversal strategies.
What do you think?
@odow we chatted about tree structures at one point - I'd love to hear your take on this sort of interface!
It looks nice. We don't really use algorithms over trees in JuMP/MOI. The first step in our AD engine is to convert everything to a single topologically sorted tape so everything requires a linear pass.
The other issue I ran into was people constructing nested expressions that mean you can't use recursion.
A somewhat artificial example, but expressions like this cause trouble:
N = 1_000_000
x = [Variable() for _ in 1:N]
y = x[1]
for i in 2:N
y = +(y, x[i])
end
(Overlook the fact that you could lift all the nodes into +(x...)
etc. It's just an artificial example.)
I see, thanks. Indeed that look hard for recursion. We are lucky in this sense because we never have expressions with more than ~100 nodes or so, but in the future I definitely want to try some sort of stack-based evaluation.
By the way, unrelated but eventually I would love to build some sort of interface between JuMP/MOI and SymbolicRegression.jl. Maybe so you evolve symbolic models in a JuMP problem, or maybe so you could use JuMP inside an objective to optimize a symbolic expression found by SymbolicRegression.jl. At the very least it could be useful to build a converter between them, like the one we have for Symbolics.jl
.
Sometimes Julia code optimization is so weird. On Julia 1.9-rc3, this function:
function tree_mapreduce(f_leaf::F1, f_branch::F2, op::G, tree::N; preserve_sharing::Bool=false, result_type::Type{RT}=Nothing) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
preserve_sharing && return @with_memoization(_tree_mapreduce(f_leaf, f_branch, op, tree), IdDict{N,RT}())
return _tree_mapreduce(f_leaf, f_branch, op, tree)
end
is 2x slower than this function (only change is commenting out the conditional return):
function tree_mapreduce(f_leaf::F1, f_branch::F2, op::G, tree::N; preserve_sharing::Bool=false, result_type::Type{RT}=Nothing) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
# preserve_sharing && return @with_memoization(_tree_mapreduce(f_leaf, f_branch, op, tree), IdDict{N,RT}())
return _tree_mapreduce(f_leaf, f_branch, op, tree)
end
even when preserve_sharing
is set to false
! This is true even if I annotate the return type, even with -O3
, etc.
But I don't see this for 1.8.5
.
Maybe the precompilation caching is breaking something about this inline macro?
Fixed in https://discourse.julialang.org/t/strange-performance-issue-on-1-9-0-rc3/98427/10. Was quite a subtle issue.
I have since merged the "tree as collections" in #27 to be included in v0.8.0 onwards. Thanks for the tips @Moelf @odow!