ITensor/ITensors.jl

[ITensors] [BUG] Pullback of addition of MPO's not working

ntausend opened this issue · 3 comments

Description of bug

Hey, while studying some optimization of quantum circuits, I came across some strange bug in the the pullback of the addition of two MPO's. I did not checked if the same bug also appears in case of adding two MPS's. For the example, consider a very dumb function $f(O) = O - O = 0$ which is just the null function on the space of MPO's. Thus, also its pullback should map any dual vector to the null vector. However, while calculating the pullback on MPO's, the pullback on a test MPO is non-zero.

A sanity check using dense matrix representations of the same operation shows the expected behavior.

The observed behavior is also independent from the element type (Float/Complex), using symmetric/non-symmetric tensors or the number of legs of the MPO.

Minimal code demonstrating the bug or unexpected behavior

Minimal runnable code

using ITensors
using Zygote: pullback
using Random: seed!

seed!(1234)
N = 2
elT = Float64

ninds = 2
#idx = [Index(N)]#, Index(N) 
idx = [Index(N) for _ in 1:ninds]
tot_dim = N^ninds

zero_mat = zeros(elT, tot_dim, tot_dim)

cl = combiner(idx)

# extracting the matrix representation of the MPO for comparison
to_matrix(O) = matrix(contract(O) * dag(cl) * prime(cl))
# printing only 3 digits
round_mat(M) = map(m -> round(m, digits = 3), M)

# generate the test MPO
G = MPO(randomITensor(elT, prime.(idx)..., dag.(idx)...), idx)
# matrix version for comparison
Gm = to_matrix(G)

# generate the dual vector MPO
M = MPO(randomITensor(elT, prime.(idx)..., dag.(idx)...), idx)
# matrix version for comparison
Mm = to_matrix(M)

# Null function
f(O) = O - O

fG = to_matrix(f(G))
# should be a simple tot_dimxtot_dim zero matrix
println("Is fG equal to $(tot_dim)x$(tot_dim) null matrix? $(fG  zero_mat)") # returns true as expected

# now calculate the pullback w.r.t. to the test MPO M
∇fG = to_matrix(pullback(f, G)[2](M)[1])

# should be again the tot_dimxtot_dim zero matrix
println("Is ∇fG equal to $(tot_dim)x$(tot_dim) null matrix? $(∇fG  zero_mat)") # returns false
println(round_mat(∇fG))

# sanity check with dens matrix representations
∇fGm = pullback(f, Gm)[2](Mm)[1]
println("Is ∇fGm equal to $(tot_dim)x$(tot_dim) null matrix? $(∇fGm  zero_mat)") # returns true as expected

Expected output or behavior

The expected behavior would be that $\nabla fG$ is the 4x4 null matrix. This would correspond to an output:

Is fG equal to 4x4 null matrix? true
Is ∇fG equal to 4x4 null matrix? true
[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
Is ∇fGm equal to 4x4 null matrix? true

Actual output or behavior

The pullback is some non-zero MPO. The output of the minimal code example:

Is fG equal to 4x4 null matrix? true
Is ∇fG equal to 4x4 null matrix? false
[3.244 -1.527 -1.069 -1.674; -0.2 -1.452 -0.481 1.885; 0.892 0.921 -2.89 0.743; 3.905 0.647 -1.516 2.339]
Is ∇fGm equal to 4x4 null matrix? true

Version information

  • Output from versioninfo():
julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × 13th Gen Intel(R) Core(TM) i7-1355U
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, goldmont)
  Threads: 1 on 12 virtual cores
  • Output from using Pkg; Pkg.status("ITensors"):
julia> using Pkg; Pkg.status("ITensors")
[9136182c] ITensors v0.3.54

I also figured out, that the pullback returned in the example, is exactly the pullback expected for the identity function $f(M) = M$.

Thanks for the report, maybe this rrule is incorrect: https://github.com/ITensor/ITensors.jl/blob/v0.3.55/src/ITensorChainRules/mps/mpo.jl#L33-L35

Does it work if you change that to:

function rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...)
  y = -(x1, x2; kwargs...)
  function subtract_pullback(ȳ)
    return (NoTangent(), ȳ, -ȳ)
  end
  return y, subtract_pullback
end

?

Looks like that works, when I use that new rrule your code outputs:

Is fG equal to 4x4 null matrix? true
Is ∇fG equal to 4x4 null matrix? true
[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
Is ∇fGm equal to 4x4 null matrix? true