FluxML/ParameterSchedulers.jl

Please take any useful ideas from my CyclicOptimisers

Closed this issue · 1 comments

In working on my bachelor project, I have implemented a type CyclicOptimiser. It acts as a drop-in replacement for regular optimisers, by adding a new method to Flux.update!.

I can see that I have reproduced a lot of the work in here. So I simply wanted to offer up my implementation, and let you decide if I have had any ideas that might add value to this project.

using Flux
using Flux.Optimise: AbstractOptimiser
using Base.Iterators: Stateful, Cycle
import UnicodePlots
import Flux.update!  # Learning rate is updated each time Flux.update! is called, allowing seemless drop-in-replacement of normal optimisers with cyclic optimisers.
import Base: show
round3(x) = round(x, sigdigits=3)

function optimiser_to_string(opt::AbstractOptimiser)
    fldnms = fieldnames(typeof(opt))
    fields = getfield.([opt], fldnms)
    fieldtypes = typeof.(fields)
    output = string(typeof(opt)) * "("
    for i in eachindex(fields)
        if fieldtypes[i] <: IdDict
            output *= "..., "
        else
            fldnms[i] == :eta ? (output *= "$(fields[i]|>round3), ") : (output *= "$(fields[i]), ")
        end
    end
    output = output[begin:end-2] * ")"
    return output
end

"""
struct CycleMagnitude
    len::Int
    magfac::Float64
end

A type to be used as a functor with the purpose of 
calculating a magnitude that is changed discretely 
by a factor `magfac` each time `len` cycles are completed.
"""
struct CycleMagnitude
    len::Int
    magfac::Float64
end

"""
    (cyc::CycleMagnitude)(x) = cyc.magfac ^ (x÷cyc.len)

Compute a magnitude that is multiplied by `cyc.magfac` 
every time the input increases by cyc.len.

The input is intended to be the `taken` field of a 
Cycle(Stateful(my_collection)).

Note that for the actual calculation, the learning rate 
needs to be shifted so that the smallest value in the 
cycle is 0 before scaling, and shifted back up after scaling.
"""
(cyc::CycleMagnitude)(x) = cyc.magfac ^ (x÷cyc.len)


abstract type AbstractCycler end
struct TriangleCycler <: AbstractCycler
    cycle::Stateful{Cycle{A}} where {A<:AbstractVector}
end
show(io::IO, cyc::AbstractCycler) = println(io, "Cycler with values $(cyc.cycle.itr.xs).\nCycled $(cyc.cycle.taken) times")

cycle!(cycler::AbstractCycler) = popfirst!(cycler.cycle)

"""
    TriangleCycler(lower, upper, len)

Construct a TriangleCycler containing a set 
of `len` values values that goes from `lower` 
up to `upper` and back down again. Plotted against 
its index, the returned set looks like 
a triangle with 2 equal legs.

If the `len` is odd, the first and last point will 
be the same, causing repetition when cycled.
"""
function TriangleCycler(lower, upper, len)
    if len == 1  # Special case to avoid the error from range(a_number, another_number != a_number, length=1)
        cycle = [(lower+upper)/2]
    elseif iseven(len) 
        cycle = vcat(range(lower, upper; length=len÷2+1), reverse(range(lower, upper; length=len÷2+1))[begin+1:end-1])
    else
        cycle = vcat(range(lower, upper; length=len÷2+1), reverse(range(lower, upper; length=len÷2+1))[begin+1:end])
    end

    return TriangleCycler(cycle |> Cycle |> Stateful)
end
show(io::IO, tricy::TriangleCycler) = println(io, "TriangleCycler from $(minimum(tricy.cycle.itr.xs)|>round3) to $(maximum(tricy.cycle.itr.xs)|>round3) of cycle-length $(length(tricy.cycle.itr.xs))")

