sksq96/pytorch-summary

Dividing model over multiple gpus in pytorch.

sreenithakasarapu opened this issue · 1 comments

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking arugment for argument target in method wrapper_nll_loss_forward)

from torchvision.models.resnet import ResNet, Bottleneck
import torch.nn as nn
num_classes = 2

class CNN(ResNet):
def init(self, *args, **kwargs):
super(CNN, self).init(
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, *args, **kwargs)

    self.seq1 = nn.Sequential(
        self.conv1,
        self.bn1,
        self.relu,
        self.maxpool,

        self.layer1,
        self.layer2
    ).to('cuda:0')

    self.seq2 = nn.Sequential(
        self.layer3,
        self.layer4,
        self.avgpool,
    ).to('cuda:1')
    
    

    self.fc.to('cuda:1')

def forward(self, x):
    x = self.seq2(self.seq1(x).to('cuda:1'))
    return self.fc(x.view(x.size(0), -1))

from torch.autograd import Variable
num_epochs = 10
def train(num_epochs, cnn, loaders):

cnn.train()
    
# Train the model
total_step = len(loaders['train'])
    
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(loaders['train']):
        
        # gives batch data, normalize x when iterate train_loader
        b_x =  Variable(images).to('cuda:0')    # batch x
        b_y = Variable(labels).to('cuda:0')   # batch y
        output = cnn(images.to('cuda:0'))  
        
        loss = loss_func(output, b_y).to('cuda:1')  
        
        print(loss)
        
        # clear gradients for this training step   
        optimizer.zero_grad()           
        
        # backpropagation, compute gradients 
        loss.backward()    
        # apply gradients             
        optimizer.step()                
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
            pass
    
        pass


    pass

train(num_epochs, cnn, loaders)

By chance were you able to solve this issue?