hidet-org/hidet

[Bug] ops.concat does not work the same as torch.cat

Closed this issue · 1 comments

Describe the bug
ops.concat does not work the same as torch.cat on an empty tensor(shape=[0]), and this cause an Runtime error when interpreting torch.cat.

To Reproduce
The following is my code for test:

import torch
import hidet

from hidet.graph import ops
x = torch.rand([3, 4, 5])
y = torch.rand([0])
print(torch.cat([x, y]))

# x = hidet.randn([3, 4, 5])
# y = hidet.randn([0])
# print(ops.concat([x, y], axis = 0))

The hidet code report an error but torch work normally.

Expected behavior
I hope that ops.cat can do as tensor.cat does.
Or how can I custom what interpreter dose, for example, interpret torch.cat as my custome operator, but not ops.concat defaultly.

Enviroment

  • OS: Ubuntu 22.04
  • hidet version: 0.3.0
  • torch version: 2.2.1 cuda

Fixed by @ZichuWu in our internal repo and will be available in our next release (end of Oct.).