google-research/torchsde

learning the generative model of periodical process

Opened this issue · 1 comments

I've applied the WGAN algorithm implemented in torchsde/example/sde_gap.py to sine function (deterministic with fixed initial conditions). After 30000 learning epochs we can see that algorithm struggles to capture the periodic structure of the signal:

sine wave

The sine function was implemented as:

class PeriodicSDE(torch.nn.Module):
sde_type='ito'
noise_type='diagonal'

    def __init__(self):
        super().__init__()
    def f(self,t,y):
        x1, x2 = torch.split(y, split_size_or_sections=(1, 1), dim=1)
        f1 = -x2/3
        f2 = x1/3
        return torch.cat([f1, f2,], dim=1)
    def g(self,t,y):
        return 0*torch.ones_like(y)
        
   
ou_sde = PeriodicSDE().to(device)
y0= torch.ones([dataset_size,2],device=device)*2 - 1
norm= (torch.sqrt(torch.sum(y0**2,dim=1))).unsqueeze(1)
y0=y0/norm

In my opinion, the reason of low efficiency is caused by vanishing/exploding gradients in discriminator network due to weight clipping. The histograms of weights for input and output layers of "f" function of NCDE discriminator:

input_layer

out_layer

Most weights are stucked on the limits imposed by clipping, and effectively the learning process for discriminator network stops once this happens. Is it possible to fix through gradient penalty?

This sounds like an open research question :)