PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans
torch_kmeans features implementations of the well known k-means algorithm as well as its soft and constrained variants.
All algorithms are completely implemented as PyTorch modules and can be easily incorporated in a PyTorch pipeline or model. Therefore, they support execution on GPU as well as working on (mini-)batches of data. Moreover, they also provide a scikit-learn style interface featuring
model.fit(), model.predict() and model.fit_predict()
functions.
-> view official documentation
- Fully implemented in PyTorch. (PyTorch and Numpy are the only package dependencies!)
- GPU support like native PyTorch.
- PyTorch script JIT compiled for most performance sensitive parts.
- Works with mini-batches of samples:
- each instance can have a different number of clusters.
- Constrained Kmeans works with cluster constraints like:
- a max number of samples per cluster or,
- a maximum weight per cluster, where each sample has an associated weight.
- SoftKMeans is a fully differentiable clustering procedure and can readily be used in a PyTorch neural network model which requires backpropagation.
- Unit tested against the scikit-learn KMeans implementation.
- GPU execution enables very fast computation even for large batch size or very high dimensional feature spaces (see speed comparison)
Simply install from PyPI
pip install torch-kmeans
Pytorch style usage
import torch
from torch_kmeans import KMeans
model = KMeans(n_clusters=4)
x = torch.randn((4, 20, 2)) # (BS, N, D)
result = model(x)
print(result.labels)
Scikit-learn style usage
import torch
from torch_kmeans import KMeans
model = KMeans(n_clusters=4)
x = torch.randn((4, 20, 2)) # (BS, N, D)
model = model.fit(x)
labels = model.predict(x)
print(labels)
or
import torch
from torch_kmeans import KMeans
model = KMeans(n_clusters=4)
x = torch.randn((4, 20, 2)) # (BS, N, D)
labels = model.fit_predict(x)
print(labels)
You can find more examples and usage in the detailed example notebooks.