/GeometricFlux.jl

Geometric Deep Learning for Flux

Primary LanguageJuliaMIT LicenseMIT

GeometricFlux.jl

codecov

GeometricFlux is a geometric deep learning library for Flux. This library aims to be compatible with packages from JuliaGraphs ecosystem and have support of CUDA GPU acceleration with CUDA. Message passing scheme is implemented as a flexbile framework and fused with Graph Network block scheme. GeometricFlux is compatible with other packages that are composable with Flux.

Suggestions, issues and pull requsts are welcome.

Installation

]add GeometricFlux

Features

  • Extending Flux deep learning framework in Julia and seamlessly integration with regular Flux layers.
  • Support of CUDA GPU with CUDA.jl and mini-batched training leveraging advantages of GPU
  • Integration with existing JuliaGraphs ecosystem
  • Support Message-passing and graph network architectures
  • Support of static graph and variable graph strategy. Variable graph strategy is useful when training the model over diverse graph structures.
  • Integration of GNN benchmark datasets
  • Support dynamic graph update towards manifold learning 2.0

Featured Graphs

GeometricFlux handles graph data (the topology plus node/vertex/graph features) thanks to FeaturedGraph type.

A FeaturedGraph can be constructed from various graph structures, including adjacency matrices, adjacency lists, Graphs' types...

fg = FeaturedGraph(adj_list)

Graph convolutional layers

Construct a GCN layer:

GCNConv(input_dim => output_dim, relu)

Use it as you use Flux

model = Chain(
    WithGraph(fg, GCNConv(fg, 1024 => 512, relu)),
    Dropout(0.5),
    WithGraph(fg, GCNConv(fg, 512 => 128)),
    Dense(128, 10)
)
## Loss
loss(x, y) = logitcrossentropy(model(x), y)
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))

## Training
ps = Flux.params(model)
train_data = [(train_X, train_y)]
opt = ADAM(0.01)
evalcb() = @show(accuracy(train_X, train_y))

Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))

Roadmap

To achieve geometric deep learning raised by Bronstein et al, 5G fields of deep learning models will be supported in GeometricFlux.jl. For details, you could check the geometric deep learning official website.

5(+1)G including the following fields:

  • Graphs and Sets
    • including classical GNN models and networks over sets.
    • Transformer models are regard as a kind of GNN with complete graph, and you can check chengchingwen/Transformers.jl for more details.
  • Grids and Euclidean spaces
    • including classical convolutional neural networks, multi-layer perceptrons etc.
    • for operators over functional spaces of regular grid, you can check SciML/NeuralOperators.jl for more details.
  • Groups and Homogeneous spaces
    • including a series of equivariant/invariant models.
  • Geodesics and Manifolds
  • Gauges and Bundles
  • Geometric algebra

Discussions

It's welcome to have direct discussions in #graphnet channel or in #flux-bridged channel on slack. For usage issues, it's welcome to post your minimal working examples (MWE) on Julia discourse and then tag maintainer @yuehhua.