Reverse net has no rrule
ziyiyin97 opened this issue · 3 comments
ziyiyin97 commented
using InvertibleNetworks, LinearAlgebra, Test, Flux
using ChainRulesCore:rrule
# Initialize invertible/non-invertible layers
nx = 32
ny = 32
n_ch = 16
n_hidden = 64
batchsize = 2
logdet = false
N1 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")
# Gradient Test
X = randn(Float32, nx, ny, n_ch, batchsize)
Y0 = randn(Float32, nx, ny, n_ch, batchsize)
## test Reverse network AD
Nrev = reverse(N1)
rrule(N1, X)
rrule(Nrev, X)
g = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)
In this script, rrule(N1, X)
is well-defined but rrule(Nrev, X)
is not defined. Thus, the last line throws an error
ERROR: LoadError: Mutating arrays is not supported -- called copyto!(::SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}, _...)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#445#446"{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/array.jl:74
[3] (::Zygote.var"#2347#back#447"{Zygote.var"#445#446"{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ./broadcast.jl:871 [inlined]
[5] Pullback
@ ./broadcast.jl:868 [inlined]
[6] Pullback
@ ./broadcast.jl:864 [inlined]
[7] (::typeof(∂(materialize!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_conv1x1.jl:214 [inlined]
[9] (::typeof(∂(#inverse#255)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_conv1x1.jl:203 [inlined]
[11] #212
@ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
[12] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof(∂(inverse))}})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[13] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:29 [inlined]
[14] (::typeof(∂(#_predefined_mode#133)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[15] (::Zygote.var"#212#213"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof(∂(#_predefined_mode#133))})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
[16] #1750#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[17] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:28 [inlined]
[18] (::typeof(∂(_predefined_mode)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[19] #212
@ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
[20] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(_predefined_mode))}})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[21] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
[22] (::typeof(∂(λ)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[23] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(λ))})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
[24] #1750#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[25] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
[26] (::typeof(∂(λ)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[27] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_hint.jl:193 [inlined]
[28] (::typeof(∂(#inverse#298)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[29] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_hint.jl:158 [inlined]
[30] #212
@ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
[31] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof(∂(inverse))}})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[32] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:29 [inlined]
[33] (::typeof(∂(#_predefined_mode#133)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[34] (::Zygote.var"#212#213"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof(∂(#_predefined_mode#133))})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
[35] #1750#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[36] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:28 [inlined]
[37] (::typeof(∂(_predefined_mode)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[38] #212
@ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
[39] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(_predefined_mode))}})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[40] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
[41] (::typeof(∂(λ)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[42] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(λ))})(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
[43] #1750#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[44] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
[45] (::typeof(∂(λ)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[46] Pullback
@ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:135 [inlined]
[47] (::typeof(∂(λ)))(Δ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[48] Pullback
@ ~/.julia/dev/InvertibleNetworks/test/MFE.jl:21 [inlined]
[49] (::typeof(∂(#1)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[50] (::Zygote.var"#56#57"{typeof(∂(#1))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
[51] gradient(f::Function, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
[52] top-level scope
@ ~/.julia/dev/InvertibleNetworks/test/MFE.jl:21
[53] include(fname::String)
@ Base.MainInclude ./client.jl:451
in expression starting at /Users/francisyin/.julia/dev/InvertibleNetworks/test/MFE.jl:21
I think it's because rrule
isn't defined for reverse net here
InvertibleNetworks.jl/src/utils/chainrules.jl
Line 127 in 192a7fe
forward_update!
isn't defined for reverse net either InvertibleNetworks.jl/src/utils/chainrules.jl
Line 134 in 192a7fe
mloubout commented
Is this for every network? Not all layers/network have their inverse implemented but should throw a notImplementederror if the case
ziyiyin97 commented
The reverse is implemented here
so it's always well-defined. It's justrrule
that is empty (it happens for the reverse of NetworkGlow as well)ziyiyin97 commented
Correct me if wrong but something like this could work for any generic invertible network?
function ChainRulesCore.rrule(net::Reversed, z::AbstractArray{T, N}) where {T, N}
x = net.I.inverse(z)
return x, Δ -> reverse(net.I).backward(Δ,x)
end