Julia-Tempering/Pigeons.jl

Sampling issue when adding variational chains

Closed this issue · 40 comments

Hello all,
I have been seeing occasional issues when adding a variational reference. It seems to quite be sensitive to the model and the number of chains used. I have finally found a MWE that consistently reproduces this effect.

When looking at the corner plots below, note the second column plots the log posterior density. The second example includes samples that seem to be far outside the typical set.

I will include the model below.

Without Variational Reference

First, a sampling run with 10 chains and 0 variational chains:
image
image

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.29     0.0133      6e+04       -111   5.57e-14      0.701          1          1 
        4          0       4.27     0.0184   8.76e+04       -164   1.68e-52      0.612          1          1 
        8          0          5     0.0347   1.48e+05       -103     0.0434      0.545          1          1 
       16          0       5.53     0.0665   2.62e+05       -106     0.0758      0.497          1          1 
       32          0       5.39      0.135   4.53e+05       -108      0.305       0.51          1          1 
       64          6       2.19      0.336   3.16e+07       -109      0.584      0.801          1          1 
      128         16       2.34      0.619   6.34e+07       -109      0.649      0.787          1          1 
      256         33       2.34       1.59   1.26e+08       -109      0.727      0.787          1          1 
      512         77       2.28       2.59   2.51e+08       -109      0.753      0.792          1          1 
 1.02e+03        120       2.38       5.15   5.04e+08       -109      0.754      0.784          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

With Variational Reference

Now, a sampling run with 10 chains and 6 variational chains:
image

image
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       2.48       2.74     0.0178    9.4e+04       -128   1.61e-31      0.693          1          1 
        4          0       4.47       2.75     0.0289   1.27e+05       -181   2.87e-66      0.575          1          1 
        8          0       5.04       3.49     0.0552   2.07e+05       -109   0.000299      0.499          1          1 
       16          0          6        3.5      0.109   3.68e+05       -113   4.64e-05      0.441          1          1 
       32          0       5.59       4.27      0.223   6.83e+05       -111     0.0311       0.42          1          1 
       64          4       5.67       3.84      0.447   1.36e+07       -108     0.0224      0.441          1          1 
      128          9       5.59       4.17      0.903   2.74e+07       -111      0.062      0.426          1          1 
      256         19        5.7       4.13        1.9   5.52e+07       -108     0.0407      0.422          1          1 
      512         37       5.76       4.03       3.76   1.11e+08       -109     0.0803      0.424          1          1 
 1.02e+03         67        5.8       4.02       7.84   2.23e+08       -109      0.144      0.422          1          1 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

HMC, for reference

image image

Model Code

To reproduce this example, use the latest #main commit from Octofitter and OctofitterRadialVelocity, e.g. ] add Octofitter#main OctofitterRadialVelocity#main and then run the following:

using Octofitter
using OctofitterRadialVelocity
using Distributions
using PlanetOrbits


epochs = 58849 .+ (20:20:660)
planet_sim_mass = 0.001 # solar masses here


orb_template = orbit(
    a = 1.0,
    e = 0.7,
    # i= pi/4, # You can remove I think
    # Ω = 0.1, # You can remove I think
    ω = 1π/4, # radians
    M = 1.0, # Total mass, not stellar mass FYI
    plx=100.0,
    tp =58829 # Epoch of periastron passage. 
)

rvlike = StarAbsoluteRVLikelihood(
    Table(
        epoch=epochs,
        rv=radvel.(orb_template, epochs, planet_sim_mass),
        σ_rv=fill(5.0, size(epochs)),
    ),
    instrument_names=["simulated"]
)


first_epoch_for_tp_ref = first(epochs)
@planet b RadialVelocityOrbit begin
    e ~ Uniform(0,0.999999)
    a ~ truncated(Normal(1, 1),lower=0)
    mass ~ truncated(Normal(1, 1), lower=0)

    # Remove these, we don't need 'em
    # i ~ Sine()
    # Ω ~ UniformCircular()
    ω ~ UniformCircular()
    θ ~ UniformCircular()

    τ ~ UniformCircular(1.0)
    P = (b.a^3/system.M)
    tp =  b.τ*b.P*365.25 + $first_epoch_for_tp_ref # reference epoch for τ. Choose to be near data
end 

@system SimualtedSystem begin
    M ~ truncated(Normal(1, 0.04),lower=0) # (Baines & Armstrong 2011).
    plx = 100.0
    jitter ~ truncated(Normal(0,10),lower=0)
    rv0 ~ Normal(0, 100)
