/FluxPrune.jl

Pruning framework and methods for Flux

Primary LanguageJuliaMIT LicenseMIT

FluxPrune

Build Status

FluxPrune.jl provides iterative pruning algorithms for Flux models. Pruning strategies can be unstructured or structured. Unstructured strategies operate on arrays, while structured strategies operate on layers.

Examples

Unstructured edge pruning

using Flux, FluxPrune
using MLUtils: flatten

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), flatten, Dense(512, 10))
# prune all weights to 70% sparsity= prune(LevelPrune(0.7), m)
# prune all weights with magnitude lower than 0.5= prune(ThresholdPrune(0.5), m)
# prune each layer in a Chain at a different rate
# (just uses broadcasting then re-Chains)= prune([LevelPrune(0.4), LevelPrune(0.6), identity, LevelPrune(0.7)], m)

Structured channel pruning

using Flux, FluxPrune
using MLUtils: flatten

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), flatten, Dense(512, 10))
# prune all conv layer channels to 30% sparsity= prune(ChannelPrune(0.3), m)

Mixed pruning

using Flux, FluxPrune
using MLUtils: flatten

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), flatten, Dense(512, 10))
# apply channel and edge pruning= prune([ChannelPrune(0.3), ChannelPrune(0.4), identity, LevelPrune(0.8)], m)

Iterative pruning

Target pruning levels step-by-step. The first argument to iterativeprune (or the function block after the do statement) will finetune the model and return true to indicate moving onto the next stage, or false to indicate that finetune must be called again.

using Flux, FluxPrune
using MLUtils: flatten
using Statistics: mean

features = rand(Float32, 8, 8, 3, 100);
labels = Flux.onehotbatch(rand(0:9, 100), 0:9);
data = (features, labels);
loss(m, x, y) = Flux.Losses.mse(m(x), y)
accuracy(m, data) = mean(Flux.onecold(m(data[1]), 0:9) .== Flux.onecold(data[2], 0:9))
target_accuracy = 0.08 # random data, so this is a low target

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), flatten, Dense(512, 10), softmax)
opt_state = Flux.setup(Momentum(), m);

stages = [
    [ChannelPrune(0.1), ChannelPrune(0.1), identity, LevelPrune(0.4), identity],
    [ChannelPrune(0.2), ChannelPrune(0.3), identity, LevelPrune(0.7), identity],
    [ChannelPrune(0.3), ChannelPrune(0.5), identity, LevelPrune(0.9), identity]
]
m̄ = iterativeprune(stages, m) dofor epoch in 1:10
        Flux.train!(loss, m̄, [data], opt_state)
    end
    return accuracy(m̄, data) > target_accuracy
end