vislearn/FrEIA

Train on multiple GPUs with nn.DataParallel()

delitroz opened this issue · 1 comments

Hello,

I have issue trying to train an INN on multiple GPUs using nn.DataParallel().
I recreated a minimal example of the issue from the example at the beginning of the ReadMe. When initializing the inn using Ff.SequenceINN() (commented in the code below), all goes as intended and I can see all of my GPUs being loaded. But when using Ff.ReversibleGraphNet(), I get an error saying that there is a mismatch between several tensors' devices (note: this initialization works fine on a single GPU).
Since I am planning to implement a convolutional conditional INN I will need something more expressive than just a SequenceINN() that supports splitting.

Does anyone had success training such model on multiple GPUs?
Thanks.

import torch
import torch.nn as nn
from sklearn.datasets import make_moons

import FrEIA.framework as Ff
import FrEIA.modules as Fm

from tqdm import trange

BATCHSIZE = 1000
N_DIM = 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def subnet_fc(dims_in, dims_out):
    return nn.Sequential(nn.Linear(dims_in, 512), nn.ReLU(),
                         nn.Linear(512,  dims_out))

# inn = Ff.SequenceINN(N_DIM)
# for k in range(8):
#     inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)

nodes = []
nodes.append(Ff.InputNode(2))
for k in range(8):
    nodes.append(Ff.Node(nodes[-1],
                Fm.AllInOneBlock,
                {'subnet_constructor':subnet_fc, 'permute_soft':True}))
nodes.append(Ff.OutputNode(nodes[-1]))
inn = Ff.ReversibleGraphNet(nodes)

inn = nn.DataParallel(inn)
inn = inn.to(device)

optimizer = torch.optim.Adam(inn.parameters(), lr=0.001)

for i in trange(1000):
    optimizer.zero_grad()
    data, label = make_moons(n_samples=BATCHSIZE, noise=0.05)
    x = torch.Tensor(data).to(device)
    z, log_jac_det = inn(x)
    loss = 0.5*torch.sum(z**2, 1) - log_jac_det
    loss = loss.mean() / N_DIM
    loss.backward()
    optimizer.step()

z = torch.randn(100, N_DIM).to(device)
samples, _ = inn(z, rev=True)

did you manage to find a solution?