/federated-gcn

This repository contains python scripts related to a research that is intended to run graph convolution neural networks in federated manner on distributed graph databases

Primary LanguagePython

Federated Learning framework for Graph Neural Networks

Introduction

This small framework can be used to scale up and train any Tensorflow or Pytrorch based graph neural network in federated manner on partitioned graphs or distributed graphs. Some graphs are too large and consumes huge amount of time if we train a neural network on it. In this case distributed learning is a potential approach but in some cases (Ex;- distributed graphs located on different datacenters) federated learning can be better becase it minimizes the communication overhead and speed up the training.

Requirements

  • Python 3.5 or higher (A conda or a Virtual environment is preferred)
  • Tensorflow 2
  • Stellargraph
  • Scikit-learn
  • Pandas
  • Numpy

Installation

  1. Clone the repository
  2. Install dependencies as follows

pip3 install -r requirements.txt

Model structure

To be trained using this framework a neural network should be wrapped with a python class that provides following methods. Few example models are in the models folder in the repo.

class Model:

    def __init__(self,nodes,edges):
        # define class variables here

    def initialize(self,**hyper_params):
        # define model initialization logic here
        # **hyper_params dictionary can be used to pass any variable

        return initial model weights

    def set_weights(self,weights):
        # set passed weights to your model

    def get_weights(self):
        # extract and return model weights
        return model weights

    def fit(self,epochs = 4):
        # define training logic here
        return model weights, training history

    def gen_embeddings(self):
        # this method is optional
        return embeddings as a pandas dataframe

Repository structure

Repository
├── data: partitioned CORA dataset for testing
├── misc: miscellaneous
├── models: where models are stored

  • upervised.py: supervised implementation of GRAPHSAGE
  • supervised.py: unsupervised implementation of GRAPHSAGE

Start training

Before start clients or server set PYTHONHASHSEED environment variable to 0 for getting reproducible results.

PYTHONHASHSEED=0

Starting fl_server

Following arguments must be passed in following order

For unsupervised training

  • path_weights - A location to extract and store model weights
  • path_nodes - Where your graph nodes are stored
  • path_edges - Where your graph edges are stored
  • graph_id - ID for identify graphs
  • partition_id - ID of the partition located in server that is used to initialize the weights
  • num_clients - Number of clients that will be join for the federated training
  • num_rounds - Number of federated rounds to be trained
  • IP(optional - default localhost) - IP of the VM that fl_server is in
  • PORT(optional - default 5000) - PORT that shuould be used to communicate with clients
python fl_server_unsupervised.py ./weights/ ./data/ ./data/ 4 0 2 3 localhost 5000

For supervised training

python fl_server.py ./weights/ ./data/ ./data/ 4 0 2 3 localhost 5000

For supervised + sheduled training

python fl_server_shed.py ./weights/ ./data4/ ./data4/ 3 0 1 3 localhost 5000

Starting fl_client s

For unsupervised training

  • path_weights - A location to extract and store model weights
  • path_embeddings - A location to store node embeddings if you want to generate them
  • path_nodes - Where your graph nodes are stored
  • path_edges - Where your graph edges are stored
  • graph_id - ID for identify graphs
  • partition_id - ID of the partition located in server that is used to initialize the weights
  • epochs - number of epochs to train
  • IP(optional - default localhost) - IP of the VM that fl_server is in
  • PORT(optional - default 5000) - PORT that fl_server is listening to

Any number of clients can be started but number of clients should be passed in to fl_server when it is started as explained above.

client 1

python fl_client_unsupervised.py ./weights/ ./embeddings/ ./data/ ./data/ 4 0 4 localhost 5000

client 2

python fl_client_unsupervised.py ./weights/ ./embeddings/ ./data/ ./data/ 4 1 4 localhost 5000

For supervised training

  • path_weights - A location to extract and store model weights
  • path_nodes - Where your graph nodes are stored
  • path_edges - Where your graph edges are stored
  • graph_id - ID for identify graphs
  • partition_id - ID of the partition located in server that is used to initialize the weights
  • epochs - number of epochs to train
  • IP(optional - default localhost) - IP of the VM that fl_server is in
  • PORT(optional - default 5000) - PORT that fl_server is listening to

client 1

python fl_client.py ./weights/ ./data/ ./data/ 4 0 10 localhost 5000

client 2

python fl_client.py ./weights/ ./data/ ./data/ 4 1 10 localhost 5000

For supervised + sheduled training

  • client_id - ID for identify the client
  • path_weights - A location to extract and store model weights
  • path_nodes - Where your graph nodes are stored
  • path_edges - Where your graph edges are stored
  • graph_id - ID for identify graphs
  • partition_ids - Comma seperated IDs of the partitions ordered in to the sheduling order ( Ex:- 1,2,5)
  • epochs - number of epochs to train
  • IP(optional - default localhost) - IP of the VM that fl_server is in
  • PORT(optional - default 5000) - PORT that fl_server is listening to

client 1

python fl_client_shed.py 1 ./weights/ ./data4/ ./data4/ 3 0,1 3 localhost 5000

client 2

python fl_client_shed.py 2 ./weights/ ./data4/ ./data4/ 3 2,3 3 localhost 5000