Convolution inside drift and diffusion term
dungxibo123 opened this issue · 2 comments
Hi,
I am re-implementing Neural SDE paper.
When I tried using Neural SDE with bunch of linear layers, every thing seem to be good, the train accuracy about 0.92 on valid set.
After that, I tried with convolution layers. First of all, I realize that f
and g
take (B,N)
shape of input. So in the forward
step of my network, I used torch.view
method as a way to reshape flattened input into the square image. And continue feed the data through bunch of Convolution layer, and as final step, one more time, I use the torch.view
method to reshape in to (B,N)
shape. (Take a look below).
class ConvolutionDrift(nn.Module):
def __init__(self, in_channel, size=32, device="cpu"):
super(ConvolutionDrift,self).__init__()
self.size=size
self.in_channel=in_channel
self.conv1 = ConcatConv2d(in_channel, 64,ksize=3,padding=1)
self.norm1 = Norm(64)
self.relu = nn.ReLU(inplace=True)
self.conv2 = ConcatConv2d(64, 64, ksize=3,padding=1)
self.norm2 = Norm(in_channel)
def forward(self,t,x):
bs = x.shape[0]
out = x.view(bs, self.in_channel, self.size, self.size)
# print(f"{out.shape}\n\n\n\n")
out = self.conv1(t,out)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(t,out)
out = self.norm2(out)
out = self.relu(out)
out = out.view(bs,-1)
return out
class ConvolutionDiffusion(nn.Module):
def __init__(self, in_channel, size=32, brownian_size = 2, device="cpu"):
super(ConvolutionDiffusion,self).__init__()
self.size=size
self.in_channel=in_channel
# self.net = nn.Sequential(*[
# nn.Conv2d(in_channel, 64,3,padding=1),
# nn.GroupNorm(32,64),
# nn.ReLU(),
# nn.Conv2d(64, 64,3,padding=1),
# nn.GroupNorm(32,64),
# nn.ReLU(),
# nn.Conv2d(64,in_channel * 2,3,padding=1),
# nn.ReLU(),
#
# ]).to(device)
self.relu = nn.ReLU()
self.norm1 = Norm(64)
self.conv1 = ConcatConv2d(in_channel, 64, ksize=3, padding = 1)
self.conv2 = ConcatConv2d(64,64, ksize=3, padding = 1)
self.norm2 = Norm(64)
self.conv3 = ConcatConv2d(64, in_channel * brownian_size, ksize = 3, padding = 1)
self.norm3 = Norm(in_channel * brownian_size)
def forward(self,t,x):
bs = x.shape[0]
out = x.view(bs, self.in_channel, self.size, self.size)
# out = self.net(out)
out = self.conv1(t,out)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(t,out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(t,out)
out = self.norm3(out)
out = self.relu(out)
out = out.view(bs,-1)
return out
class SDEBlock(nn.Module):
noise_type="general"
sde_type="ito"
def __init__(self, state_size, brownian_size, batch_size, option = dict(), device="cpu", parallel=False,
method="euler", noise_type="general", integral_type="ito", is_ode=False, input_conv_channel = 64,input_conv_size=6, layers="linear"):
super(SDEBlock, self).__init__()
self.noise_type=noise_type
self.sde_type=integral_type
self.state_size = state_size
self.batch_size = batch_size
self.brownian_size = brownian_size
self.parallel = parallel
self.device = device
self.is_ode = is_ode
if parallel:
self.batch_size = int(self.batch_size / 2)
if layers=="linear":
self.drift = LinearDrift(state_size, state_size).to(device)
self.diffusion = LinearDiffusion(state_size, state_size * brownian_size).to(device)
elif layers=="conv":
self.drift = ConvolutionDrift(input_conv_channel, input_conv_size).to(device)
self.diffusion = ConvolutionDiffusion(input_conv_channel, input_conv_size, brownian_size = self.brownian_size).to(device)
def f(self,t,x):
out = self.drift(t,x)
return out
#return self.f(x)
def g(self,t,x):
bs = x.shape[0]
if self.is_ode:
out = torch.zeros_like((self.batch_size,self.state_size, self.brownian_size)).to(self.device)
return out
out = self.diffusion(t,x)
out = out.view(bs, self.state_size, self.brownian_size)
return out
"""
SDEBlock: LinearDrift dx + LinearDiffusion dW
SDENet: fe -> SDEBlock -> fcc
"""
class SDENet(Model):
def __init__(self, input_channel, input_size, state_size, brownian_size, batch_size, option = dict(), method="euler",
noise_type="general", integral_type="ito", device = "cpu", is_ode=False,parallel = False):
""""""""""""""""""""""""
super(SDENet, self).__init__()
self.batch_size = batch_size
self.parallel = parallel
self.input_size = input_size
self.option = option
self.input_channel = input_channel
#state_size = 64 * 14 * 14
self.device = device
self.fe = nn.Sequential(*[
nn.Conv2d(input_channel,16,3,padding=1),
nn.GroupNorm(8,16),
nn.ReLU(),
nn.Conv2d(16,32,4,padding=2),
nn.GroupNorm(16,32),
nn.ReLU(),
nn.Conv2d(32,64,3,2),
nn.GroupNorm(32,64),
nn.ReLU(),
]).to(device)
state_size, input_conv_channel, input_conv_size = self.get_state_size()
self.input_conv_channel = input_conv_channel
self.input_conv_size = input_conv_size
# print(state_size, input_conv_channel, input_conv_size, "ehehehehehe\n\n\n\n")
# print(f"Init features extraction layer with device {self.device}")
# Output shape from (B,3,32,32) -> (B,64,14,14)
if parallel:
self.batch_size = int(self.batch_size / 2)
self.rm = SDEBlock(
state_size=state_size,
brownian_size = brownian_size,
batch_size = batch_size,
option=option,
method=method,
integral_type=integral_type,
noise_type=noise_type,
device=device,
parallel=parallel,
is_ode=is_ode,
input_conv_channel=input_conv_channel,
input_conv_size=input_conv_size,
layers="conv"
).to(device)
self.fcc = nn.Sequential(*[
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(input_conv_channel,10),
nn.Softmax(dim = 1)
]).to(device)
self.intergrated_time = torch.Tensor([0.0,1.0]).to(device)
self.device = device
self.method = method
def get_state_size(self):
out = torch.rand((1,self.input_channel,self.input_size, self.input_size)).to(self.device)
shape = self.fe(out)
return shape.view(1,-1).shape[-1], shape.shape[1], shape.shape[2]
def forward(self,x):
out = self.fe(x)
bs = x.shape[0]
# print(f"Shape after Feature Extraction Layer: {out.shape}")
out = out.view(bs,-1)
# print(f"After the feature extraction step, shape is: {out.shape}")
# print(f"Device of out {out.device}")
# print(f"Shape before the SDE Intergral: {out.shape}")
out = sdeint(self.rm,out,self.intergrated_time, options=self.option,method="euler", atol=5e-2,rtol=5e-2, dt=0.1, dt_min=0.05,adaptive=True)[-1]
out = out.view(bs,self.input_conv_channel, self.input_conv_size, self.input_conv_size)
out = self.fcc(out)
return out
The code run with no any mistake. But there is a thing that the results is not so good. It usually converges at second / thrid epoch.
And accuracy with convolution layers is not good at all.
So my question here is: "Will the convolution layer is available in torchsde by someway or the loss.backward()
does not update ConcatConv2d
s' parameters?"
I'm afraid that's far too large an example for us to help you debug. Moreover it's not clear that this is actually an issue to do with torchsde at all (indeed you mention that using linear layers seems to work correctly).
Your basic approach of reshaping from (batch, channels * height * width)
to (batch, channels, height, width)
, and back again, sounds essentially correct, however.
Thanks Ms @patrick-kidger, whenever I raise a issue and got your answers, I find out what happen in my work.
I have done more Experiments, and the issue cause by the Adaptive Average Pooling, I dont really know why.