
PyTorch implementation of the PointNet with applications to 3D object classification and part segmentation.

Primary LanguagePythonMIT LicenseMIT


PyTorch implementation of the PointNet [1], a deep neural network that can directly process point-clouds, without using intermediate representations such as voxels or multi-view images, and is suitable for a variety of point-based 3D recognition tasks such as object classification and part segmentation.


PointNet is a neural network that can learn to arbitrarily aproximate any uniformly continuous, permutation-invariant (symmetric) function f on finite sets of points (point clouds) by decomposing it into:

f({x_1, x_n}) ≈ (g ∘ POOL)({h(x_1), ..., h(x_n)})

where x_i is the i-th point-vector of size C_in, X_in = {x_1, ..., x_n} is a point-cloud of cardinality n, h: R^{C_in} -> R^{C_out} and g: R^{C_out} -> R^L are some continuous functions, POOL is a symmetric pooling operation such as max-pooling or avg-pooling that aggregates information from the points and enforces permutation-invariance of the whole function f. Note that pooling operations are applied component-wise.

The continuous functions g and h can be approximated by Multi-layer perceptrons (MLPs) -- combination of fully-connected layers (FCs) followed by non-linearities. Since the MLP for the function h acts on all the points of a point-cloud identically and independently, we name it as PointMLP:

PointMLP({x_1, ..., x_n}) = {MLP(x_1), ..., MLP(x_n)}.

Note that PointMLP is permutation-equivariant by construction.

The vanilla PointNet thus can be represented as:

f(X_in) ≈ MLP(MAX(PointMLP(X_in))).

We note that in order for a PointNet to be able to arbitrarily approximate any continuous set function, it is required to have a PointMLP with sufficiently large number of output-layer neurons (typically C_out >= n).

In addition to the permutation-invariance, it is desirable to have invariance to certain geometric transformations (e.g. rigid transformation) of the point-clouds. TNets...


PointNet for classification

PointNet for part segmentation

Note that this network is permutation-equivariant and not permutation-invariant as the PointNet for classification.


  • Pytorch (1.10.1)
  • Numpy (1.18.5)
  • Matplotlib (3.4.1)
  • Trimesh (3.9.32)
  • h5py (3.1.0)

In parentheses are the versions that were used for creating and testing the code.


For training the PointNet for 3D object classification on ModelNet dataset:

cd applications/classification
python train_clf.py --hdf5_path "dataset.hdf5" --device "cuda" --batch_size_train 32 --num_epochs 100 ...

Similarly for evaluation and inference:

python eval_clf.py ...
python infere_clf.py ...

See arg_parser.py for all the possible command-line arguments and run_clf.ipynb for an example.


[1] PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, C. Qi et al., 2016

[2] Original TensorFlow implementation of the PointNet