jittor.nn.ConvTranspose3d 的 group参数没有实现
PhyllisJi opened this issue · 0 comments
PhyllisJi commented
Describe the bug
jittor.nn.ConvTranspose3d 的 group参数没有实现
Full Log
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[11], line 30
26 y = m(x)
27 return list(y.shape)
---> 30 go()
Cell In[11], line 25, in go()
23 jittor.flags.use_cuda = 1
24 x = jittor.randn([1, 1, 28, 28])
---> 25 m = lenet()
26 y = m(x)
27 return list(y.shape)
Cell In[11], line 13, in lenet.__init__(self)
11 def __init__(self):
12 super().__init__()
---> 13 self.conv1 = jittor.nn.ConvTranspose3d(in_channels=1, kernel_size=5, out_channels=6, groups=9)
File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/nn.py:1442, in ConvTranspose3d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation)
1440 self.dilation = dilation
1441 self.group = groups
-> 1442 assert groups==1, "Group conv not supported yet."
1444 self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
1445 self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
AssertionError: Group conv not supported yet.
Minimal Reproduce
import os
os.environ["disable_lock"] = "1"
import jittor
import jittor.nn as nn
import jittor.optim as optim
import numpy as np
import copy
class lenet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = jittor.nn.ConvTranspose3d(in_channels=1, kernel_size=5, out_channels=6, groups=9)
def execute(self, x):
x = self.conv1(x)
return x
def go():
jittor.flags.use_cuda = 1
x = jittor.randn([1, 1, 28, 28])
m = lenet()
y = m(x)
return list(y.shape)
go()