CompRhys/aviary

separate `fit` and `predict`

sgbaird opened this issue · 7 comments

Thanks for the patience with all the posts.

It seems that the train and test data is passed in all at once. Ideally, I'd like to use RooSt in an sklearn-esque "instantiate, fit, and predict" style; it's not urgent, timescale is about a month. Since I'm not familiar with the underlying code, thought I would ask before diving in. Any thoughts/suggestions on this?

If we refactor it would be to use PyTorch lightning's API rather than sklearn. PyTorch lightning places a framework around a lot if the stuff we do in our base class. Doing so would introduce breaking changes and so would need a compelling reason to do so.

I am not sure I recognise the "all the data being passed at once" - train_ensemble takes train and val sets, results_multitask takes the test set. The various dictionaries to reconstruct the models are passed to these functions, here the aim was to make ensembling easy to handle. When training we don't actually use the functionality as training is embarrassingly parallel but it's important when testing.

FWIW, the train_ensemble() and results_multitask() functions seem like the most opinionated parts of the code base to me, i.e. the ones that would most benefit from a refactor. I've thought in the past that splitting them into smaller specialized functions which people can build flexible scripts from would be good. Megnet has a simple outward API so we could also draw some inspiration from there.

Looking at the code I realise that @janosh already made us implement a fit and predict method inside the ABC:

class BaseModelClass(nn.Module, ABC):

It's just that they're not super easy to use as the data inputs are very opinionated - you must pass an id and the material composition - so they're perhaps not the easiest to use.

you must pass an id and the material composition

Where does that requirement come from again? Is it hard to change?

If we refactor it would be to use PyTorch lightning's API rather than sklearn. PyTorch lightning places a framework around a lot if the stuff we do in our base class. Doing so would introduce breaking changes and so would need a compelling reason to do so.

Good point about PyTorch lightning.

I am not sure I recognise the "all the data being passed at once" - train_ensemble takes train and val sets, results_multitask takes the test set. The various dictionaries to reconstruct the models are passed to these functions, here the aim was to make ensembling easy to handle. When training we don't actually use the functionality as training is embarrassingly parallel but it's important when testing.

Thanks for clarifying! This raises an important question; is the val_set kwarg within train_ensemble used to affect the predictions in any way? (e.g. learning rate scheduler, internal optimizer, hyperparameter optimization) See also anthony-wang/CrabNet#15

Looking at the code I realise that @janosh already made us implement a fit and predict method inside the ABC:

class BaseModelClass(nn.Module, ABC):

It's just that they're not super easy to use as the data inputs are very opinionated - you must pass an id and the material composition - so they're perhaps not the easiest to use.

Good to know!

you must pass an id and the material composition

Where does that requirement come from again? Is it hard to change?

Initially it was just id comp target for roost then with cgcnn and wren I found it useful to still have comp in the results file so I made it a requirement. I think you can add more identification columns but needs to be more than one.

happy to close?