This repository aims at classify 3D points clouds with the help of graph neural networks (GNNs). Here, the raw input cloud is used as input to the GNN so that GNN will learn to capture meaningful local structures in order to classify the entire point set.
Python 3.9
PyTorch 1.11
torch-cluster 1.6
PyTorch Geometric 2.0.4
numpy 1.22.3
matplotlib 3.5.1
The GeometricShapes
dataset from PyTorch Geomtric dataset collection is being used here. The dataset contains 40 different 2D and 3D geometric shapes such as cubes, spheres, pyramids, etc. Moreover, for each shape, there exists two different versions. One is used to train the GNN and the other one is used to evaluate its performance.
The current implementation provides three imperative functions:-
train()
to train the GNN-based point cloud classifier.test()
to test the trained network.visualize_points()
to accomplish three main visualization related tasks:-- to plot the position of points in the point cloud,
- to plot the farthest points sampled in the point cloud,
- to plot the generated dynamic graph of the point cloud.
- In addition to these, it also provides class
PPFNet
that implements Point Pair Feature network, a rotation-invariant version of PointNet++ architecture. - The average loss and associated test accuracy for the trained model are printed after every epoch. Moreover, upon completion of the training procedure, the best test accuracy for the trained model is also printed.
- All hyperparameters to control training and testing of the model are provided in the given
.py
file. - The current implementation also utilizes the Farthest Point Sampling (FPS) procedure to downsample a point cloud. This is necessary to allow the network to extract more and more global features. Moreover, FPS iteratively selects a subset of points such that the sampled points are furthest apart.
num_graphs
variable is set to 4 in the current implementation, however, it can be set to any number based on the requirements. Note that the current implementation sets sampling_ratio
for FPS to 0.5 and the grey points in the 3rd column are disabled owing to applying FPS to that particular point cloud.
Point Cloud Classes | Position of Points | Farthest Points Sampled | Generated Dynamic Graph |
---|---|---|---|
3d_cone | |||
3d_moon | |||
3d_icecream | |||
3d_ico2 |