Performance issue
anbirDrea opened this issue · 0 comments
anbirDrea commented
Hi, i am using the very good pkg Yao.jl
and find something weird.
i try to implement the shift rule method to get the gradient of a circuit. I find the function is unexpected slow.
here are my code
using Yao
function gradient_rot(ham::AbstractBlock, state::AbstractRegister, cir::AbstractBlock, k::Int)
num_qubits = nqubits(cir)
dispatch!(+, cir.blocks[k], pi/2)
pos = expect(ham, copy(state)=>cir)
dispatch!(-, cir.blocks[k], pi)
neg = expect(ham, copy(state)=>cir)
dispatch!(+, cir.blocks[k], pi/2)
return (pos - neg) / 2 |> real
end
function gradient(ham::AbstractBlock, state::AbstractRegister, cir::AbstractBlock)
grads = zeros(Float64, length(cir.blocks))
for k = 1:length(cir.blocks)
grads[k] = gradient_rot(ham, state, cir, k)
end
return grads
end
c = chain(1, put(1, 1=>Rx(0.23)))
ham = kron(Z)
state = rand_state(nqubits(c))
dispatch!(c, :random)
@time grads1 = expect'(ham, copy(state)=>c)[2];
@time grads2 = gradient(ham, copy(state), c);
@time begin
grads3 = zeros(Float64, length(c.blocks))
for k = 1:length(c.blocks)
grads3[k] = gradient_rot(ham, state, c, k)
end
end
in my devices, the results are
0.000037 seconds (50 allocations: 1.953 KiB)
0.037294 seconds (41.22 k allocations: 2.300 MiB, 99.82% compilation time)
0.000029 seconds (59 allocations: 2.203 KiB)
i noticed that the gradient
function is unexpected slow. however, when implemented in the main code (like grads3), the speed is much faster.
I do not know How to optimize this function, could someone help me?