Indefinite quadratic programming for graph cuts and semi-supervised learning on graphs
MIT
CutSSL
This repository contains an efficient Jax implementation of an algorithm for semi-supervised learning on graphs. It relies on Jeff Calder's graphlearning package.
Graph-based SSL
In graph-based ssl, we are given a graph with $n$ vertices $G=(V,E)$. A subset of the vertices are assigned labels. The goal is to smoothly propagate the label information over the rest of the graph. The key assumption is that the underlying labeling of the graph corresponds to a partitioning with small cut. For this algorithm, the labels are assumed to be from a discrete set $\mathcal{Y}$ and imply a $k$-way partitioning of the vertices. For a $k$-class classification problem, $\mathcal{Y}$ can be the $k$ one-hot vectors.
Algorithm
For the algorithm implemented in this repo, we solve the following cardinality-constrained partitioning problem with supervision:
$$\min_{X\in \mathcal{Y}} tr(X^\top L X) \quad \text{s.t. } \sum_i X = m,\text{ } X_i = Y_i \text{ for all labels } i$$$m$ corresponds to some prior on the cardinalities of each class (default is $n/k$). The key idea of this algorithm is to (1.) relax this combinatorial problem over the discrete set $\mathcal{Y}$ to some continuous convex set and then (2.) apply some regularization to reach an extreme point of the continuous set. For (1.), the $k$ one-hot vectors correspond to the vertices of the hypercube in $\mathbb{R}^k$. We construct the convex hull of the intersection of the hypercube and the cardinality constraint: $\mathcal{X} = \{X \in \mathbb{R}^{n\times k}\text{ } |\text{ } 1_n^\top X = m,\text{ } X1_k = 1_n\}$. For (2.), we apply the regularization $-s\cdot ||X||_F^2$ to the objective for some nonnegative scalar $s$. This regularization corresponds to a diagonal perturbation of the Laplacian: $L-sI$. The idea is that the minimizers of the concave function $-||x||_2^2$ over the simplex correspond to the one-hot vectors. Compared to other continuous relaxations of cut problems (e.g. spectral / sdp), there is no need for a post-hoc rounding or thresholding step. Compared with other graph-based ssl algorithms (e.g. Laplace learning / Label propagation), this problem is well-posed at low label rates, and even without any labels at all. On the other hand, the price is nonconvexity due to the indefiniteness of $L-sI$.
We implement ADMM to solve this problem on knn graphs constructed using a small VAE on mnist, fmnist, and cifar-10 and citation networks (Planetoid). With one labeled example for each class, we get 97.5% accuracy on MNIST and 44.7% accuracy on CIFAR-10.