NetworkGlow has no field logdet
Closed this issue · 4 comments
Hi, I stumbled across this error message (see title) when trying to train a Glow network (but also applies to Hint network).
MWE:
using InvertibleNetworks, Flux
# Glow Network
model = NetworkGlow(2, 32, 2, 5)
# dummy input & target
X = randn(Float32, 1, 1, 2, 1)
Y = 2 .* X .+ 1
# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X)[1])
θ = Flux.params(model)
opt = ADAM(0.001f0)
for i = 1:5
l, grads = Flux.withgradient(θ) do
loss(model, X, Y)
end
@show l
Flux.update!(opt, θ, grads)
end
Hi, are there any news on this? Would be really useful if one could train the INNs as simple as any other Flux model.
Hello,
Sorry! I missed this discussion or probably forgot about this. There is an easy fix where we give GlowNetwork the optional logdet and then if logdet=false you can train it as you describe above. Would that be helpful?
If so I can make that PR in a couple of hours no problem
Yes, this would be fabulous, thank you!
All right pushed that quick fix. I want to be clear again that this will only work for logdet=false. Currently tracking/differentiating the logdet is a bit difficult to do with Julia AD. I think it is possible it just needs some time when I have that later.
I added the MWE that you suggested here:
https://github.com/slimgroup/InvertibleNetworks.jl/blob/master/examples/chainrules/train_with_flux.jl
I just had to increase the dimensionality of the input because the actnorm layer was exploding over the variance over a single element.
I hope this helps, Thank you for the input!