/fed-iris

A prototype federated-learning iris-classifier CLI application

Primary LanguagePythonMIT LicenseMIT

Federated Iris Classifier: Prototype

Overview

This is an prototype Federated Machine Learning application, specifically a simple Iris Classifier (the classical iris dataset), based on primarily torch and flwr. Moreover, this is an experimental sub-project preparing for Fed-KGQA System, or Federated Question-Answering System over Knowledge Graph. The extended abstract can be found here under the home folder. This repo implements the Federated Answer Selector module mentioned in the extended abstract.


Pre-req

Before getting started with this repo, a quick overview of Federated Learning and the flwr package documentation if HIGHLY RECOMMENDED. The flwr official website, click here.


The General Architecture

The application adopts the canonical Server-Clients architecture, where:

  • the clients are the training workers that:
    • updates their own models locally and independently
  • the server is responsible for:
    • collecting locally updated models
    • aggregating the parameters/weights
    • re-distribute the new global model back to the clients

Each one of the two sub-modules will be discussed in detail in their own README. In short, both the Client module (under the client folder) and the Server module (under the server folder) can be viewed as independent applications that can run individually.


Folder Structure

Before running the code, I would like to introduce how I organized the various files. The folder structures of both the server and client sub-modules are quite similar.

Server Module

root
|-- server
    |-- bin      # entry to the CLI
    |-- log      # training log
    |-- models   # cached model parameters
    |-- src      # python source code; called from the bin folder
        |-- __init__.py    # where the CLI is defined

Client Module

root
|-- client
    |-- bin      # same as above
    |-- config   # local settings (not used yet; reserved)
    |-- data     # local training/validation data warehouse
    |-- models   # cached model parameters, for local model serving
    |-- src      # same as above
        |-- __init__.py    # where the CLI is defined, using components defined in the core module
        |-- core           # discussed in client/README.md

Quick Start

In the future I may package the source code into (publication-ready) independent applications. For now I assume that the conda environment is properly setup or all the dependencies are installed.

Server

The server only serves as a parameter aggregator and currently takes only one command line argument address. If not provided, it starts a local server.

# To start the server
#   assuming inside root directory
python ./server/bin/main --address '127.0.0.1:8080'

# For help
python ./server/bin/main --help

Client

The client has two modes available: infer and train. The inference mode can operate independently:

python ./client/bin/main infer

and the expected interface is:

[ INFO ] :: entering inference mode
[ INFO ] :: press < ENTER > to make prediction; < Q > to exit: 
[ READ ] :: |-sep_len >>> 1.1
[ READ ] :: |-sep_wid >>> 2
[ READ ] :: |-pet_len >>> 3
[ READ ] :: |-pet_wid >>> 4
[ INFO ] :: it might be a < Iris-virginica >

To train the model we connect to the running server. Currently, purely local training is not supported. The client must connect to the server for federated training.

# Same as above, providing the address of the SERVER
#   by default, it tries to connect a local server
python ./client/bin/main train --address '127.0.0.1:8080'

IMPORTANT:

  • Training won't start until at least TWO clients are connected. However, this can be changed by configuring flwr strategy.
  • After training, the new global model IS NOT SAVED BY THE CLIENT. Instead, the new model is stored under ./server/models as an .npz file. So, for inference, please manually copy the model_params.npz file to the ./client/models folder.

Future Works

Though workable, this prototype is still very "sketchy." Further developments largely fall into two broad categories:

  • system optimization: software engineering
  • more complex model: question-answering model support

The Software Part

  • Basic Requirements
    • error checking and handling
    • logging system
    • testing (robustness)
    • hard-code problem
    • etc.
  • Design Problems
    • discussed independently in the sub-modules

The Model Support Part

The iris classifier is a very simple example of Softmax Regression model. But a QA model can be much more complex. One possible step further is to implement a Text Similarity Matching model based on the WikiQA dataset. In short, the format of training samples is (Question, Bag-of-Answers) pairs. Thus, the task is to rank the candidate answers based on its pertinency with the given question, which is very similar to the Answer Selector module mentioned in the extended abstract at the very beginning.