Not Learning the Data Distribution
Closed this issue · 1 comments
I am working with a structured dataset of 20 input features and a single predicted value. All input features are uniformly distributed.
I trained a sequential INN to map input features to a N(0,1) latent space. When I run the trained INN in reverse with z values sampled from N(0,1), the predicted distribution for any input feature doesn't quite match the data distribution p(x). I tried tuning the hyperparameters, training for more epochs, but still no luck. An example for one of the input features is shown below:
Same setup works fine for the Make Moons and Boston Housing datasets in sklearn. Running in reverse yields the right distribution for the input features. An example for the Boston dataset is included below:
Any suggestions why it is not learning the data distribution? The dataset I am working with is small only 361 samples. Could that be an issue? Although, Boston Housing is 506 samples and 13 features and the INN works fine.
Code below:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import FrEIA.framework as Ff
import FrEIA.modules as Fm
BATCHSIZE = 19
N_DIM = 20
epochs = 500
def subnet_fc(dims_in, dims_out):
return nn.Sequential(nn.Linear(dims_in, 512), nn.ReLU(),
nn.Linear(512, dims_out))
inn = Ff.SequenceINN(N_DIM)
for k in range(8):
inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)
from sklearn.preprocessing import MinMaxScaler
scaler_input = MinMaxScaler()
X_train_full_scaled = scaler_input.fit_transform(X_train_full)
train_x = torch.Tensor(X_train_full_scaled)
train_dataset = TensorDataset(train_x)
train_dataloader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True)
optimizer = torch.optim.Adam(inn.parameters(), lr=0.001)
losses = []
for epoch in range(epochs):
running_loss = 0.0
for x in train_dataloader:
optimizer.zero_grad()
z, log_jac_det = inn(x)
loss = 0.5*torch.sum(z**2, 1) - log_jac_det
loss = loss.mean() / N_DIM
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss
losses.append(epoch_loss)
sample_size = 100
z = torch.randn(sample_size, N_DIM)
samples, _ = inn(z, rev=True)
samples = samples.detach().numpy()
pred = scaler_input.inverse_transform(samples)