Tests currently broken
martincornejo opened this issue · 6 comments
Testing the current master branch, with Julia 1.9.0, results in the following:
Pass | Error | Total | Time |
---|---|---|---|
1067 | 10 | 1077 | 5m48.0s |
It seems some of the broken tests are related to JuliaGaussianProcesses/AbstractGPs.jl#356, while others are caused by Zygote.
versioninfo()
Julia Version 1.9.0
Commit 8e63055292 (2023-05-07 11:25 UTC)
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 8 × Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
Threads: 4 on 8 virtual cores
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 4
This is the error message of some broken tests
GP: Error During Test at C:\Users\Cornejo\Documents\GitHub\Stheno.jl\test\gp\atomic_gp.jl:14
Test threw exception
Expression: mean(f, x) == AbstractGPs._map_meanfunction(m, x)
UndefVarError: `_map_meanfunction` not defined
Minimal example to reproduce the Zygote errors:
using Stheno
import Zygote
f = @gppp let
f1 = GP(SEKernel())
f2 = GP(Matern52Kernel())
f3 = f1 + f2
end
x = GPPPInput(:f3, randn(5))
y = randn(5)
Zygote.pullback(mean, f3, y)
Edit: simplified the example
Ugh -- this second case looks like an example of someone having written an rrule
which makes assumptions that are too strong. Debugging now.
The Statcktrace leads here?
Stheno.jl/src/gp/derived_gp.jl
Line 19 in bd15654
Zygote is somehow overseeing the following method definition for mean
:
Stheno.jl/src/affine_transformations/addition.jl
Lines 20 to 22 in bd15654
This is the potential the commit that introduced this behavior in ChainRules: JuliaDiff/ChainRules.jl@8424476
Committed December 2022 (last commit in Stheno is June 2022)
The rrule
for mean
is currently defined as following:
function rrule(
config::RuleConfig{>:HasReverseMode},
::typeof(mean),
f::F,
x::AbstractArray{T};
dims=:,
) where {F, T<:Union{Real,Complex,AbstractArray}}
y_sum, sum_pullback = rrule(config, sum, f, x; dims)
n = _denom(x, dims)
function mean_pullback_f(ȳ)
return sum_pullback(unthunk(ȳ) / n)
end
return y_sum / n, mean_pullback_f
end
Probably specifying where {F<:Function, ...
would fix it? I will try it out and open a PR if that works.
Edit: Adding that type check is probably not a solution: JuliaDiff/ChainRules.jl#522. ChainRules also tests that non-function callables should also work https://github.com/JuliaDiff/ChainRules.jl/blob/11c230cdf0f37a4f42de909d6c1f8500d1a80d69/test/rulesets/Statistics/statistics.jl#L18
Resolved by #244