Jittor/jittor

jittor.nn.Conv3d 没有对输入的维度进行检查,导致底层崩溃,也没有提供正确的引导信息

PhyllisJi opened this issue · 0 comments

Describe the bug

jittor.nn.Conv3d 没有对输入的维度进行检查,导致底层崩溃,也没有提供正确的引导信息

Full Log

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[10], line 34
     32     y = m(x)
     33     return list(y.shape)
---> 34 go()

Cell In[10], line 32, in go()
     30 x = jittor.randn([1, 1, 28, 28])
     31 m = lenet()
---> 32 y = m(x)
     33 return list(y.shape)

File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/__init__.py:1168, in Module.__call__(self, *args, **kw)
   1167 def __call__(self, *args, **kw):
-> 1168     return self.execute(*args, **kw)

Cell In[10], line 22, in lenet.execute(self, x)
     20 x = self.relu1(x)
     21 x = self.pool1(x)
---> 22 x = self.conv2(x)
     23 return x

File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/__init__.py:1168, in Module.__call__(self, *args, **kw)
   1167 def __call__(self, *args, **kw):
-> 1168     return self.execute(*args, **kw)

File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/nn.py:1138, in Conv3d.execute(self, x)
   1137 def execute(self, x):
-> 1138     return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/nn.py:1277, in conv3d(x, weight, bias, stride, padding, dilation, groups)
   1274 out_channels = weight.shape[0]
   1276 if jt.flags.use_cuda and jt.cudnn:
-> 1277     y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups)
   1278 elif groups == 1:
   1279     N,C,D,H,W = x.shape

RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.cudnn_conv3d)).

Types of your inputs are:
 self	= module,
 args	= (Var, Var, int, int, int, int, int, int, int, int, int, int, ),

The function declarations are:
 VarHolder* cudnn_conv3d(VarHolder* x, VarHolder* w,  int strided,  int strideh,  int stridew,  int paddingd,  int paddingh,  int paddingw,  int dilationd=1,  int dilationh=1,  int dilationw=1,  int groups=1,  string xformat="ncdhw")

Failed reason:[f 0906 09:47:16.870443 64 cudnn_conv3d_op.cc:37] Check failed x->shape.size()(4) == 5(5) Something wrong ... Could you please report this issue?

Minimal Reproduce

import os
os.environ["disable_lock"] = "1"
import jittor
import jittor.nn as nn
import jittor.optim as optim
import numpy as np
import copy


class lenet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = jittor.nn.Conv(in_channels=1, out_channels=6, kernel_size=5)
        self.relu1 = jittor.nn.ReLU()
        self.pool1 = jittor.nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = jittor.nn.Conv3d(in_channels=6, kernel_size=5, out_channels=16, dilation=(2, 7, 0))
    
    def execute(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        return x

def go():
    jittor.flags.use_cuda = 1
    x = jittor.randn([1, 1, 28, 28])
    m = lenet()
    y = m(x)
    return list(y.shape)
go()