Jittor/jittor

jittor.nn.ConvTranspose3d 的 group参数没有实现

PhyllisJi opened this issue · 0 comments

Describe the bug

jittor.nn.ConvTranspose3d 的 group参数没有实现

Full Log

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[11], line 30
     26     y = m(x)
     27     return list(y.shape)
---> 30 go()

Cell In[11], line 25, in go()
     23 jittor.flags.use_cuda = 1
     24 x = jittor.randn([1, 1, 28, 28])
---> 25 m = lenet()
     26 y = m(x)
     27 return list(y.shape)

Cell In[11], line 13, in lenet.__init__(self)
     11 def __init__(self):
     12     super().__init__()
---> 13     self.conv1 = jittor.nn.ConvTranspose3d(in_channels=1, kernel_size=5, out_channels=6, groups=9)

File ~/miniconda3/envs/myconda/lib/python3.9/site-packages/jittor/nn.py:1442, in ConvTranspose3d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation)
   1440 self.dilation = dilation
   1441 self.group = groups
-> 1442 assert groups==1, "Group conv not supported yet."
   1444 self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
   1445 self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)

AssertionError: Group conv not supported yet.

Minimal Reproduce

import os
os.environ["disable_lock"] = "1"
import jittor
import jittor.nn as nn
import jittor.optim as optim
import numpy as np
import copy


class lenet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = jittor.nn.ConvTranspose3d(in_channels=1, kernel_size=5, out_channels=6, groups=9)
    
    def execute(self, x):
        x = self.conv1(x)
        return x

def go():
    jittor.flags.use_cuda = 1
    x = jittor.randn([1, 1, 28, 28])
    m = lenet()
    y = m(x)
    return list(y.shape)

go()