kentechx/pointnet

Issue with Handling Additional Features in PointNet Implementation

Dy7gy22 opened this issue · 1 comments

Hello,

First, thank you for this implementation of PointNet. I am encountering an issue when using the model with input data that has more features than just the XYZ coordinates. Specifically, the error occurs during the model initialization.

Here is the code snippet where the error arises:

import torch
from pointnet import PointNetCls, STN

stn_3d = STN(in_dim=6, out_nd=3)
model = PointNetCls(in_dim=6, out_dim=2, stn_3d=stn_3d)
xyz = torch.randn(16, 3, 1024)
rgb_features = torch.randn(16, 3, 1024)
x = torch.cat([xyz, rgb_features], dim=1)
logits = model(x)

The error message is:
/usr/local/lib/python3.10/dist-packages/pointnet/pointnet.py in init(self, in_dim, out_nd, head_norm)
48
49
---> 50
51 nn.init.normal_(self.head[-1].weight, 0, 0.001)
52 nn.init.eye_(self.head[-1].bias.view(in_dim, -1))

RuntimeError: shape '[6, 6]' is invalid for input of size 9

fixed in a7f9ad0