Jutho/TensorKit.jl

Permute for Non-abelian tensor is slow.

xiangjianqian opened this issue · 8 comments

I found that the length(fushiontrees(f1,f2)) would be very large for large-index SU2 tensors in some cases, which slows down the permute operation. The easiest way to solve this problem would be parallizing the _add_general_kernel! function. Is there any other better solution for this problem?

Jutho commented

The non-abelian permutation is indeed not yet parallellized. It is possible but some more care is needed as different blocks from the input tensor can contribute to each specific block of the output tensor. Can you post the specific example, so that I can investigate some more what the dominating bottleneck is?

Can one not parallelize over the different output blocks?

Jutho commented

Probably, but it might require to change the permutation logic on the underlying fusion trees somewhat. Or we could first collect all the permutation results for the different fusion trees in the input tensor (which requires some extra allocations), and then parallellize over the output blocks and find all the input blocks that contribute to it.

SU2_test.zip

Here is a short demo of my code. We can find that the time consumed in permutation is 200 times larger than that in contraction (which is measured by t1/(t2-t1)). Raw data is available through the link above.

import FileIO
using TensorKit
using TensorOperations

function TN_save(ttn::Vector{TensorMap},fname::String)
    ttn1=[convert(Dict,ttn[i]) for i in 1:length(ttn)]
    FileIO.save(fname,"ttn_tem",ttn1)
    return 0
end

function TN_load(fname::String)
    ttn1=FileIO.load(fname)["ttn_tem"]
    ttn=Vector{TensorMap}(undef, length(ttn1))
    for i in 1:length(ttn1)
        ttn[i]=convert(TensorMap,ttn1[i])
    end
    return ttn
end

function tensor_contract(a)
    a=1
    tensors=TN_load("SU2_test.jld2")

    link=[[167, 10, 69], [31, 190, -80], [135, 151, 97], [97, 143, 98], [98, 137, 99], [99, 139, 100], [100, 141, 101], [101, 153, 102], [102, 145, 103], [103, 146, 104], [104, 154, 105], [105, -133, 106], [106, 131, 107], [107, 129, 108], [108, 144, 109], [109, 152, 110], [110, 127, 111], [111, 128, 112], [112, 155, 113], [113, 147, 114], [114, 130, 115], [115, 132, 116], [116, -134, 117], [117, 157, 118], [118, 149, 119], [119, 150, 120], [120, 158, 121], [121, 142, 122], [122, 140, 123], [123, 138, 124], [124, 148, 125], [66, 69, 151, 152], [167, 190, 9, 32], [9, 10, 267], [31, 32, 278], [135, 349, 295], [295, 143, 296], [296, 137, 297], [297, 139, 298], [298, 141, 299], [299, 153, 300], [300, 145, 301], [301, 146, 302], [302, 154, 303], [303, 331, 304], [304, 131, 305], [305, 129, 306], [306, 144, 307], [307, 350, 308], [308, 127, 309], [309, 128, 310], [310, 155, 311], [311, 147, 312], [312, 130, 313], [313, 132, 314], [314, 332, 315], [315, 157, 316], [316, 149, 317], [317, 150, 318], [318, 158, 319], [319, 142, 320], [320, 140, 321], [321, 138, 322], [322, 148, 125], [278, -91, 331, 332], [66, 267, 349, 350]]
    conj_list= [true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false]
    order=[[63, 30], [64, 29], [63, 60], [62, 59], [61, 28], [60, 57], [59, 27], [58, 55], [57, 26], [56, 53], [55, 25], [54, 51], [53, 24], [52, 49], [51, 23], [50, 22], [49, 21], [45, 44], [47, 46], [46, 20], [45, 42], [44, 19], [43, 40], [42, 18], [41, 38], [40, 17], [39, 36], [38, 16], [37, 34], [15, 14], [35, 34], [31, 30], [33, 32], [3, 2], [31, 2], [30, 2], [29, 2], [28, 2], [13, 12], [26, 12], [25, 12], [24, 12], [23, 12], [22, 21], [21, 12], [20, 2], [19, 11], [18, 2], [17, 2], [16, 9], [3, 2], [14, 13], [8, 7], [12, 11], [11, 2], [10, 6], [2, 0], [5, 2], [7, 6], [6, 5], [5, 4], [4, 1], [2, 1], [2, 1], [1, 0]]
    output= [-80, -91, -133, -134]
    order=convert(Vector{Vector{Int64}},order)
    for i in order
        if length(link)!=2
            indc=Tuple(symdiff(Tuple(link[i[1]+1]),Tuple(link[i[2]+1])))
        else
            indc=tuple(output...)
        end
        temp_tensor=_tensorcontract(tensors[i[1]+1],Tuple(link[i[1]+1]),tensors[i[2]+1],Tuple(link[i[2]+1]),conj_list[i[1]+1],conj_list[i[2]+1],indc)
        link=vcat(deleteat!(link,sort(i+[1,1])),indc)
        tensors=vcat(deleteat!(tensors,sort(i+[1,1])),temp_tensor)
        conj_list=vcat(deleteat!(conj_list,sort(i+[1,1])),false)
    end
    return tensors[1]
end

