google-research/torchsde

Convolution inside drift and diffusion term

dungxibo123 opened this issue · 2 comments

Hi,

I am re-implementing Neural SDE paper.

When I tried using Neural SDE with bunch of linear layers, every thing seem to be good, the train accuracy about 0.92 on valid set.

After that, I tried with convolution layers. First of all, I realize that f and g take (B,N) shape of input. So in the forward step of my network, I used torch.view method as a way to reshape flattened input into the square image. And continue feed the data through bunch of Convolution layer, and as final step, one more time, I use the torch.view method to reshape in to (B,N) shape. (Take a look below).

class ConvolutionDrift(nn.Module):
    def __init__(self, in_channel, size=32, device="cpu"):
        super(ConvolutionDrift,self).__init__()
        self.size=size
        self.in_channel=in_channel
        self.conv1 = ConcatConv2d(in_channel, 64,ksize=3,padding=1)
        self.norm1 = Norm(64) 
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = ConcatConv2d(64, 64, ksize=3,padding=1)
        self.norm2 = Norm(in_channel) 
        
    def forward(self,t,x):
        bs = x.shape[0]
        out = x.view(bs, self.in_channel, self.size, self.size)
#        print(f"{out.shape}\n\n\n\n")
        out = self.conv1(t,out)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(t,out)
        out = self.norm2(out)
        out = self.relu(out)
        out = out.view(bs,-1)
        return out
class ConvolutionDiffusion(nn.Module):
    def __init__(self, in_channel, size=32, brownian_size = 2, device="cpu"):
        super(ConvolutionDiffusion,self).__init__()
        self.size=size
        self.in_channel=in_channel
