rsagroup/rsatoolbox

Returning fitted model parameters in inference funcitons

Aceticia opened this issue · 3 comments

Sometimes we have a parametrized model and are interested in finding the parameters found during evaluation. For example, one use case is we might want to learn about the weight assigned to a flexible model during a searchlight analysis.

volunteer
I'll be happy to contribute towards this feature. Two natural ways exist to implement this:

  1. We can optionally incorporate the parameters found in the result objects. It would need to be a length n_model list of numpy arrays of size bootstrap_samples x crossval & others x n_params.
  2. In inference functions such as eval_dual_bootstrap_random, we optionally return the parameters found during each fold.

Note that these two approaches both require us to add additional returns from the crossval function.

Hi @Aceticia
thanks for looking deeper into the toolbox! Giving an easy way to analyze the fitted parameters certainly seems like a good idea, especially for something like a searchlight analysis.

However, I am not so sure, that you would want to use the parameter fits from the bootstrap-crossvalidation for this type of analysis. If you care about the parameters, not about model performance, you should probably not use cross-validation, yielding both more stable estimates of the parameters and more accurate variance estimates for them.

Thus, currently I would rather go for an analogous function that yields fits for a model with different amounts of bootstrapping, rather than returning the fitted parameters from the bootstrap crossvalidation. Returning the fitted parameters might still be sensible for debugging the models, but for actual inference on the parameters, you should probably do slightly different things.

Hi @HeikoSchuett I agree that using crossval is not the best idea here. How does the following sound:

A separate utility function that is basically a wrapper around bootstrap_sample and model.fit. It takes the models and a set of RDMs, an RDM descriptor to group over, number of boots and iterate over boots and simply return the result of each call to model.fit.

This sounds more reasonable to do. I can write a test for it as well.