end rvlike b

model = Octofitter.LogDensityModel(SimualtedSystem)

using Random
rng = Xoshiro(0)


results, pt = octofit_pigeons(model, n_rounds=10, explorer=SliceSampler(), n_chains=12, n_chains_variational=0)

Plotting:

using PairPlots, CairoMakie
octocorner(model, results, small=false, includecols=(:iter,:logpost,:b_ωx,:b_ωy,:b_τx,:b_τy), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))

Additional notes: the columns ω, τ, and P are computed from other variables and not sampled directly from the model.

One complication is that the variables ω,τ, and θ are computed from their counterparts :θx,:θy and so on using atan(θy,θx).

In the likelihood, I set the length of sqrt(θx^2 + θy^2) ~ Normal(1, 0.1) in order to prevent numerical issues if θx and θy are approximately 0.

So of course in the reference models, that sqrt(θx^2 + θy^2) ~ Normal(1, 0.1) is not applied.

Thanks William! Will discuss in more details in our meeting today, but quick question: is the second type of plot showing goodness of fit comparing data to prediction?

Prior only model, sampled with HMC:

image

Variational Chains Only

Here is a corner plot of a sampling run with 6 variational chains and no fixed reference.

image

As you can see, the problem goes away! Very interesting.

Thanks for running that @sefffal ! Definitely helps us narrow down the issue further.

Btw how did you solve the Chains/sample_array/get_sample issue you mentioned with plain Variational PT?

I just pushed a fix to the Octofitter / Pigeons integration: sefffal/Octofitter.jl@49a77c1

Here is a further-reduced example that removes those angular "UniformCircular" parameters.

I note also that the weighting between the number of regular chains and variational chains seems to be important.

using Octofitter
using OctofitterRadialVelocity
using CairoMakie
using PairPlots
using Distributions
using PlanetOrbits


epochs = 58849 .+ (20:20:660)
planet_sim_mass = 0.001 # solar masses here


orb_template = orbit(
    a = 1.0,
    e = 0.7,
    # i= pi/4, # You can remove I think
    # Ω = 0.1, # You can remove I think
    ω = 1π/4, # radians
    M = 1.0, # Total mass, not stellar mass FYI
    plx=100.0,
    tp =58829 # Epoch of periastron passage. 
)
# Makie.lines(orb_template)


rvlike = StarAbsoluteRVLikelihood(
    Table(
        epoch=epochs,
        rv=radvel.(orb_template, epochs, planet_sim_mass),
        σ_rv=fill(5.0, size(epochs)),
    ),
    instrument_names=["simulated"]
)

first_epoch_for_tp_ref = first(epochs)
@planet b RadialVelocityOrbit begin
    e ~ Uniform(0,0.999999)
    a ~ truncated(Normal(1, 1),lower=0)
    mass ~ truncated(Normal(1, 1), lower=0)
    ω ~ Uniform(0,2pi)
    τ ~ Uniform(0.0, 1.0)
    tp =  b.τ*√(b.a^3/system.M)*365.25 + $first_epoch_for_tp_ref 
end 

@system SimualtedSystem begin
    M ~ truncated(Normal(1, 0.04),lower=0) # (Baines & Armstrong 2011).
    plx = 100.0
    jitter ~ truncated(Normal(0,10),lower=0)
    rv0 ~ Normal(0, 100)
end rvlike b

model = Octofitter.LogDensityModel(SimualtedSystem)

results, pt = octofit_pigeons(model, n_rounds=10, explorer=SliceSampler(), n_chains=10, n_chains_variational=6)

octocorner(model, results, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
image

Re:

Variational Chains Only

Here is a corner plot of a sampling run with 6 variational chains and no fixed reference.
...
As you can see, the problem goes away! Very interesting.

Actually,
I am not sure this example is right. I get example the same values with n_chains=8, n_chains_variational=0 as I do with n_chains=0, n_chains_variational=8. How can I be sure this was using the variational reference at all?

You notice the usage because Lambda changes at 5th round. For example,

Traditional PT

inp = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
)
results, pt = octofit_pigeons(inp)

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.07     0.0478   5.42e+05       -928          0      0.659          1          1 
        4          0       3.87     0.0428   5.46e+05       -122   5.95e-17       0.57          1          1 
        8          0       4.79     0.0493    8.3e+04       -120   3.94e-12      0.468          1          1 
       16          0       4.13     0.0969   1.53e+05       -107       0.25      0.541          1          1 
       32          0       4.79      0.184   2.58e+05       -105      0.277      0.468          1          1 
       64          0       5.15      0.358   4.72e+05       -106      0.314      0.428          1          1 
      128          1       4.76      0.746    8.8e+05       -108      0.171      0.471          1          1 
      256          5       4.93       1.48    1.7e+06       -108      0.369      0.452          1          1 
      512          9       5.05       3.01   3.25e+06       -108      0.362      0.439          1          1 
 1.02e+03         14       5.05       6.03   6.43e+06       -108      0.285      0.439          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