#        self.net = nn.Sequential(*[
#            nn.Conv2d(in_channel, 64,3,padding=1),
#            nn.GroupNorm(32,64),
#            nn.ReLU(),
#            nn.Conv2d(64, 64,3,padding=1),
#            nn.GroupNorm(32,64),
#            nn.ReLU(),
#            nn.Conv2d(64,in_channel * 2,3,padding=1),
#            nn.ReLU(),
#            
#        ]).to(device)
        self.relu = nn.ReLU()
        self.norm1 = Norm(64)
        self.conv1 = ConcatConv2d(in_channel, 64, ksize=3, padding = 1)
        self.conv2 = ConcatConv2d(64,64, ksize=3, padding = 1)
        self.norm2 = Norm(64)
        self.conv3 = ConcatConv2d(64, in_channel * brownian_size, ksize = 3, padding = 1)
        self.norm3 = Norm(in_channel * brownian_size)
    def forward(self,t,x):
        bs = x.shape[0]
        out = x.view(bs, self.in_channel, self.size, self.size)
        # out = self.net(out)
        out = self.conv1(t,out)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(t,out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(t,out)
        out = self.norm3(out)
        out = self.relu(out)
        out = out.view(bs,-1)
        return out
        
class SDEBlock(nn.Module):
    noise_type="general"
    sde_type="ito"
    def __init__(self, state_size, brownian_size, batch_size, option = dict(), device="cpu", parallel=False,
        method="euler", noise_type="general", integral_type="ito", is_ode=False, input_conv_channel = 64,input_conv_size=6, layers="linear"):
        super(SDEBlock, self).__init__()
        self.noise_type=noise_type
        self.sde_type=integral_type
        self.state_size = state_size
        self.batch_size = batch_size
        self.brownian_size = brownian_size
        self.parallel = parallel
        self.device = device
        self.is_ode = is_ode
        if parallel:
            self.batch_size = int(self.batch_size / 2)
        
        if layers=="linear":
            self.drift = LinearDrift(state_size, state_size).to(device)
            self.diffusion = LinearDiffusion(state_size, state_size * brownian_size).to(device)

        elif layers=="conv":
            self.drift = ConvolutionDrift(input_conv_channel, input_conv_size).to(device)
            self.diffusion = ConvolutionDiffusion(input_conv_channel, input_conv_size, brownian_size = self.brownian_size).to(device)


    def f(self,t,x):  
        out = self.drift(t,x)
        return out
        #return self.f(x)
    def g(self,t,x):
        bs = x.shape[0]
        if self.is_ode:
            out =  torch.zeros_like((self.batch_size,self.state_size, self.brownian_size)).to(self.device)
            return out
        out = self.diffusion(t,x)
        
        out =  out.view(bs, self.state_size, self.brownian_size)
        return out

        


   

"""
SDEBlock: LinearDrift dx + LinearDiffusion dW
SDENet: fe -> SDEBlock -> fcc
"""
    
    
class SDENet(Model):
    def __init__(self, input_channel, input_size, state_size, brownian_size, batch_size, option = dict(), method="euler",
        noise_type="general", integral_type="ito", device = "cpu", is_ode=False,parallel = False):
        """"""""""""""""""""""""
        super(SDENet, self).__init__()
        self.batch_size = batch_size
        self.parallel = parallel
        self.input_size = input_size
        self.option = option
        self.input_channel = input_channel
        #state_size = 64 * 14 * 14
        self.device = device
        self.fe = nn.Sequential(*[
            nn.Conv2d(input_channel,16,3,padding=1),
            nn.GroupNorm(8,16),
            nn.ReLU(),
            nn.Conv2d(16,32,4,padding=2),
            nn.GroupNorm(16,32),
            nn.ReLU(),
            nn.Conv2d(32,64,3,2),
            nn.GroupNorm(32,64),
            nn.ReLU(),

        ]).to(device)
        state_size, input_conv_channel, input_conv_size = self.get_state_size()
        self.input_conv_channel = input_conv_channel
        self.input_conv_size = input_conv_size
        
#        print(state_size, input_conv_channel, input_conv_size, "ehehehehehe\n\n\n\n")
#        print(f"Init features extraction layer with device {self.device}")
        # Output shape from (B,3,32,32) -> (B,64,14,14)
        if parallel:
            self.batch_size = int(self.batch_size /  2)
        self.rm = SDEBlock(
                state_size=state_size,
                brownian_size = brownian_size,
                batch_size = batch_size,
                option=option,
                method=method,
                integral_type=integral_type,
                noise_type=noise_type,
                device=device,
                parallel=parallel,
                is_ode=is_ode,
                input_conv_channel=input_conv_channel,
                input_conv_size=input_conv_size,
                layers="conv"
            ).to(device)


        self.fcc = nn.Sequential(*[
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(input_conv_channel,10),
            nn.Softmax(dim = 1)
        ]).to(device)


        self.intergrated_time = torch.Tensor([0.0,1.0]).to(device)
        self.device = device
        self.method = method
    def get_state_size(self):
        out = torch.rand((1,self.input_channel,self.input_size, self.input_size)).to(self.device)
        shape = self.fe(out)
        return shape.view(1,-1).shape[-1], shape.shape[1], shape.shape[2]
    def forward(self,x):
        out = self.fe(x)
        bs = x.shape[0]
#        print(f"Shape after Feature Extraction Layer: {out.shape}")
        out = out.view(bs,-1)
#        print(f"After the feature extraction step, shape is: {out.shape}")
#        print(f"Device of out {out.device}")
#        print(f"Shape before the SDE Intergral: {out.shape}")
        out = sdeint(self.rm,out,self.intergrated_time, options=self.option,method="euler", atol=5e-2,rtol=5e-2, dt=0.1, dt_min=0.05,adaptive=True)[-1]
        out = out.view(bs,self.input_conv_channel, self.input_conv_size, self.input_conv_size)
        out = self.fcc(out)
        return out

The code run with no any mistake. But there is a thing that the results is not so good. It usually converges at second / thrid epoch.
And accuracy with convolution layers is not good at all.
So my question here is: "Will the convolution layer is available in torchsde by someway or the loss.backward() does not update ConcatConv2ds' parameters?"

image

I'm afraid that's far too large an example for us to help you debug. Moreover it's not clear that this is actually an issue to do with torchsde at all (indeed you mention that using linear layers seems to work correctly).

Your basic approach of reshaping from (batch, channels * height * width) to (batch, channels, height, width), and back again, sounds essentially correct, however.

Thanks Ms @patrick-kidger, whenever I raise a issue and got your answers, I find out what happen in my work.

I have done more Experiments, and the issue cause by the Adaptive Average Pooling, I dont really know why.