Julia-Tempering/Pigeons.jl

Multithreading performance

Closed this issue · 5 comments

Hello,

I have been trying to evaluate the performance of this package with multiple threads but unfortunately setting multithreded=true seems to result in a slowdown (running on a 6-core CPU):

$ julia -t 6
...
julia> using Pigeons

(omitting first run)

julia> pigeons(target = toy_mvn_target(100), multithreaded=false)
┌ Info: Neither traces, disk, nor online recorders included.
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Warning: More than one threads are available, but explore!() loop is not parallelized as inputs.multithreaded == false
└ @ Pigeons ~/.julia/packages/Pigeons/rBo9q/src/pt/checks.jl:12
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        7.5   5.64e-05   1.28e+04       -122   2.66e-15      0.167
        4       5.59   6.76e-05   1.47e+04       -119    2.6e-07      0.378
        8       6.04   0.000102   1.86e+04       -115    0.00129      0.329
       16       7.27    0.00016   2.42e+04       -118     0.0134      0.193
       32       6.97    0.00025   3.05e+04       -114      0.107      0.225
       64       7.03   0.000439   4.13e+04       -117     0.0531      0.219
      128       7.23   0.000797   6.08e+04       -114     0.0944      0.196
      256       7.05    0.00142   6.77e+04       -115       0.13      0.217
      512       7.14    0.00259   7.18e+04       -115      0.171      0.207
 1.02e+03       7.19    0.00494   7.91e+04       -115      0.172      0.201
────────────────────────────────────────────────────────────────────────────
PT(checkpoint = false, ...)

julia> pigeons(target = toy_mvn_target(100), multithreaded=true)
┌ Info: Neither traces, disk, nor online recorders included.
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        7.5     0.0227   7.49e+05       -122   2.66e-15      0.167
        4       5.59     0.0223   7.57e+05       -119    2.6e-07      0.378
        8       6.04   0.000379   5.62e+04       -115    0.00129      0.329
       16       7.27   0.000604   9.92e+04       -118     0.0134      0.193
       32       6.97   0.000983    1.8e+05       -114      0.107      0.225
       64       7.03    0.00175   3.42e+05       -117     0.0531      0.219
      128       7.23    0.00184   6.61e+05       -114     0.0944      0.196
      256       7.05     0.0036   1.27e+06       -115       0.13      0.217
      512       7.14    0.00811   2.48e+06       -115      0.171      0.207
 1.02e+03       7.19     0.0172   4.91e+06       -115      0.172      0.201
────────────────────────────────────────────────────────────────────────────
PT(checkpoint = false, ...)

julia>

Is this expected?
I tried changing the problem by increasing the number of chains, changing the dimension of the problem, and increasing the number of rounds but the results remain similar.
For instance:

julia> pigeons(target = toy_mvn_target(500), multithreaded=true, n_rounds=17)
┌ Info: Neither traces, disk, nor online recorders included.
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       7.59   0.000312   2.13e+04       -586   1.31e-48      0.157
        4       8.05   0.000293   3.04e+04       -591   7.42e-36      0.106
        8       8.59   0.000431   4.91e+04       -586   2.12e-32      0.046
       16       8.49   0.000672   8.77e+04       -583   7.84e-20      0.057
       32       8.71    0.00133   1.65e+05       -586   6.39e-25     0.0326
       64       8.58    0.00256   3.19e+05       -580   3.93e-17     0.0466
      128       8.75    0.00441   6.21e+05       -575   2.85e-18     0.0281
      256       8.76    0.00961   1.24e+06       -575   1.33e-17     0.0263
      512       8.86     0.0206   2.44e+06       -575   8.15e-15     0.0152
 1.02e+03       8.83     0.0285   4.89e+06       -575   2.57e-14     0.0192
 2.05e+03       8.88     0.0468   9.72e+06       -576   1.69e-12     0.0135
  4.1e+03       8.88     0.0902   1.94e+07       -574      3e-13     0.0138
 8.19e+03       8.89      0.149   3.87e+07       -576   1.75e-10     0.0121
 1.64e+04       8.89      0.289   7.74e+07       -573   1.05e-08     0.0119
 3.28e+04        8.9      0.575   1.55e+08       -574   2.01e-09     0.0115
 6.55e+04        8.9       1.16   3.09e+08       -574   7.14e-08      0.011
 1.31e+05       8.91       2.31   6.19e+08       -574   1.53e-07     0.0104
────────────────────────────────────────────────────────────────────────────
PT(checkpoint = false, ...)

julia> pigeons(target = toy_mvn_target(500), multithreaded=false, n_rounds=17)
┌ Info: Neither traces, disk, nor online recorders included.
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
┌ Warning: More than one threads are available, but explore!() loop is not parallelized as inputs.multithreaded == false
└ @ Pigeons ~/.julia/packages/Pigeons/rBo9q/src/pt/checks.jl:12
────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)
────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       7.59   0.000293   1.28e+04       -586   1.31e-48      0.157
        4       8.05   0.000282   1.16e+04       -591   7.42e-36      0.106
        8       8.59   0.000482   1.16e+04       -586   2.12e-32      0.046
       16       8.49   0.000855   1.27e+04       -583   7.84e-20      0.057
       32       8.71    0.00157   1.46e+04       -586   6.39e-25     0.0326
       64       8.58    0.00347   1.62e+04       -580   3.93e-17     0.0466
      128       8.75    0.00584   1.62e+04       -575   2.85e-18     0.0281
      256       8.76     0.0119   2.71e+04       -575   1.33e-17     0.0263
      512       8.86     0.0196   2.18e+04       -575   8.15e-15     0.0152
 1.02e+03       8.83     0.0182   3.43e+04       -575   2.57e-14     0.0192
 2.05e+03       8.88     0.0367   3.95e+04       -576   1.69e-12     0.0135
  4.1e+03       8.88     0.0725   5.35e+04       -574      3e-13     0.0138
 8.19e+03       8.89      0.144    5.4e+04       -576   1.75e-10     0.0121
 1.64e+04       8.89      0.287   6.08e+04       -573   1.05e-08     0.0119
 3.28e+04        8.9      0.575   6.24e+04       -574   2.01e-09     0.0115
 6.55e+04        8.9       1.15    6.4e+04       -574   7.14e-08      0.011
 1.31e+05       8.91        2.3    6.4e+04       -574   1.53e-07     0.0104
────────────────────────────────────────────────────────────────────────────
PT(checkpoint = false, ...)

julia>

Interestingly, at least in the second case the threads do seem to be busy.
It also seems that the multithreaded version is allocating significantly more memory.
Am I missing something?

Thanks for pointing this out. The reason for this is that the toy_mvn_target() has a special explorer performing iid sampling. We use this when we need very fast running tests (e.g. in CI) but in terms of performance profile this is atypical. Since the explorer is so quick the overhead of multithreading is not worth it. Typical explorers use algorithms such as slice sampling or gradient informed samplers, and for those algorithms we see improvements once the problem is large enough. I will post here in a few minutes a preliminary benchmark on this...

Info on the above plot:

  • targets: same as you used (mvn_toy_normal) of but with the key difference than I force the more realistic explorers using explorer = AutoMALA(). I show models of increasing dimensionality, where one would increase the number of chains prescribed by the theory (N = sqrt{d}) (x-axis) and hence the number of threads.
  • 10 repeats for each regime (confidence intervals shown)
  • blue dashed line is the theoretical best case
  • red dashed line is the threshold where there are improvements to use multithreading

For completeness, similar results for MPI/distributed mode instead