Can torchinfo support BEVFusion (https://github.com/mit-han-lab/bevfusion) ?
dpan817 opened this issue · 3 comments
dpan817 commented
Has anyone tried torchinfo with BEVFusion? I tried it, but it reported that "TypeError: Model contains a layer with an unsupported input or output type: <mmdet3d.ops.spconv.structure.SparseConvTensor object at 0x7f3d9a48fee0>, type: <class 'mmdet3d.ops.spconv.structure.SparseConvTensor'>"
TylerYep commented
Can you post the full code used to reproduce this error?
dpan817 commented
sorry for the later reply, as I worked on other issues in the past two weeks.
I debugged the code without torchinfo and get the parameters for the model forward, then compose the same parameters for summary() function call, but still failed.
the model forward parameters is :
then I compose the parameters in summary() in tools/test.py
if not distributed:
model = MMDataParallel(model, device_ids=[0])
print(f"Model:\n{model}")
img_tensor=dataset[0].get('img').data
img_tensor=img_tensor.unsqueeze(0)
points_list=[dataset[0].get('points').data]
camera2ego_tensor=dataset[0].get('camera2ego').data
camera2ego_tensor=camera2ego_tensor.unsqueeze(0)
lidar2ego_tensor=dataset[0].get('lidar2ego').data
lidar2ego_tensor=lidar2ego_tensor.unsqueeze(0)
lidar2camera_tensor=dataset[0].get('lidar2camera').data
lidar2camera_tensor=lidar2camera_tensor.unsqueeze(0)
lidar2image_tensor=dataset[0].get('lidar2image').data
lidar2image_tensor=lidar2image_tensor.unsqueeze(0)
camera_intrinsics_tensor=dataset[0].get('camera_intrinsics').data
camera_intrinsics_tensor=camera_intrinsics_tensor.unsqueeze(0)
camera2lidar_tensor=dataset[0].get('camera2lidar').data
camera2lidar_tensor=camera2lidar_tensor.unsqueeze(0)
img_aug_matrix_tensor=dataset[0].get('img_aug_matrix').data
img_aug_matrix_tensor=img_aug_matrix_tensor.unsqueeze(0)
lidar_aug_matrix_tensor=dataset[0].get('lidar_aug_matrix').data
lidar_aug_matrix_tensor=lidar_aug_matrix_tensor.unsqueeze(0)
metas_list=[dataset[0].get('metas').data]
gt_masks_bev_tensor=torch.zeros(1, 6, 200, 200)
gt_bboxes_3d_list=[dataset[0].get('gt_bboxes_3d').data]
gt_labels_3d_list=[torch.tensor(dataset[0].get('gt_labels_3d').data,device='cuda:0')]
args_dict = {
'return_loss': False,
'rescale': True,
'img': img_tensor,
'points': points_list,
'gt_bboxes_3d': gt_bboxes_3d_list,
'gt_labels_3d': gt_labels_3d_list,
'gt_masks_bev': gt_masks_bev_tensor,
'camera_intrinscis': camera_intrinsics_tensor,
'camera2ego': camera2ego_tensor,
'lidar2ego': lidar2ego_tensor,
'lidar2camera': lidar2camera_tensor,
'camera2lidar': camera2lidar_tensor,
'lidar2image': lidar2image_tensor,
'img_aug_matrix': img_aug_matrix_tensor,
'lidar_aug_matrix': lidar_aug_matrix_tensor,
'metas': metas_list
}
input_dict = { }
summary(model, input_data=[input_dict, args_dict])
and the error is:
Traceback (most recent call last):
File "tools/test.py", line 288, in <module>
main()
File "tools/test.py", line 250, in main
summary(model, input_data=[input_dict, args_dict])
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 220, in summary
x, correct_input_size = process_input(
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 246, in process_input
correct_input_size = get_input_data_sizes(input_data)
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 496, in get_input_data_sizes
return traverse_input_data(
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in traverse_input_data
[traverse_input_data(d, action_fn, aggregate_fn) for d in data]
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in <listcomp>
[traverse_input_data(d, action_fn, aggregate_fn) for d in data]
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 435, in traverse_input_data
{
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 436, in <dictcomp>
k: traverse_input_data(v, action_fn, aggregate_fn)
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in traverse_input_data
[traverse_input_data(d, action_fn, aggregate_fn) for d in data]
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in <listcomp>
[traverse_input_data(d, action_fn, aggregate_fn) for d in data]
File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
result = aggregate(
File "/home/adlink/Downloads/Lidar_AI_Solution/CUDA-BEVFusion/bevfusion/mmdet3d/core/bbox/structures/base_box3d.py", line 46, in __init__
assert tensor.dim() == 2 and tensor.size(-1) == box_dim, tensor.size()
AssertionError: torch.Size([9, 1])