slimgroup/InvertibleNetworks.jl

Reverse net has no rrule

ziyiyin97 opened this issue · 3 comments

using InvertibleNetworks, LinearAlgebra, Test, Flux
using ChainRulesCore:rrule
# Initialize invertible/non-invertible layers
nx = 32
ny = 32
n_ch = 16
n_hidden = 64
batchsize = 2
logdet = false
N1 = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")

# Gradient Test
X = randn(Float32, nx, ny, n_ch, batchsize)
Y0 = randn(Float32, nx, ny, n_ch, batchsize)

## test Reverse network AD

Nrev = reverse(N1)
rrule(N1, X)
rrule(Nrev, X)

g = gradient(X -> 0.5f0*norm(Nrev(X) - Y0)^2, X)

In this script, rrule(N1, X) is well-defined but rrule(Nrev, X) is not defined. Thus, the last line throws an error

ERROR: LoadError: Mutating arrays is not supported -- called copyto!(::SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#445#446"{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/array.jl:74
  [3] (::Zygote.var"#2347#back#447"{Zygote.var"#445#446"{SubArray{Float32, 3, Array{Float32, 4}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:871 [inlined]
  [5] Pullback
    @ ./broadcast.jl:868 [inlined]
  [6] Pullback
    @ ./broadcast.jl:864 [inlined]
  [7] (::typeof((materialize!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_conv1x1.jl:214 [inlined]
  [9] (::typeof((#inverse#255)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_conv1x1.jl:203 [inlined]
 [11] #212
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
 [12] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof((inverse))}})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [13] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:29 [inlined]
 [14] (::typeof((#_predefined_mode#133)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [15] (::Zygote.var"#212#213"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof((#_predefined_mode#133))})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
 [16] #1750#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [17] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:28 [inlined]
 [18] (::typeof((_predefined_mode)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [19] #212
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
 [20] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof((_predefined_mode))}})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [21] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
 [22] (::typeof((λ)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [23] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof((λ))})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
 [24] #1750#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [25] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
 [26] (::typeof((λ)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [27] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_hint.jl:193 [inlined]
 [28] (::typeof((#inverse#298)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [29] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/layers/invertible_layer_hint.jl:158 [inlined]
 [30] #212
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
 [31] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, typeof((inverse))}})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [32] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:29 [inlined]
 [33] (::typeof((#_predefined_mode#133)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [34] (::Zygote.var"#212#213"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof((#_predefined_mode#133))})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
 [35] #1750#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [36] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:28 [inlined]
 [37] (::typeof((_predefined_mode)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [38] #212
    @ ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203 [inlined]
 [39] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof((_predefined_mode))}})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [40] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
 [41] (::typeof((λ)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof((λ))})(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/lib/lib.jl:203
 [43] #1750#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [44] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:40 [inlined]
 [45] (::typeof((λ)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [46] Pullback
    @ ~/.julia/dev/InvertibleNetworks/src/utils/neuralnet.jl:135 [inlined]
 [47] (::typeof((λ)))(Δ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [48] Pullback
    @ ~/.julia/dev/InvertibleNetworks/test/MFE.jl:21 [inlined]
 [49] (::typeof((#1)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [50] (::Zygote.var"#56#57"{typeof((#1))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
 [51] gradient(f::Function, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [52] top-level scope
    @ ~/.julia/dev/InvertibleNetworks/test/MFE.jl:21
 [53] include(fname::String)
    @ Base.MainInclude ./client.jl:451
in expression starting at /Users/francisyin/.julia/dev/InvertibleNetworks/test/MFE.jl:21

I think it's because rrule isn't defined for reverse net here

function ChainRulesCore.rrule(net::Union{NeuralNetLayer,InvertibleNetwork}, X::AbstractArray{T, N};
, forward_update! isn't defined for reverse net either
forward_update!(state, X, Y, logdet, net)
Any suggestion for me to fix this? Thanks

Is this for every network? Not all layers/network have their inverse implemented but should throw a notImplementederror if the case

The reverse is implemented here

reverse(N::InvertibleNetwork) = Reversed(tag_as_reversed!(deepcopy(N), true))
so it's always well-defined. It's just rrule that is empty (it happens for the reverse of NetworkGlow as well)

Correct me if wrong but something like this could work for any generic invertible network?

function ChainRulesCore.rrule(net::Reversed, z::AbstractArray{T, N}) where {T, N}
    x = net.I.inverse(z)
    return x, Δ -> reverse(net.I).backward(Δ,x)
end