Issue with Handling Additional Features in PointNet Implementation
Dy7gy22 opened this issue · 1 comments
Dy7gy22 commented
Hello,
First, thank you for this implementation of PointNet. I am encountering an issue when using the model with input data that has more features than just the XYZ coordinates. Specifically, the error occurs during the model initialization.
Here is the code snippet where the error arises:
import torch
from pointnet import PointNetCls, STN
stn_3d = STN(in_dim=6, out_nd=3)
model = PointNetCls(in_dim=6, out_dim=2, stn_3d=stn_3d)
xyz = torch.randn(16, 3, 1024)
rgb_features = torch.randn(16, 3, 1024)
x = torch.cat([xyz, rgb_features], dim=1)
logits = model(x)
The error message is:
/usr/local/lib/python3.10/dist-packages/pointnet/pointnet.py in init(self, in_dim, out_nd, head_norm)
48
49
---> 50
51 nn.init.normal_(self.head[-1].weight, 0, 0.001)
52 nn.init.eye_(self.head[-1].bias.view(in_dim, -1))
RuntimeError: shape '[6, 6]' is invalid for input of size 9