denizyuret/Knet.jl

Feature Request: Instance Normalization

Opened this issue · 1 comments

Would be nice to have. Computes the mean and variance for each each W×H×1×1 slice and
shifts them to have a new mean and variance

Ref:
https://arxiv.org/abs/1607.08022

Flux implementation:
https://github.com/FluxML/Flux.jl/blob/2b1ba184d1a58c37543f4561413cddb2de594289/src/layers/normalise.jl#L249-L276

Currently I'm implementing Instance Norm in my project as follows. However, it is not using the moving mean and scale since it was not needed.

struct InstanceNorm
scale
offset
end
function InstanceNorm(nChannels)
scale = Param((param(1,1,nChannels,1, init = gaussian) .* 2) .+ 1)
offset = param0(1,1,nChannels,1)
InstanceNorm(scale, offset)
end
function (normLayer::InstanceNorm)(x)
len = length(size(x))
mu = mean(x, dims=collect(1:len-2))
variance = var(x, dims=collect(1:len-2))
sigma = sqrt.(variance .+ 1e-5)
normalized = (x .- mu) ./ sigma
(normLayer.scale .* normalized) .+ normLayer.offset
end