Variational (non-stabilized) PT

inp = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    variational = GaussianReference(first_tuning_round = 5)
)
results, pt = octofit_pigeons(inp)

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.07      0.333    1.7e+07       -105   1.71e-08      0.659          1          1 
        4          0       3.87     0.0486   5.48e+05       -113   9.57e-09       0.57          1          1 
        8          0       4.79     0.0525   8.86e+04       -152   1.56e-40      0.468          1          1 
       16          0       4.12     0.0944   1.63e+05       -105      0.361      0.542          1          1 
       32          0       4.72      0.182   2.76e+05       -108      0.185      0.476          1          1 
       64          6       1.82      0.674   2.53e+07       -108      0.623      0.797          1          1 
      128         17       1.82      0.871    3.5e+07       -108      0.698      0.798          1          1 
      256         32       1.75       1.63   6.97e+07       -108      0.757      0.805          1          1 
      512         72       1.77       3.43    1.4e+08       -108      0.687      0.803          1          1 
 1.02e+03        121       2.17       6.81   2.77e+08       -108       0.71      0.759          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

Stabilized PT: note that Lambda is roughly the same as for standard PT, whereas Lambda_var is similar to the final value of (non-stabilized) Variational PT.

inp = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    n_chains_variational=10,
    variational = GaussianReference(first_tuning_round = 5)
)
results, pt = octofit_pigeons(inp)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.54       3.07      0.101   1.12e+06       -105  7.61e-163      0.652          1          1 
        4          0       3.89       3.79     0.0443   1.12e+05       -113   9.57e-09      0.596          1          1 
        8          0       3.71        5.1      0.102   1.89e+05       -149   1.21e-39      0.536          1          1 
       16          0       4.07       4.57      0.172   3.25e+05       -108      0.055      0.545          1          1 
       32          0       4.51        4.3      0.335   5.55e+05       -107       0.26      0.536          1          1 
       64         12       4.85       1.72      0.872   1.97e+07       -107      0.288      0.654          1          1 
      128         18       4.94        1.9        1.5   3.57e+07       -108      0.368       0.64          1          1 
      256         37       5.02       1.82       3.08    7.1e+07       -108      0.144       0.64          1          1 
      512         71       4.98       2.14       6.13    1.4e+08       -108      0.361      0.625          1          1 
 1.02e+03        164          5       1.89       12.5   2.82e+08       -108      0.402      0.637          1          1 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Ah, okay thanks @miguelbiron. Then I am a bit confused about the interaction between n_chains_variational and n_chains.

Is the following correct?
Regular PT: n_chains=8, n_chains_variational=0, variational=nothing
Stabilized Variational PT: n_chains=8, n_chains_variational=8, variational=GaussianReference....
Non-stabilized variational PT: n_chains=8, n_chains_variational=0, variational=GaussianReference....

Correct!

I agree it's pretty confusing we should improve the naming at some point :(

Ok so I'm not getting any issues whatsoever. First, as you can see in my post, the log-normalization constant is the same across the 3 alternatives. Second, the following plots are very similar

PT
pt

VariationalPT
vpt

StabilizedPT
stab_vpt

Btw, I'm using Julia 1.10.5 because 1.11.1 is not working right now.

julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 4 default, 0 interactive, 2 GC (on 8 virtual cores)
Environment:
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_DEBUG = 
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 4

Oh, @miguelbiron my experiments were with Julia 1.11. I know there is an Enzyme failure with 1.11 but are there other issues too?

I'm not sure but 1.11 has been a bumpy ride so I decided to wait until it's on a firmer ground. I'm gonna rerun with 1.11 to check if it makes any difference

Hmm I can't run this line on 1.11.1