function _tensorcontract(A::TensorMap, IA::Tuple, B::TensorMap, IB::Tuple,CA1::Bool,CB1::Bool,IC::Tuple)
    CA = CA1 == true ? :C : :N
    CB = CB1 == true ? :C : :N
    oindA, cindA, oindB, cindB, indCinoAB = TensorOperations.contract_indices(IA, IB, Tuple(IC))
    T = promote_type(eltype(A), eltype(B))
    C =TensorOperations.similar_from_indices(T, oindA, oindB,indCinoAB ,(), A, B, CA, CB)
    contraction(1, A, CA, B, CB, 0, C,
    oindA, cindA, oindB, cindB, indCinoAB, ())
    return C
end


function contraction(α,
    tA::AbstractTensorMap{S}, CA::Symbol,
    tB::AbstractTensorMap{S}, CB::Symbol,
    β, tC::AbstractTensorMap{S, N₁, N₂},
    oindA, cindA,
    oindB, cindB,
    p1, p2)where {S, N₁, N₂}

    p = (p1..., p2...)
    pl = ntuple(n->p[n], N₁)
    pr = ntuple(n->p[N₁+n], N₂)
    if CA == :N && CB == :N
        qcontract!(α, tA, tB, β, tC, oindA, cindA, oindB, cindB, pl, pr)
    elseif CA == :N && CB == :C
        oindB = TensorKit.adjointtensorindices(tB, oindB)
        cindB = TensorKit.adjointtensorindices(tB, cindB)
        qcontract!(α, tA, tB', β, tC, oindA, cindA, oindB, cindB, pl, pr)
    elseif CA == :C && CB == :N
        oindA = TensorKit.adjointtensorindices(tA, oindA)
        cindA = TensorKit.adjointtensorindices(tA, cindA)
        qcontract!(α, tA', tB, β, tC, oindA, cindA, oindB, cindB, pl, pr)
    elseif CA == :C && CB == :C
        oindA = TensorKit.adjointtensorindices(tA, oindA)
        cindA = TensorKit.adjointtensorindices(tA, cindA)
        oindB = TensorKit.adjointtensorindices(tB, oindB)
        cindB = TensorKit.adjointtensorindices(tB, cindB)
        qcontract!(α, tA', tB', β, tC, oindA, cindA, oindB, cindB, pl, pr)
    else
        error("unknown conjugation flags: $CA and $CB")
    end
    return tC
end


function qcontract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
                    β, C::AbstractTensorMap{S},
                    oindA, cindA,
                    oindB, cindB,
                    p1, p2)where {S, N₁, N₂}
    st=time()
    A′ = permute(A, oindA, cindA)
    B′ = permute(B, cindB, oindB)
    t1=time()-st
    C′ = A′*B′
    add!(α, C′, β, C, p1, p2)
    t2=time()-st
    println(t1/(t2-t1))
    return C
end
Jutho commented

I am not sure of much can be done about this. Yes, multithreading add_general might help, but ultimately, you have tensors where the individual blocks are all rather small, but plentiful. I also see that, in evaluating the permutation, Strided._mapreduce_block! (which is being called by the higher level axpy! on StridedView representations of the permuted data), more time is being spent on computing the optimal block size than to actually evaluate the permutation. This is again an artefact of the same problem, i.e. arrays with several indices which all have a small size. The Strided logic is certainly not optimised for that case.

There are various places where one could intervene to improve the performance, at varying degree of high level versus low level, but this will always be a very problem specific process and there will not be a one-size-fits-all optimal solution.

I have been thinking about a general machinery to be able to select among different contraction or permutation algorithms, but with little progress so far.

Thank you very much for your help. Maybe I need to improve my code to avoid that problem.

Jutho commented

I don't think your code is really the problem. What is the contraction that you are evaluating exactly? Maybe a different order of contracting the tensors, or a different index order of the individual tensors can already mitigate part of the problem. Overall, the code a = 1; tensor_contract(a) (I am not sure what the a argument is for?) runs quite fast on my laptop, and there seem to be only a few specific contractions where t1/(t2-t1) spikes to values much larger than one.

Maybe it is also a matter of expectations. The "common believe" that multiplication time dominates permutation time in tensor network contractions is only valid in the asymptotic regime of large tensors. Present-day computers have such powerful processors (using vector instructions and all that), that many algorithms (throughout computing, not specifically in TN) are actually limited in speed by the memory bandwidth rather than the computing time. Furthermore, matrix multiplication has such highly optimised and efficient implementations (exploiting all these processor developments), that permutations (which do not involve any computations but are completely about memory bandwidth) often take a significant amount of the time.

It would be interesting to see how the ratio between permutation and multiplication change on a different architecture, such as Apple's M1, which seems to have much larger memory bandwidths. Unfortunately, I do not have any hands on experience (yet).

Yes, I want to find a more efficient contraction order to avoid producing tensors with several indices which all have a small size. In my code, I have thousands of that TNs to be contracted to evaluate the expectation value of a hamiltonian. The a augment is a general TN input. I just set it as 1 for simplification.

In my previous order evaluation, I didn't consider the problem mentioned above. So, in the next step, I want to take that problem into account in the order evaluation.

I am very grateful for your advice.