Jutho/TensorKit.jl

Problem with gradient when using tensors with Float64 and ComplexF64 entries

Closed this issue · 4 comments

I am trying to compute the gradient of function, that takes a matrix as input and contracts with another tensor. If both tensors have either Float64 or ComplexF64 values, the Zygote gradient works. However, if one has Float64 and the other one ComplexF64 entries, it fails and returns an InexactError: Float64() error. Below I provide a MWE, which has the same behaviour as the actual function I need.

normDiff(a, b) = norm(a - b);

A = TensorMap(randn, Float64, ComplexSpace(2), ComplexSpace(2));
B = TensorMap(randn, Float64, ComplexSpace(2), ComplexSpace(2));
Zygote.gradient((x, y) -> normDiff(x, y), A, B)

A = TensorMap(randn, ComplexF64, ComplexSpace(2), ComplexSpace(2));
B = TensorMap(randn, ComplexF64, ComplexSpace(2), ComplexSpace(2));
Zygote.gradient((x, y) -> normDiff(x, y), A, B)

A = TensorMap(randn, Float64, ComplexSpace(2), ComplexSpace(2));
B = TensorMap(randn, ComplexF64, ComplexSpace(2), ComplexSpace(2));
Zygote.gradient((x, y) -> normDiff(x, y), A, B)

Could you test this with the master version of TensorKit.jl? There was some fix for the chainrules in the case of mixed scalartype, which was not yet included in the latest registered version. We will provide a version update soon.

I had the same problem and updating TensorKit.jl to the current master fixed it.

Thanks, using the current master fixes the problem.

Tagged in v0.12.1