QuantumBFS/Yao.jl

Performance issue

anbirDrea opened this issue · 0 comments

Hi, i am using the very good pkg Yao.jl and find something weird.

i try to implement the shift rule method to get the gradient of a circuit. I find the function is unexpected slow.

here are my code

using Yao

function gradient_rot(ham::AbstractBlock, state::AbstractRegister, cir::AbstractBlock, k::Int)
    num_qubits = nqubits(cir) 
    
    dispatch!(+, cir.blocks[k], pi/2)
    pos = expect(ham, copy(state)=>cir)
    
    dispatch!(-, cir.blocks[k], pi)
    neg = expect(ham, copy(state)=>cir)
    
    dispatch!(+, cir.blocks[k], pi/2)
    
    return (pos - neg) / 2 |> real
end

function gradient(ham::AbstractBlock, state::AbstractRegister, cir::AbstractBlock)
    grads = zeros(Float64, length(cir.blocks))
    for k = 1:length(cir.blocks)
        grads[k] = gradient_rot(ham, state, cir, k)
    end
    
    return grads
end

c = chain(1, put(1, 1=>Rx(0.23)))

ham = kron(Z)
state = rand_state(nqubits(c))
dispatch!(c, :random)

@time grads1 = expect'(ham, copy(state)=>c)[2];
@time grads2 = gradient(ham, copy(state), c);

@time begin
    grads3 = zeros(Float64, length(c.blocks))
    for k = 1:length(c.blocks)
        grads3[k] = gradient_rot(ham, state, c, k)
    end
end

in my devices, the results are

0.000037 seconds (50 allocations: 1.953 KiB)
0.037294 seconds (41.22 k allocations: 2.300 MiB, 99.82% compilation time)
0.000029 seconds (59 allocations: 2.203 KiB)

i noticed that the gradient function is unexpected slow. however, when implemented in the main code (like grads3), the speed is much faster.

I do not know How to optimize this function, could someone help me?