MKLab-ITI/JGNN

Graph Neural Network tutorial

Opened this issue · 9 comments

This is a very interesting library and I want to try this for my project. I wanted to know
if it's possible to have a Graph Neural Network example in the tutorials?

Hi and thanks for the interest!

The Quickstart section of the readme demonstrates a GNN already and you can find more in the examples directory. I strongly recommend using simpler architectures (e.g. APPNP) when processing a lot of edges.

The corresponding tutorial will be uploaded by Monday with the next commit. I will update this issue at that time.

Feel free to point out errors, missing functionality, or other things that are hard to use/understand when using the library.

P.S. Currently, automatic dataset downloading is disabled (it will be back on in a week or so, once we migrate to the new approach), but you can train architectures with your own data just fine.

Edit: Apparently the quickstart is a little outdated in terms of dataset management. (But the architecture itself should work just fine as-is.)

Added a first take on a GNN tutorial.

Can you give some feedback @gavalian (or whoever stumbles upon this issue, don't hesitate to re-open it in the future) on things that are difficult to understand?

Hi @maniospas,
The tutorial looks good. I'm trying to get started on it, I do not understand how to structure the data to pass it to the trainer. I want to use this code for graph classification. I have graphs with a different number of nodes (and connections) and I want to classify them based on the node and edge features. I want to construct a training and validation sample. A simple data set example will be very useful on how constructing such data samples.

Hi again,
I must confess that the library's GCNBuilder and the trainer were made with node classification in mind. Maybe these should be renamed to show that their purpose is to handle node classification.

Issue #1 already mentions the prospect of providing graph classification, but I have been concentrating on making sure that everything else works correctly and have not gotten yet to creating and testing code for this capability.

If you don't mind manually calling the backpropagation, you can perform graph clasisifcation using the general LayeredBuilder with something like the following snippet:

ModelBuilder builder = new LayeredBuilder()
    .var("A")  // set the graph as an architecture input. this is the second input, as the layered builder already defines a first one "h0" (not to be confused with any "h{0}") from its constructor to hold 
    .layer(...) // define your architecture per the tutorial
    ....
    .layer("h{l+1}=softmax(mean(h{l}, row))") // or mean(h{l}, row), better graph pooling will probably be provided in the future
    .out("h{l}"); // set the model to output the outcome of the top layer

Model model = builder.getModel().init(new XavierNormal());
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Loss loss = new CategoricalCrossEntropy();
for(int epoch=0; epoch<300; epoch++) {
    for(int graphId=0; graphId<graphLabels.size(); graphId++) {
         Matrix adjacency = graphMatrices.get(graphId); // retrieve from a list
         Matrix nodeFeatures = graphNodeFeatures.get(graphId); // retrieve from a list
         Tensor graphLabel = graphLabels.get(graphId); // one-hot encoding of graph labels, e.g. created with Tensor label = new DenseTensor(numClasses).put(class, 1.0);
         model.train(loss, optimizer, 
              Arrays.asList(nodeFeatures, adjacency), 
              Arrays.asList(graphLabel));
    }
    optimizer.updateAll();
}

⚠️ This is a first take, just to show how you would go about doing it. I have not tested it yet and will probably get arround to integrating this kind of functionality or examples next week. In the interim, I hope that this has been helpful.

Edit: Code improvements, apparently softmax is not implemented for simple tensors (but works well for matrices). Will push a preliminary working version tomorrow.

I'm working on a generic example, if it does work I can forward you the code in case you want to post it as an example. Meanwhile, the LayeredBuilder() does not have a method layer(....).
The code above does not compile.

Thanks a lot, I would really appreciate a coherent example if you make one work.

I pushed changes to make the code work. The issue was that '.var(...)' returned a base ModelBuilder instead of a LayeredBuilder (sometimes, polymorphism is hard). Make sure to update dependencies to accomodate the last commit.

I also added a simple graph classification example that is a more complete version of the above take.

Also added graph classification tutorial, though as a first take it could be missing things or be unrefined.

I tried the example, and everything works fine (after I updated the dependency). However, the network fails to learn. I wrote a small data provider class that generates an object trajectory along with a false trajectory and with labels true and false. When I run the graphs through the training it converges on accuracy of 50%, (in other words, does not learn), however when I run the same data set through MLP network, it learns the dataset with accuracy and performs on tests sample with an accuracy of 99.999%.
I have saved the project here:

https://drive.google.com/file/d/10lYVnclZzdYCflScdwbqJSQSngwXu7oD/view?usp=sharing
Any help will be appreciated.

This is a simple one-dimensional trajectory prediction (I'm trying to get the simple case working), my goal is to eventually get 3-D trajectory classification working, the reason for using Graphs instead of MLP is that not all nodes will be present along the trajectory which leads to different graph sizes.

Thanks a lot for the example. I am looking into it and will update this issue.

It turns out that mean pooling is too naive for this application.

For the time being, I finished adding support for the sort pooling operation of:

Zhang, Muhan, et al. "An end-to-end deep learning architecture for graph classification." Proceedings of the AAAI conference on artificial intelligence. Vol. 32. No. 1. 2018.

Consider this reply a tentative first working take on this type of architecture. In the next commits, I will also add your example in the code base @gavalian , as it provides a very interesting use case that is easy to experiment with.

Architecture

Sorting can be integrated as in the following example (don't forget to upgrade to the latest version of the library from jitpack first). I tested the snippet locally and it yields approx. 95% accuracy on the above setting after 500 epochs of training - this is not as impressive as near-perfect MLP performance, but could be an acceptable in practice. I would be interested in hearing further insights.

By the way, don't stop training if the architecture keeps producing random or worse-than-random test accuracy before epoch 50 - this happens because sorting overcomes thresholds that keep changing the understanding of the graph's structure. As learning converges so does the internally understood ordering of nodes and there is a point after which accuracy skyrockets and remains high.

To explain intuitively the concept of sort pooling due to lack of a respective tutorial for the time being: the reduced hyperparameter keeps only that many topologically important nodes, ordered in terms of their importance (where importance is measured by latent feature values). The idea is that these nodes include propagated information by other nodes too.

long reduced = 5;  // input graphs need to have at least that many nodes,  lower values decrease accuracy
long hidden = 8;  // since this library does not use GPU parallelization, many latent dims reduce speed

ModelBuilder builder = new LayeredBuilder()        
        .var("A")  
        .config("features", 1)
        .config("classes", 2)
        .config("reduced", reduced)
        .config("hidden", hidden)
        .layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden))+vector(hidden))")  // don't forget to add bias vectors to dense transformations
        .layer("h{l+1}=relu(A@(h{l}@matrix(hidden, hidden))+vector(hidden))") 
        .concat(2) // concatenates the outputs of the last 2 layers
        .config("hiddenReduced", hidden*2*reduced)  // 2* due to concatenation
        .operation("z{l}=sort(h{l}, reduced)")  // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
        .layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
        .layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
        .layer("h{l+1}=softmax(h{l}, row)")
                .out("h{l}");  