julia> Octofitter._kepsolve_use_threads[] = false
ERROR: UndefVarError: `_kepsolve_use_threads` not defined
Stacktrace:
 [1] getproperty(x::Module, f::Symbol)
   @ Base ./Base.jl:31
 [2] top-level scope
   @ ~/projects/Pigeons.jl/test/temp.jl:10

I was able to reproduce the problem on Julia 1.10.

@miguelbiron are you using the latest commit from Octofitter?

julia> versioninfo()
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)

Yes, #main like you suggested. Gonna go back to 1.10 then.

I went through this again.

Blue is stabilized variational PT, gold is regular PT, and green is non-stabilized variational PT.

Is it possible the problem is caused by how I'm using get_sample(pt, 0)? In some situations, could I be getting samples from the wrong chain? IMO that could explain why the numbers of each kind of chain used seems to be interpolating between a ~prior like distribution and the target.

image

That makes sense... Why are you doing get_sample(pt, 0) again?

Btw, I rerun and cannot reproduce for the life of me. This is what I'm running (from Pigeons)

Details

include("activate_test_env.jl")

using Pkg
Pkg.add([PackageSpec(name="Octofitter", rev="main"), PackageSpec(name="OctofitterRadialVelocity", rev="main")])

using Octofitter
using OctofitterRadialVelocity
using CairoMakie
using PairPlots
using Distributions
using PlanetOrbits

Octofitter._kepsolve_use_threads[] = false


epochs = 58849 .+ (20:20:660)
planet_sim_mass = 0.001 # solar masses here


orb_template = orbit(
    a = 1.0,
    e = 0.7,
    # i= pi/4, # You can remove I think
    # Ω = 0.1, # You can remove I think
    ω = 1π/4, # radians
    M = 1.0, # Total mass, not stellar mass FYI
    plx=100.0,
    tp =58829 # Epoch of periastron passage. 
)
# Makie.lines(orb_template)


rvlike = StarAbsoluteRVLikelihood(
    Table(
        epoch=epochs,
        rv=radvel.(orb_template, epochs, planet_sim_mass),
        σ_rv=fill(5.0, size(epochs)),
    ),
    instrument_names=["simulated"]
)

first_epoch_for_tp_ref = first(epochs)
@planet b RadialVelocityOrbit begin
    e ~ Uniform(0,0.999999)
    a ~ truncated(Normal(1, 1),lower=0)
    mass ~ truncated(Normal(1, 1), lower=0)
    ω ~ Uniform(0,2pi)
    τ ~ Uniform(0.0, 1.0)
    tp =  b.τ*√(b.a^3/system.M)*365.25 + $first_epoch_for_tp_ref 
end 

@system SimualtedSystem begin
    M ~ truncated(Normal(1, 0.04),lower=0) # (Baines & Armstrong 2011).
    plx = 100.0
    jitter ~ truncated(Normal(0,10),lower=0)
    rv0 ~ Normal(0, 100)
end rvlike b

model = Octofitter.LogDensityModel(SimualtedSystem)

inp_pt = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    # n_chains_variational=10,
    # variational = GaussianReference(first_tuning_round = 5)
)
inp_vpt = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    # n_chains_variational=10,
    variational = GaussianReference(first_tuning_round = 5)
)
inp_stab_vpt = Pigeons.Inputs(;
    target = model,
    record = [traces; round_trip; record_default(); index_process],
    multithreaded=true,
    show_report=true,
    n_rounds=10, 
    explorer=SliceSampler(), 
    n_chains=10, 
    n_chains_variational=10,
    variational = GaussianReference(first_tuning_round = 5)
)
results_pt, pt = octofit_pigeons(inp_pt)
results_vpt, pt = octofit_pigeons(inp_vpt)
results_stab_vpt, pt = octofit_pigeons(inp_stab_vpt)

