learning the generative model of periodical process
Opened this issue · 1 comments
I've applied the WGAN algorithm implemented in torchsde/example/sde_gap.py to sine function (deterministic with fixed initial conditions). After 30000 learning epochs we can see that algorithm struggles to capture the periodic structure of the signal:
The sine function was implemented as:
class PeriodicSDE(torch.nn.Module):
sde_type='ito'
noise_type='diagonal'
def __init__(self):
super().__init__()
def f(self,t,y):
x1, x2 = torch.split(y, split_size_or_sections=(1, 1), dim=1)
f1 = -x2/3
f2 = x1/3
return torch.cat([f1, f2,], dim=1)
def g(self,t,y):
return 0*torch.ones_like(y)
ou_sde = PeriodicSDE().to(device)
y0= torch.ones([dataset_size,2],device=device)*2 - 1
norm= (torch.sqrt(torch.sum(y0**2,dim=1))).unsqueeze(1)
y0=y0/norm
In my opinion, the reason of low efficiency is caused by vanishing/exploding gradients in discriminator network due to weight clipping. The histograms of weights for input and output layers of "f" function of NCDE discriminator:
Most weights are stucked on the limits imposed by clipping, and effectively the learning process for discriminator network stops once this happens. Is it possible to fix through gradient penalty?
This sounds like an open research question :)