jittor.nn.Conv2d的stride接受非法参数异常
PhyllisJi opened this issue · 0 comments
PhyllisJi commented
关键错误信息
jittor.nn. Conv2d的stride设置为-1(或者其他非法值)时,模型定义时不会对此抛出异常,在执行模型逻辑时,计算oh、ow时得到负数,从而执行assert oh>0 and ow>0抛出异常。
预期行为
jittor.nn. Conv2d的stride非法时,应该拒绝该参数,并抛出正确的异常信息提醒用户,以防止底层算子出现未知问题,并不影响其他正常运算。
错误日志
Traceback (most recent call last):
File "D:\myvscode\softWare\PyCharm 2022.3.2\plugins\python\helpers\pydev\pydevd.py", line 1496, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:\myvscode\softWare\PyCharm 2022.3.2\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "D:\课件\研究生\项目\2023-7-5 bug标注\models\jttest.py", line 8, in <module>
output = conv(input1)
File "D:\myvscode\environment\python3.9\lib\site-packages\jittor\__init__.py", line 1172, in __call__
return self.execute(*args, **kw)
File "D:\myvscode\environment\python3.9\lib\site-packages\jittor\nn.py", line 962, in execute
assert oh>0 and ow>0
AssertionError
python-BaseException
Process finished with exit code 1
最小复现
conv = nn.Conv2d(3, 5, 2, stride=-1)
input1 = jt.randn(4, 3, 10, 10)
output = conv(input1)
print(output.shape)