`expect(::AbstractBlock, ::ArrayReg)` is not recognized as typestable
jlbosse opened this issue · 2 comments
Currently the output type of expect(::PutBlock, ::ArrayReg)
is inferred to be Union{ArrayReg{...}, Complex}
and the output
type of expect(::AddBlock, ::ArrayReg)
is inferred to be Any
. This is a problem in one of my applications, because I wanted to use return_types(expect, Tuple{typeof(hamiltonian),typeof{register}})
to preallocate some memory. See the following MWE to confirm the non-inferability:
using Yao
nq = 2
O1 = put(nq, 1 => X)
O2 = 0.5 * put(nq, 2 => Y)
H = O1 + O2
reg = rand_state(nq)
@code_warntype expect(O1, reg)
# output
MethodInstance for YaoAPI.expect(::PutBlock{2, 1, XGate}, ::ArrayReg{2, ComplexF64, Matrix{ComplexF64}})
from expect(op::AbstractBlock, reg::AbstractRegister) in YaoBlocks at .../Yao/lib/YaoBlocks/src/blocktools.jl:112
Arguments
#self#::Core.Const(YaoAPI.expect)
op::PutBlock{2, 1, XGate}
reg::ArrayReg{2, ComplexF64, Matrix{ComplexF64}}
Body::Union{ArrayReg{2, ComplexF64, Matrix{ComplexF64}}, ComplexF64}
1 ─ %1 = YaoBlocks.:var"'"(reg)::AdjointRegister{2, ArrayReg{2, ComplexF64, Matrix{ComplexF64}}}
│ %2 = YaoBlocks.copy(reg)::ArrayReg{2, ComplexF64, Matrix{ComplexF64}}
│ %3 = YaoBlocks.apply!(%2, op)::ArrayReg{2, ComplexF64, Matrix{ComplexF64}}
│ %4 = (%1 * %3)::Union{ArrayReg{2, ComplexF64, Matrix{ComplexF64}}, ComplexF64}
└── return %4
@code_warntype expect(H, reg)
# output
MethodInstance for YaoAPI.expect(::Add{2}, ::ArrayReg{2, ComplexF64, Matrix{ComplexF64}})
from expect(op::AbstractAdd, reg::AbstractRegister) in YaoBlocks at .../Yao/lib/YaoBlocks/src/blocktools.jl:165
Arguments
#self#::Core.Const(YaoAPI.expect)
op::Add{2}
reg::ArrayReg{2, ComplexF64, Matrix{ComplexF64}}
Locals
#194::YaoBlocks.var"#194#195"{ArrayReg{2, ComplexF64, Matrix{ComplexF64}}}
Body::Any
1 ─ %1 = YaoBlocks.:(var"#194#195")::Core.Const(YaoBlocks.var"#194#195")
│ %2 = Core.typeof(reg)::Core.Const(ArrayReg{2, ComplexF64, Matrix{ComplexF64}})
│ %3 = Core.apply_type(%1, %2)::Core.Const(YaoBlocks.var"#194#195"{ArrayReg{2, ComplexF64, Matrix{ComplexF64}}})
│ (#194 = %new(%3, reg))
│ %5 = #194::YaoBlocks.var"#194#195"{ArrayReg{2, ComplexF64, Matrix{ComplexF64}}}
│ %6 = YaoBlocks.sum(%5, op)::Any
└── return %6
After some digging the culprit seems to be that *(::AdjointRegister{D,<:ArrayReg}, ::ArrayReg{D})
is not type stable. This is the code in question:
Yao.jl/lib/YaoArrayRegister/src/operations.jl
Lines 161 to 171 in 4bc2a6f
IMHO *(bra::AdjointRegister, ket::ArrayReg)
should always return an ArrayReg
, albeit one on zero qudits if the dimensions match. This is similar to the case in LinearAlgebra:
using LinearAlgebra
v = rand(10, 1)
v' * v
# output is a matrix, not a float.
1×1 Matrix{Float64}
2.584654
and there should be another method dot(::ArrayReg, ::ArrayReg)
that is guaranteed to return a number and not a register.
However, I am not deep enough into the code base, to understand how feasible that approach would be (and in particular if there are cases where *(::AdjointReg, ::ArrayReg)
gets called and it is not clear
whether a number or a register is expected as the output. And this change would be a breaking change, if any users rely on the current behaviour.
Another, simpler fix would be to annotate expect(op::AbstractBlock{D}, reg::AbstractArrayReg{D,T,AT})::T where {D,T,AT}
to fix the output type of expect to be the data type of the register. Again, I am not deep enough in the code, to understand whether there are cases where this currently is not the case or where this behaviour would be undesirable.
I think the question would be what does it mean for basis after dot
product? The reason it should be a number is that all basis is contracted and thus no register is left. But there are some qubits left if the register is partially contracted.
Since the next community call is approaching next Thursday, we can probably discuss this during the community call.
And this change would be a breaking change if any users rely on the current behavior.
I don't opposite to breaking changes. There have been some breaking changes in our last minor release, and very likely more to come, since we are not 1.0 yet, so we do not guarantee back-compatibility in 0.x versions (tho we try to be compatible). But I'd rather choose the right thing instead of the compatible thing at this stage.
I feel like dot
should always return a number and error if the number of qubits don't match. Discussing during the next community call sounds good!