sbi-dev/sbi

New organization of `SNPE` methods

Opened this issue · 0 comments

-- inference
----- trainers
--------- npe
------------- npe.py
------------- snpe_a_correction.py
------------- snpe_c_loss.py

Then, the API for NPE (amortized) is:

from sbi.inference import NPE, DirectPosterior

trainer = NPE()
net = trainer.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior)  # Or use `build_posterior()`

For SNPE_A, it is:

from sbi.inference import NPE, DirectPosterior, snpe_a_correction

for r in range(3):
    theta = proposal.sample((1000,))
    x = simulator(theta)

    trainer = NPE(density_estimator="Gaussian" if r < 2 else "mdn")
    net = trainer.append_simulations(theta, x).train()
    proposal_posterior = DirectPosterior(net, prior)  # Or use `build_posterior()`
    corrected_posterior = snpe_a_correction(proposal_posterior, proposal)
    proposal = corrected_posterior

For SNPE_C (atomic), it is:

from sbi.inference import NPE, DirectPosterior, snpe_c_atomic_loss

# First round is standard NPE.
theta, x = simulate_for_sbi(prior, simulator)
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
proposal = DirectPosterior(net, prior).set_default_x(x_o)  # Or use `build_posterior()`

# Later rounds use the APT loss.
for _ in range(1, 3):
    theta, x = simulate_for_sbi(proposal, simulator)
    net = trainer.append_simulations(theta, x).train(loss=snpe_c_atomic_loss)
    proposal = DirectPosterior(net, prior).set_default_x(x_o)  # Or use `build_posterior()`

For SNPE_C (non-atomic), the only difference is that one would also pass proposal=proposal to append_simulations(), and one has to use MDNs.