QuantumBFS/Yao.jl

How can we compute the hessian matrix

anbirDrea opened this issue · 4 comments

Let's say we want to compute the hessian matrix of a VQC, i implement the following code and find it does not work. It seems that we should implement many base funcitons.

What is the most convenient way to obtain the hessian matrix using Yao.jl?

Thanks for your help!

using ForwardDiff: jacobian
using Yao
using LinearAlgebra: I

circuit = chain(5,
    put(5, 2=>Rx(1)),
    put(5, 1=>Ry(2)),
    put(5, 3=>Rz(3)),
    put(5, 2=>shift(4))
)
HamBlock = matblock(Matrix{Complex{Float64}}(I, 1<<5, 1<<5))


function compute_gradient(params::AbstractVector{T}) where T
    dispatch!(circuit, params)
    expect'(HamBlock, zero_state(nqubits(circuit))=>circuit)[2]
end

x = rand(4)*2π
g = compute_gradient(x)
h = jacobian(compute_gradient, x)

Results:

MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(compute_gradient), Float64}, Float64, 4})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at ~/julia-1.7.2/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at ~/julia-1.7.2/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at ~/julia-1.7.2/share/julia/base/char.jl:50
  ...

Can you try the following function?

julia> function compute_gradient(params::AbstractVector{T}) where T
           c = dispatch(circuit, params)
           expect'(HamBlock, zero_state(Complex{T}, nqubits(c))=>c)[2]
       end
  1. use non-inplace version dispatch function is much safer in autodiff.
  2. let the register element type be consistent with the parameter type: zero_state(Complex{T}, nqubits(c))=>c).

Can you try the following function?

julia> function compute_gradient(params::AbstractVector{T}) where T
           c = dispatch(circuit, params)
           expect'(HamBlock, zero_state(Complex{T}, nqubits(c))=>c)[2]
       end
  1. use non-inplace version dispatch function is much safer in autodiff.
  2. let the register element type be consistent with the parameter type: zero_state(Complex{T}, nqubits(c))=>c).

Many thanks!

I try to apply the above code to compute the hessian of a VQC, which includes a custom gate. However, it throws an error as:

setparams!(x, θ...) is not implemented

it says i did not implement some functions for this custom gate. But i already implement those functions mentioned in the docs, niparams, getiparams and setiparams!.

Did i miss something?

YaoBlocks.setiparams!(x::AbstractBlock, a::Number, xs::Number...)

According to your error message, you function falls back to the above interface. Try to be more concrete than that one? And also please make sure you have imported this function before overloading.

YaoBlocks.setiparams!

Many thanks!

It turns out that i forget to implement setiparams for my custom gate to use dispatch. Now i get the desired hessian matrix.