sbi-dev/sbi

hyperparameter tuning for the models?

Closed this issue · 1 comments

I am testing embedding_net, e.g. in the example

embedding

latent_dim = 10
single_trial_net = FCEmbedding(
input_dim=theta_dim,
num_hiddens=40,
num_layers=2,
output_dim=latent_dim,
)
embedding_net = PermutationInvariantEmbedding(
single_trial_net,
trial_net_output_dim=latent_dim,
# NOTE: post-embedding is not needed really.
num_layers=1,
num_hiddens=10,
output_dim=10,
)

there are hyper parameter such as num_hiddens, num_layers and so on. I assume they should be tuned? is there any guideline?

For the project I am working on, I used a CNN model, the dense layer's number of neurons is decreasing in order to create a hierarchical structure, and the performance is better than using SBI. But I am thinking SBI is designed to inference problem so should be better? Any comments and suggestions are appreciated.

Hi @timktsang

yes, these hyper parameter should be adapted to the given problem.

First, the choice of network is essential, e.g., for image data you should use an CNN embedding, for time series a RNN or a transformer, for other high-dimensional data (say >100), one could use just a fully connected FCEmbedding. The permutation invariant embedding is useful for trial-based data.

I am not sure I understand how you used your CNN model as an alternative to SBI, but if it is working well, you could just use that CNN as embedding net for the inference with SBI.

Regarding hyper parameter searchers, you could set apart a test set of $N$ simulations (theta, x) and then calculate the negative log probability under the estimated posterior (NLTP) as a performance metric (see https://arxiv.org/pdf/2101.04653 appendix M1 for details). Say want to find the best num_hiddens for your embedding net. Then you would repeat the training over a grid of num_hidden values and after each training calculate NLTP as

nltp = - torch.mean(torch.tensor([posterior.log_prob(theta_i, x=x_i) for theta_i, x_i in zip(theta, x)]))

The "best fitting" num_hiddens is the one with the lowest nltp. Note that one should use a large N, e.g., N>100. Therefore, I recommend using NPE (single round SNPE) as inference method. Then you could easily use a test set of size N=1000.

I hope this helps. Let me know if there are further questions.

Best,
Jan