For training, the labels should also be cast into row vectors to be compliant with the architecture's outputs. (See next code snippet.)

Parallelized Training

I was too focused on the architecture previously, but for this large of a dataset you can take advantage of multi-core processors to calculate derivatives in parallel during training. JGNN is thread-safe and provides its own simplified thread pool to help you do this per:

for(int epoch=0; epoch<500; epoch++) {
  // gradient update
  for(int graphId=0; graphId<dtrain.adjucency.size(); graphId++) {
    int graphIdentifier = graphId;
    ThreadPool.getInstance().submit(new Runnable() {
      @Override
      public void run() {
        Matrix adjacency = dtrain.adjucency.get(graphIdentifier);
        Matrix features= dtrain.features.get(graphIdentifier);
        Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();  // Don't forget to cast to the same format as predictions.
        model.train(loss, optimizer, 
		            Arrays.asList(features, adjacency), 
		            Arrays.asList(graphLabel));
      }
    });
  }
  ThreadPool.getInstance().waitForConclusion();  // waits for all gradients to finish calculating
  optimizer.updateAll();
  
  double acc = 0.0;
  for(int graphId=0; graphId<dtest.adjucency.size(); graphId++) {
    Matrix adjacency = dtest.adjucency.get(graphId);
    Matrix features= dtest.features.get(graphId);
    Tensor graphLabel = dtest.labels.get(graphId);
    if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
       acc += 1;
    System.out.println("iter = " + epoch + "  " + acc/dtest.adjucency.size());
  }
}

Notes

I am not yet closing this issue, because I need to also update the related tutorial.

For more discussion, requests on pooling for graph classification, or requests for more tutorials, please open separate issues.