ACEsuit/ACEmd.jl

Bug in assemble when using threads

Opened this issue · 2 comments

@cortner Here is the bug with assemble. Basically you need to have PR #46 in use. Then this triggers the bug

using Folds

function broken_assemble(data::AbstractArray, basis; kwargs...)
    W = Threads.@spawn ACEfit.assemble_weights(data; kwargs...)
    raw_data = Folds.map( data ) do d # this will bug out
    #raw_data = pmap( data ) do d  # this works
        A = ACEfit.feature_matrix(d, basis; kwargs...)
        Y = ACEfit.target_vector(d; kwargs...)
        (A, Y)
    end
    A = [ a[1] for a in raw_data ]
    Y = [ a[2] for a in raw_data ]

    A_final = reduce(vcat, A)
    Y_final = reduce(vcat, Y)
    return A_final, Y_final, fetch(W)
end

When you try same input several times, it gives different results and some times the arrays have even different sizes

Here is a full working example:

using ACEmd
using ACEfit
using ACEpotentials

model = acemodel(
    elements = [:Ti, :Al],
	order = 3,
	totaldegree = 6,
	rcut = 5.5,
	Eref = [:Ti => -1586.0195, :Al => -105.5954]
)
basis = model.basis

data_j, _, meta = ACEpotentials.example_dataset("TiAl_tutorial")
train_new = [  FlexibleSystem(x) for x in data_j[1:5:end]  ]

# Ref data
A, Y, W = ACEfit.assemble(train_new, basis; energy_default_weight=5, energy_ref=model.Vref)

## Broken code
using Folds

function broken_assemble(data::AbstractArray, basis; kwargs...)
    W = Threads.@spawn ACEfit.assemble_weights(data; kwargs...)
    raw_data = Folds.map( data ) do d # this will bug out
    #raw_data = pmap( data ) do d  # this works
        A = ACEfit.feature_matrix(d, basis; kwargs...)
        Y = ACEfit.target_vector(d; kwargs...)
        (A, Y)
    end
    A = [ a[1] for a in raw_data ]
    Y = [ a[2] for a in raw_data ]

    A_final = reduce(vcat, A)
    Y_final = reduce(vcat, Y)
    return A_final, Y_final, fetch(W)
end
##

for i in 1:10
    a, y, w = broken_assemble(train_new, basis; energy_default_weight=5, energy_ref=model.Vref)
    # Compare results
    @info "Maximum error " maximum(abs2, A - a)
end

Adding for future reference, this is the feature_matrix function

function ACEfit.feature_matrix(
    data,
    basis; 
    energy=true, 
    force=true, 
    virial=true,
    energy_key=:energy,
    force_key=:force,
    virial_key=:virial,
    kwargs...)
    # Basis functions are on different collumns.
    # Energy is on fist row.
    # Force is flattened on several rows, so that each basis function is on same collumns.
    # Virial is in practice triangular matrix, so first remove double values and then flatten.
    # This is equal to only flattening lower triangular matrix.

    blocks = []
    if energy && haskey(data, energy_key)
        e = ace_energy(basis, data; kwargs...)
        push!(blocks, e')
    end
    if force && haskey(data, force_key)
        f = ace_forces(basis, data; kwargs...)
        tf = reinterpret.(Float64, f)
        f_bock = reduce(hcat, tf)
        push!(blocks, f_bock)
    end
    if virial && haskey(data, virial_key)
        v = ace_virial(basis, data; kwargs...)
        tv = map( v ) do m
            m[SVector(1,5,9,6,3,2)]
        end
        v_block = reduce(hcat, tv)
        push!(blocks, v_block)
    end
    return reduce(vcat, blocks)
end

It will lead to to this

for ace_method in [ :ace_energy, :ace_forces, :ace_virial]
    @eval begin
        function $ace_method(basis::ACE1.IPSuperBasis, data; executor=ThreadedEx(), kwargs...)
            em = Folds.map( basis.BB, executor ) do b
                $ace_method(b, data; executor=executor, kwargs...)
            end
            return reduce(vcat, em)
        end
    end
end

and so on...

I made some changes to feature_matrix with the performance increase patch in #46. Still same thing this triggers the bug.

One possible cause is that with this form, both evaluate and evaluate_d are running at the same time on same process. Maybe that is the cause of the issue. At least I don't think this happens in any other use case, as energy and force is run sequentially in other code. In any case this would be easy to check at least.