New organization of `SNPE` methods
Opened this issue · 0 comments
michaeldeistler commented
-- inference
----- trainers
--------- npe
------------- npe.py
------------- snpe_a_correction.py
------------- snpe_c_loss.py
Then, the API for NPE
(amortized) is:
from sbi.inference import NPE, DirectPosterior
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
posterior = DirectPosterior(net, prior) # Or use `build_posterior()`
For SNPE_A
, it is:
from sbi.inference import NPE, DirectPosterior, snpe_a_correction
for r in range(3):
theta = proposal.sample((1000,))
x = simulator(theta)
trainer = NPE(density_estimator="Gaussian" if r < 2 else "mdn")
net = trainer.append_simulations(theta, x).train()
proposal_posterior = DirectPosterior(net, prior) # Or use `build_posterior()`
corrected_posterior = snpe_a_correction(proposal_posterior, proposal)
proposal = corrected_posterior
For SNPE_C
(atomic), it is:
from sbi.inference import NPE, DirectPosterior, snpe_c_atomic_loss
# First round is standard NPE.
theta, x = simulate_for_sbi(prior, simulator)
trainer = NPE()
net = trainer.append_simulations(theta, x).train()
proposal = DirectPosterior(net, prior).set_default_x(x_o) # Or use `build_posterior()`
# Later rounds use the APT loss.
for _ in range(1, 3):
theta, x = simulate_for_sbi(proposal, simulator)
net = trainer.append_simulations(theta, x).train(loss=snpe_c_atomic_loss)
proposal = DirectPosterior(net, prior).set_default_x(x_o) # Or use `build_posterior()`
For SNPE_C
(non-atomic), the only difference is that one would also pass proposal=proposal
to append_simulations()
, and one has to use MDN
s.