[Bug] ops.concat does not work the same as torch.cat
Closed this issue · 1 comments
CBalaa commented
Describe the bug
ops.concat does not work the same as torch.cat on an empty tensor(shape=[0]), and this cause an Runtime error when interpreting torch.cat.
To Reproduce
The following is my code for test:
import torch
import hidet
from hidet.graph import ops
x = torch.rand([3, 4, 5])
y = torch.rand([0])
print(torch.cat([x, y]))
# x = hidet.randn([3, 4, 5])
# y = hidet.randn([0])
# print(ops.concat([x, y], axis = 0))
The hidet code report an error but torch work normally.
Expected behavior
I hope that ops.cat can do as tensor.cat does.
Or how can I custom what interpreter dose, for example, interpret torch.cat
as my custome operator, but not ops.concat
defaultly.
Enviroment
- OS: Ubuntu 22.04
- hidet version: 0.3.0
- torch version: 2.2.1 cuda
wangshangsam commented
Fixed by @ZichuWu in our internal repo and will be available in our next release (end of Oct.).