JuliaGaussianProcesses/Stheno.jl

Tests currently broken

martincornejo opened this issue · 6 comments

Testing the current master branch, with Julia 1.9.0, results in the following:

Pass Error Total Time
1067 10 1077 5m48.0s

It seems some of the broken tests are related to JuliaGaussianProcesses/AbstractGPs.jl#356, while others are caused by Zygote.

versioninfo()
Julia Version 1.9.0
Commit 8e63055292 (2023-05-07 11:25 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 8 × Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 4 on 8 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 4

This is the error message of some broken tests

GP: Error During Test at C:\Users\Cornejo\Documents\GitHub\Stheno.jl\test\gp\atomic_gp.jl:14
  Test threw exception
  Expression: mean(f, x) == AbstractGPs._map_meanfunction(m, x)
  UndefVarError: `_map_meanfunction` not defined

Minimal example to reproduce the Zygote errors:

using Stheno
import Zygote

f = @gppp let
    f1 = GP(SEKernel())
    f2 = GP(Matern52Kernel())
    f3 = f1 + f2
end

x = GPPPInput(:f3, randn(5))
y = randn(5)
Zygote.pullback(mean, f3, y)

Edit: simplified the example

Ugh -- this second case looks like an example of someone having written an rrule which makes assumptions that are too strong. Debugging now.

The Statcktrace leads here?

AbstractGPs.mean(f::DerivedGP, x::AbstractVector) = mean(f.args, x)

Zygote is somehow overseeing the following method definition for mean:

const add_args = Tuple{typeof(+), AbstractGP, AbstractGP}
mean((_, fa, fb)::add_args, x::AV) = mean(fa, x) .+ mean(fb, x)

This is the potential the commit that introduced this behavior in ChainRules: JuliaDiff/ChainRules.jl@8424476

Committed December 2022 (last commit in Stheno is June 2022)

The rrule for mean is currently defined as following:

function rrule(
    config::RuleConfig{>:HasReverseMode},
    ::typeof(mean),
    f::F,
    x::AbstractArray{T};
    dims=:,
) where {F, T<:Union{Real,Complex,AbstractArray}}
    y_sum, sum_pullback = rrule(config, sum, f, x; dims)
    n = _denom(x, dims)
    function mean_pullback_f(ȳ)
        return sum_pullback(unthunk(ȳ) / n)
    end
    return y_sum / n, mean_pullback_f
end

Probably specifying where {F<:Function, ... would fix it? I will try it out and open a PR if that works.

Edit: Adding that type check is probably not a solution: JuliaDiff/ChainRules.jl#522. ChainRules also tests that non-function callables should also work https://github.com/JuliaDiff/ChainRules.jl/blob/11c230cdf0f37a4f42de909d6c1f8500d1a80d69/test/rulesets/Statistics/statistics.jl#L18

Resolved by #244