jt827859032/DRRN-pytorch

how does your implementation share the weight?

pzz2011 opened this issue · 1 comments

Hi, there,
I don't find the any code to evident that the parameter is shared. Maybe becanse I don't I understand how to use the "weight shared function" of pytorch? Can u help me?
thanks.

class DRRN(nn.Module):
	def __init__(self):
		super(DRRN, self).__init__()
		self.input = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
		self.conv1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
		self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
		self.output = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
		self.relu = nn.ReLU(inplace=True)

		# weights initialization
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, sqrt(2. / n))

	def forward(self, x):
		residual = x
		inputs = self.input(self.relu(x))
		out = inputs
		for _ in range(25):
			out = self.conv2(self.relu(self.conv1(self.relu(out))))
			out = torch.add(out, inputs)

		out = self.output(self.relu(out))
		out = torch.add(out, residual)
return out

Hi, @pzz2011
You can regard conv = nn.Conv2D as an instantiation of a convolutional layer. The self.conv1 and self.conv2 are 2 instantiations defined in the __init__ function and I am continuously reusing these two instantiations in the forward implementation (see in the for loop).

for _ in range(25):
	out = self.conv2(self.relu(self.conv1(self.relu(out))))
	out = torch.add(out, inputs)

If you wanna build 3 convolutional layers with different weights, you should define 3 instantiations by utilizing nn.Conv2d:

    conv1 = nn.Conv2d(xxx)
    conv2 = nn.Conv2d(xxx)
    conv3 = nn.Conv2d(xxx)