slimgroup/InvertibleNetworks.jl

Flux friendlyness

Closed this issue · 1 comments

Hi,

I think it would be very useful if there is a plug-and-play option to train INNs of this package using standard Flux API. Failing to train a simple INN using the following training script:

using InvertibleNetworks, Flux

# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 10
batchsize = 32

# net
AN = ActNorm(n_in; logdet = false)
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
model = Chain(AN, C)

# dummy input & target
X = randn(Float32, nx, ny, n_in, batchsize)
Y = 2 .* X .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X))

# old, implicit-style Flux
θ = Flux.params(model)
opt = ADAM(0.001f0)

for i = 1:5
    l, grads = Flux.withgradient(θ) do
        loss(model, X, Y)
    end

    @info "Loss: $l"

    Flux.update!(opt, θ, grads)
end

Running this code, the loss stays the same (parameters do not seem to be updated). I do not know if this style of training is by default not supported currently, or if it's simply some bug.

I think it would be a useful feature to ease things up. E.g. for my use case, I want to incorporate INNs in some personal larger project which uses Flux and I only need the guaranteed invertibility of the models after training, else, they should just behave as any other custom Flux model.

Hi

This was indeed an oversight, this is being fixed in #79 , thanks for reporting it