function check_optimiser(opt::AbstractOptimiser)
    hasfield(typeof(opt), :eta) || "Tried to construct a CyclicOptimiser with $(opt), which has no field eta (e.g. no learningrate parameter)." |> error
    opt isa DataType && "Tried to construct a CyclicOptimiser with an optimiser type (e.g. `Descent`). Try to use a concrete optimiser instead (e.g. `Descent()`)"|>error
    return nothing
end

"""
    struct CyclicOptimiser{T} <: AbstractOptimiser where {T<:AbstractOptimiser}
        current_optimiser::T
        learningrate::AbstractCycler
        cycle_magnitude::CycleMagnitude
    end
"""
struct CyclicOptimiser{T} <: AbstractOptimiser where {T<:AbstractOptimiser}
    current_optimiser::T
    learningrate::AbstractCycler
    cycle_magnitude::CycleMagnitude
    function CyclicOptimiser(opt, learningrate::AbstractCycler, cycmag::CycleMagnitude)
        check_optimiser(opt)
        @assert length(learningrate.cycle.itr.xs) == cycmag.len "Length og learningrate cycle does not match the length of the internal CycleMagnitude."
        return new{typeof(opt)}(opt, learningrate, cycmag)
    end
end


"""
CyclicOptimiser(opt::AbstractOptimiser, lower, upper, len; cycler::AbstractCycler=TriangleCycler, magfac=1)

Construct a CyclicOptimiser. The optimiser whose learning rate is cycled is 
`opt`, the first positional argument. `lower`, `upper` and `len` are passed on 
to `cycler`, constructing an `AbstractCycler` and defaulting to TriangleCycler.

A final keyword argument `magfac` sets the magnitude-controlling factor that 
is applied after a full cycle is completed. So if `magfac` is set to 0.5, then 
the span of the cycle is halved each cycle. The lower limit is pinned, 
so `magfac` only effects the upper limit, to ensure that the learningrate 
decreases each cycle (assuming magfac ≤ 1, which is checked for).
"""
function CyclicOptimiser(opt::AbstractOptimiser, lower, upper, len; cycler=TriangleCycler, magfac=1)
    check_optimiser(opt)
    return CyclicOptimiser(opt, cycler(lower, upper, len), CycleMagnitude(len, magfac))
end

function plot(cycopt::CyclicOptimiser, n_cycles=3)
    xs = 1:cycopt.cycle_magnitude.len*n_cycles
    cycopt = deepcopy(cycopt)
    Iterators.reset!(cycopt.learningrate.cycle)
    ys = [cycle!(cycopt.learningrate) for _ in eachindex(xs)] .* cycopt.cycle_magnitude.(xs)
    return UnicodePlots.scatterplot(xs, ys, xlabel="Iteration", ylabel="Learningrate",
    title="Learningrate for $n_cycles cycles")
end

function show(io::IO, cycopt::CyclicOptimiser)
    print(io, 
    """
    CyclicOptimiser with following properties:
    Current optimiser = $(cycopt.current_optimiser|>optimiser_to_string)
         Learningrate = $(typeof(cycopt.learningrate)) from $(cycopt.learningrate.cycle.itr.xs|>minimum|>round3) to $(cycopt.learningrate.cycle.itr.xs|>maximum|>round3)
          Cyclelength = $(cycopt.cycle_magnitude.len). Magfac = $(cycopt.cycle_magnitude.magfac)""")
end

function cycle!(co::CyclicOptimiser)
    A = co.cycle_magnitude(co.learningrate.cycle.taken)
    lower_bound = co.learningrate.cycle.itr.xs |> minimum
    co.current_optimiser.eta = A * (cycle!(co.learningrate) - lower_bound) + lower_bound
    return co.current_optimiser
end

Flux.update!(cycopt::CyclicOptimiser, xs::Params, gs) = Flux.update!(cycle!(cycopt), xs::Params, gs)

Awesome! Thanks, I'll take a look to see if there's any functionality we should port over.