jittor.nn.Conv3d 没有对输入的维度进行检查,导致底层崩溃,也没有提供正确的引导信息
PhyllisJi opened this issue · 0 comments
PhyllisJi commented
Describe the bug
jittor.nn.Conv3d 没有对输入的维度进行检查,导致底层崩溃,也没有提供正确的引导信息
Full Log
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[10], line 34
32 y = m(x)
33 return list(y.shape)
---> 34 go()
Cell In[10], line 32, in go()
30 x = jittor.randn([1, 1, 28, 28])
31 m = lenet()
---> 32 y = m(x)
33 return list(y.shape)
File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/__init__.py:1168, in Module.__call__(self, *args, **kw)
1167 def __call__(self, *args, **kw):
-> 1168 return self.execute(*args, **kw)
Cell In[10], line 22, in lenet.execute(self, x)
20 x = self.relu1(x)
21 x = self.pool1(x)
---> 22 x = self.conv2(x)
23 return x
File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/__init__.py:1168, in Module.__call__(self, *args, **kw)
1167 def __call__(self, *args, **kw):
-> 1168 return self.execute(*args, **kw)
File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/nn.py:1138, in Conv3d.execute(self, x)
1137 def execute(self, x):
-> 1138 return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/nn.py:1277, in conv3d(x, weight, bias, stride, padding, dilation, groups)
1274 out_channels = weight.shape[0]
1276 if jt.flags.use_cuda and jt.cudnn:
-> 1277 y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups)
1278 elif groups == 1:
1279 N,C,D,H,W = x.shape
RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.cudnn_conv3d)).
Types of your inputs are:
self = module,
args = (Var, Var, int, int, int, int, int, int, int, int, int, int, ),
The function declarations are:
VarHolder* cudnn_conv3d(VarHolder* x, VarHolder* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw")
Failed reason:[f 0906 09:47:16.870443 64 cudnn_conv3d_op.cc:37] Check failed x->shape.size()(4) == 5(5) Something wrong ... Could you please report this issue?
Minimal Reproduce
import os
os.environ["disable_lock"] = "1"
import jittor
import jittor.nn as nn
import jittor.optim as optim
import numpy as np
import copy
class lenet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = jittor.nn.Conv(in_channels=1, out_channels=6, kernel_size=5)
self.relu1 = jittor.nn.ReLU()
self.pool1 = jittor.nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = jittor.nn.Conv3d(in_channels=6, kernel_size=5, out_channels=16, dilation=(2, 7, 0))
def execute(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
return x
def go():
jittor.flags.use_cuda = 1
x = jittor.randn([1, 1, 28, 28])
m = lenet()
y = m(x)
return list(y.shape)
go()