Flux friendlyness
Closed this issue · 1 comments
flo-he commented
Hi,
I think it would be very useful if there is a plug-and-play option to train INNs of this package using standard Flux API. Failing to train a simple INN using the following training script:
using InvertibleNetworks, Flux
# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 10
batchsize = 32
# net
AN = ActNorm(n_in; logdet = false)
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
model = Chain(AN, C)
# dummy input & target
X = randn(Float32, nx, ny, n_in, batchsize)
Y = 2 .* X .+ 1
# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X))
# old, implicit-style Flux
θ = Flux.params(model)
opt = ADAM(0.001f0)
for i = 1:5
l, grads = Flux.withgradient(θ) do
loss(model, X, Y)
end
@info "Loss: $l"
Flux.update!(opt, θ, grads)
end
Running this code, the loss stays the same (parameters do not seem to be updated). I do not know if this style of training is by default not supported currently, or if it's simply some bug.
I think it would be a useful feature to ease things up. E.g. for my use case, I want to incorporate INNs in some personal larger project which uses Flux and I only need the guaranteed invertibility of the models after training, else, they should just behave as any other custom Flux model.