2017-fall-DL-training-program/ConvNetwork

What is Plain network in Lab01?

fansia opened this issue · 2 comments

Hi,

belowis the BasicBlock class which from a sample code of resnet.
To achieve plain net, could we just remove the resudual part in "def forward".
Is that right?

`class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(BasicBlock, self).__init__()
    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = nn.BatchNorm2d(planes)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes)
    self.downsample = downsample
    self.stride = stride

def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        residual = self.downsample(x)

    out += residual
    out = self.relu(out)

    return out`

Thanks.

Hi,
Yes, just remove the skip connection by removing this line "out += residual".
Jia-Ren

Hi Jia-Ren,

Thank you.