CompRhys/aviary

Add models that are equivalent to Roost

Closed this issue ยท 11 comments

CrabNet and AtomSets-v0 are both equivalent to roost in that they are weighted set regression architectures. If aviary is to develop into a DeepChem for inorganic materials property prediction it might be nice to add implementations of these models.

I made a pip/conda installable version of CrabNet and the README has some basic, updated documentation. In particular, this was to make it easy to incorporate CrabNet into mat_discover such as using a fit/predict API. I'm actively updating this, and eventually I'd like to incorporate the changes into the original repo, but Anthony has been busy defending his PhD as of late โš”๏ธ ๐Ÿ›ก๏ธ (he passed!). Not sure when all the changes will go through, but in the meantime, it should still be easy to incorporate CrabNet via the fork. If you're hoping for a minimal, stand-alone implementation within Aviary, that's a different story.

Also, do you mind clarifying what you mean by "weighted set regression"?

What can I help with for this?

Would probably be aiming to keep the <model>/data.py, <model>/model.py, <model>/utils.py structure as otherwise there's not much benefit to having them all in one place over just having a list pointing to the different reference implementations. I am not set on keeping the API fixed in terms of the classes but for this to be a useful effort I think the directory structure and as much common code should be reused. If you look at DeepChem it's not a useful learning tool imo as they mix frameworks and implementations and structures. There's not a huge amount of point duplicating other peoples work unless we can add something that makes it easier to understand or use (i.e. here the cgcnn implementation can take a variable number of neighbours rather than fixing to 12 and zero padding like the reference implementation).

The chemical system is a set i.e. {Fe, P, O}, for a composition each element in the set has a weight. In my opinion the main advance of Roost (and then CrabNet and AtomSets-v0) is that they operate directly on this set with the element weighting playing a role akin to the positional encoding in sequence based transformers. I've been calling this weighted set regression in contrast to standard regression where we use a fixed-length vector (CBFV in your groups terminology).

Gotcha. Thanks for clarifying! For CrabNet at least, that structure that you described might be a decent amount of refactoring, but I'd have to take a second look at examples of the files you mentioned.

Also, marginally relevant but maybe worth mentioning another codebase that is integrating some models under a common framework: https://github.com/ncfrey/litmatter

I think that might just be wrapper around torch.geometric models but having sorted out how to split workload over many GPUs? In which case it's not great for inorganic systems as the general neighbour graph for triclinic systems doesn't have a standard torch/pyg implementation

I think that other codebases are better suited to structure-based problems but my own work had been on co-ordinate free models and discovery workflows (i.e. energy/stability) rather than structure-based models and either dynamics/transport or property prediction.

@sgbaird does crabnet need the edit to pytorch source to work or only to extract the heatmap plots?

Most of my focus has been on compositional models as well. The idea of stability (and synthesizability) is also fairly important to us, though I think we've struggled to settle on an implementation for this. Recently, CoCoCrab gave some focus on multi-objective optimization for which one of the objectives was a proxy for stability. There's of course Materials Project e_above_hull or other measure such as formation energy per atom. Happy to hear if you have any thoughts or suggestions on what you've liked for stability proxies. Also, thanks for getting back to me so quickly and thoroughly on this.

The modification to PyTorch is just for being able to output the multi-head attentions for plotting as you mentioned. Other than that visualization, everything processes the same as normal (i.e. same working model/outputs). No need for the edit unless you want a bit of added interpretability of the outputs.

@janosh -- as you can run the roostformer variant using the wrenformer code and roostformer is essentially CrabNet without the sinosoidal embedding I am going to close this.

I am not sure that the atom-sets-V0 model is worth keeping this open as the QKV attention mechanism probably dominates the set2set pooling mechanism used in the atom-sets work.

Also, marginally relevant but maybe worth mentioning another codebase that is integrating some models under a common framework: https://github.com/ncfrey/litmatter

There's also https://github.com/vxfung/MatDeepLearn, another collection of materials models in PyTorch focused on structure-based models. They have CGCNN, MEGNet, Schnet and MPNN among others.