p = octocorner(model, results_pt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
save("pt.png", p)
p = octocorner(model, results_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
save("vpt.png", p)
p = octocorner(model, results_stab_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
save("stab_vpt.png", p)

versioninfo()
Pkg.status()

Here's the relevant output (excluding plots since alreay the logZ values are correct)

(...)
      results_pt, pt = octofit_pigeons(inp_pt)
       results_vpt, pt = octofit_pigeons(inp_vpt)
       results_stab_vpt, pt = octofit_pigeons(inp_stab_vpt)
       
       p = octocorner(model, results_pt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
       save("pt.png", p)
       p = octocorner(model, results_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
       save("vpt.png", p)
       p = octocorner(model, results_stab_vpt, small=false, includecols=(:iter,:logpost), viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))
       save("stab_vpt.png", p)
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000014 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 7.09e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000009 seconds
[ Info: Determining initial positions and metric using pathfinder
┌ Info: Found a sample of initial positions
└   initial_logpost_range = (-108.8666220931488, -97.86911327709223)
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0          3      0.048   5.42e+05       -443  7.64e-304      0.667          1          1 
        4          0       3.83     0.0469   5.49e+05       -102     0.0871      0.574          1          1 
        8          0       4.53     0.0466   8.86e+04       -107   4.03e-06      0.497          1          1 
       16          0       4.47     0.0983   1.59e+05       -109      0.138      0.503          1          1 
       32          0       4.54      0.187   2.56e+05       -107      0.294      0.495          1          1 
       64          0        4.7      0.365   4.84e+05       -107       0.36      0.478          1          1 
      128          2       4.97       0.77   8.79e+05       -108      0.301      0.448          1          1 
      256          2       4.87       1.56    1.7e+06       -107      0.274      0.459          1          1 
      512          4       4.97          3   3.24e+06       -107      0.394      0.448          1          1 
 1.02e+03         16       4.97       6.11   6.43e+06       -108      0.402      0.447          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000001 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 8.39e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000001 seconds
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0          3      0.333    1.7e+07       -104   1.87e-09      0.667          1          1 
        4          0       4.21     0.0442   5.47e+05       -113   2.57e-10      0.532          1          1 
        8          0       4.03     0.0464   9.75e+04       -156   5.41e-41      0.552          1          1 
       16          0       5.09      0.113   1.53e+05       -117      0.126      0.434          1          1 
       32          0       4.34      0.201   2.68e+05       -108      0.117      0.517          1          1 
       64          9       1.84      0.622    2.5e+07       -108      0.689      0.795          1          1 
      128         20       1.64      0.907   3.49e+07       -108      0.753      0.817          1          1 
      256         37       1.66       1.72      7e+07       -108      0.757      0.816          1          1 
      512         69       1.83       3.49    1.4e+08       -108      0.717      0.796          1          1 
 1.02e+03        128       1.95       7.02   2.79e+08       -108      0.701      0.783          1          1 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000004 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 7.98e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000002 seconds
[ Info: Preparing model
┌ Info: Determined number of free variables
└   D = 8
ℓπcallback(θ, args...): 0.000003 seconds (9 allocations: 320 bytes)
┌ Info: Tuning autodiff
│   chunk_size = 8
└   t = 8.08e-7
┌ Info: Selected auto-diff chunk size
└   ideal_chunk_size = 8
∇ℓπcallback(θ): 0.000002 seconds
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        Λ_var      time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       3.38          3     0.0909   1.12e+06       -104  1.59e-164      0.664          1          1 
        4          0       3.99       4.01     0.0455   1.12e+05       -111   2.57e-10      0.579          1          1 
        8          0       3.89       3.55     0.0893   1.96e+05       -147    1.5e-39      0.608          1          1 
       16          0       4.71       3.99      0.174   3.23e+05       -104    0.00798      0.542          1          1 
       32          0       4.72       4.61      0.362   5.42e+05       -117    0.00419      0.509          1          1 
       64          9       4.62        1.8      0.917   1.97e+07      -97.2     0.0644      0.662          1          1 
      128         14       4.76       2.24       1.58   3.54e+07       -108      0.342      0.632          1          1 
      256         37       4.87       2.03       3.04   7.08e+07       -108      0.229      0.637          1          1 
      512         57       4.98       2.33       6.23    1.4e+08       -108      0.354      0.615          1          1 
 1.02e+03        157       5.04       2.13       12.1    2.8e+08       -108      0.371      0.623          1          1 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
CairoMakie.Screen{IMAGE}


julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 4 default, 0 interactive, 2 GC (on 8 virtual cores)
Environment:
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_DEBUG = 
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 4

julia> Pkg.status()
Status `~/projects/Pigeons.jl/test/Project.toml`
  [0bf59076] AdvancedHMC v0.6.2
  [dbc42088] ArgMacros v0.2.4
  [76274a88] Bijectors v0.13.18
  [c88b6f0a] BridgeStan v2.5.0
  [13f3f980] CairoMakie v0.12.14
⌃ [99d987ce] Comrade v0.10.5
  [8bb1440f] DelimitedFiles v1.9.1
  [31c24e10] Distributions v0.25.112
  [ced4e74d] DistributionsAD v0.6.57
⌃ [366bfd00] DynamicPPL v0.28.6
⌅ [7da242da] Enzyme v0.12.36
  [7a1cc6ca] FFTW v1.8.0
  [1a297f60] FillArrays v1.13.0
  [f6369f11] ForwardDiff v0.10.36
  [09f84164] HypothesisTests v0.11.3
  [682c06a0] JSON v0.21.4
  [92481ed7] LinearRegression v0.2.1
  [6fdf6af0] LogDensityProblems v2.1.2
⌃ [996a588d] LogDensityProblemsAD v1.10.1
  [2ab3a3ac] LogExpFunctions v0.3.28
  [c7f686f2] MCMCChains v6.0.6
  [be115224] MCMCDiagnosticTools v0.3.10
  [da04e1cc] MPI v0.20.22
  [3da0fdf6] MPIPreferences v0.1.11
  [daf3887e] Octofitter v4.0.0 `https://github.com/sefffal/Octofitter.jl.git#main`
  [c6a353d9] OctofitterRadialVelocity v4.0.0 `https://github.com/sefffal/Octofitter.jl.git:OctofitterRadialVelocity#main`
  [a15396b6] OnlineStats v1.7.1
  [43a3c2be] PairPlots v2.9.2
  [0eb8d820] Pigeons v0.4.6 `~/projects/Pigeons.jl`
  [fd6f9641] PlanetOrbits v0.10.1
  [91a5bcdd] Plots v1.40.8
  [c3e4b0f8] Pluto v0.20.0
  [7f904dfe] PlutoUI v0.7.60
  [37e2e3b7] ReverseDiff v1.15.3
  [276daf66] SpecialFunctions v2.4.0
  [8efc31e9] SplittableRandoms v0.1.2
⌅ [b1ba175b] VLBIImagePriors v0.8.4
  [a5390f91] ZipFile v0.10.1
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [d6f4376e] Markdown
  [44cfe95a] Pkg v1.10.0
  [9a3f8284] Random
  [9e88b42a] Serialization
  [10745b16] Statistics v1.10.0
  [8dfed614] Test
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`

Here is some code to expose the calls to Inputs, pigeons, and get_sample without going through my convenience wrapper octofit_pigeons:

inputs = Pigeons.Inputs(
       target=model,
       explorer=SliceSampler(),
       n_rounds=10,
       n_chains=10,
       n_chains_variational=6,
       record=[traces; round_trip; record_default(); index_process],
       multithreaded=true,
       variational=GaussianReference()
);
pt = pigeons(inputs);
samples = get_sample(pt,10)
chn =  Octofitter.result2mcmcchain(  model.arr2nt.(model.invlink.(s[1:model.D] for s in samples)));

If we look at the two target chains they are quite different:

julia> Pigeons.target_chains(pt2)
2-element Vector{Int64}:
 10
 11

julia> chn_10 =  Octofitter.result2mcmcchain(  model.arr2nt.(model.invlink.(s[1:model.D] for s in get_sample(pt2,10))));

julia> chn_11 =  Octofitter.result2mcmcchain(  model.arr2nt.(model.invlink.(s[1:model.D] for s in get_sample(pt2,11))));

Julia> octocorner(model, chn_11, chn_10, small=false, viz=(PairPlots.Scatter(markersize=4),PairPlots.MarginHist()))

image

@miguelbiron I ran your script, and agree it did not reproduce the error. Curious.

Okay I found the problem. Whatever the underlying issue, it is quite sensitive to the number of chains.

This does not have the issue:

inp_stab_vpt = Pigeons.Inputs(;
   target = model,
   record = [traces; round_trip; record_default(); index_process],
   multithreaded=true,
   show_report=true,
   n_rounds=10, 
   explorer=SliceSampler(), 
   n_chains=10, 
   n_chains_variational=10,
   variational = GaussianReference(first_tuning_round = 5)
)

But this does reproduce it:

inp_stab_vpt = Pigeons.Inputs(;
   target = model,
   record = [traces; round_trip; record_default(); index_process],
   multithreaded=true,
   show_report=true,
   n_rounds=10, 
   explorer=SliceSampler(), 
   n_chains=10, 
   n_chains_variational=6,
   variational = GaussianReference(first_tuning_round = 5)
)

See n_chains_variational=6.

Ah yes yes I see at some point I changed the number on my script. Rerunning again...

Ah yes now I see it. Sorry for the confusion.

Oof I think I found the issue: around July 2023, the order of the fixed/var legs was swapped in the linear indexing of the combined chains

# Note: 2023/07/20: changed order to have variational first (as depicted below)
# to simplify log(Z) code for 2-legged
# <--- variational ----> <----- fixed ------>
# reference ----- target -- target ---- reference
# 1 ----- N -- N + 1 ---- 2N

But this was not propagated to the swap graph definition

is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_fixed) || (chain == deo.n_chains_fixed + 1)

This should instead be

is_target(deo::VariationalDEO, chain::Int) = (chain == deo.n_chains_var) || (chain == deo.n_chains_var + 1)

(is_reference is correct)

Can confirm that with this change, the results look correct. I'm making a PR with a small test for this issue.

@miguelbiron can we add a unit test for this too? This is a super subtle/insidious bug.

(woops, just saw that you are doing a test already. good stuff :-) )

Thanks again @sefffal for the issue thread!

Wow, great spot @miguelbiron !

I'm left wondering though: why do these two target legs look so different? I.e. how did these samples with way lower posterior density end up in the second target chain to begin with?

I guess that all my previous results with stabilized VPT also had this issue, and I only noticed there was a problem when the target chain from the variational leg was misbehaving?

@sefffal because chains (10,11) where actually both in the interior of the fixed leg; so neither was actually a target. But since we use adaptive schedules, two consecutive chains can target surprisingly different distributions, if that arrangement is the one that leads to equirejection.

Concerning your past experiments with SVPT, I think you were most like getting incorrect samples too from the interior of one or the other leg (depending on the number of chains in each). Sometimes the samples can look ok because of the interior points can be arbitrarily close to the endpoints thanks to adaptive scheduling.

The error was sensitive to the number of chains; was that because some numbers of chains didn't have this issue, or because in some cases the samples were arbitrarily close to the endpoints due to the adaptive scheduling?

Sorry for all the questions; I just want to know how this will impact previous results.

No worries! It's probably best to look at the indexer in question of your example

julia> pt_2_legs.shared.tempering.indexer.i2t
16-element Vector{@NamedTuple{chain::Int64, leg::Symbol}}:
 (chain = 1, leg = :variational)
 (chain = 2, leg = :variational)
 (chain = 3, leg = :variational)
 (chain = 4, leg = :variational)
 (chain = 5, leg = :variational)
 (chain = 6, leg = :variational)
 (chain = 10, leg = :fixed)
 (chain = 9, leg = :fixed)
 (chain = 8, leg = :fixed)
 (chain = 7, leg = :fixed)
 (chain = 6, leg = :fixed)
 (chain = 5, leg = :fixed)
 (chain = 4, leg = :fixed)
 (chain = 3, leg = :fixed)
 (chain = 2, leg = :fixed)
 (chain = 1, leg = :fixed)

This is a vector that, for each i 1:(total number of chains), tells us to which leg it belongs and to which index it corresponds within that leg. Now, the true target legs are entries 6 and 7

 (chain = 6, leg = :variational)
 (chain = 10, leg = :fixed)

But the failed target_chains was reporting 10 and 11, corresponding to

 (chain = 7, leg = :fixed)
 (chain = 6, leg = :fixed)

I.e., 2 interior chains of the fixed leg. You can see how different combinations of n_chains and n_chains_variational parameters would result in different output. But most likely, target_chains was returning a pair of chains in the interior of a leg instead of the true 2 targets in different legs.

How different those two interior distributions were from the actual target is impossible to say a priori, given that their beta parameters could've been close to 1---and therefore the samples could've looked ok.

Thanks @miguelbiron, that really clears it up.

So can we say this issue did not occur when the number of regular and variational chains were the same?

For sure; in that special case, the bug does not appear.

Excellent! I am relieved, since our published results had so far used the same numbers of regular and variational chains.

I appreciate your help on this very much.

OMG I'm so glad that that is the case---phew!!! 😌

Phhheewww.. thank you so much William for reporting. And Miguel, you have eagle eyes.. amazing job

I am relieved, since our published results had so far used the same numbers of regular and variational chains

Thank goodness!....

and thanks again @miguelbiron and @sefffal for hunting this very bad bug down 👍