pytorch/pytorch

[DataParallel] flatten_parameters doesn't work under torch.no_grad

apsdehal opened this issue ยท 2 comments

๐Ÿ› Bug

When the model is using DataParallel and we call flatten_parameters inside the model under torch.no_grad it throws this error:

RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()

works fine otherwise. This behavior only happens on 1.1.0 and was working fine on 1.0.1.post2

To Reproduce

Run the code below on 1.1.0 to reproduce the behavior:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
    def forward(self, x):
        self.rnn.flatten_parameters()
        return self.rnn(x)  # N * T * hidden_dim


model = torch.nn.DataParallel(Model().to('cuda'))

with torch.no_grad():
    x = model(torch.rand(2, 4, 300))

Expected behavior

flatten_parameters should work as it does without DataParallel

Environment

Collecting environment information...
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.9.4

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] msgpack-numpy==0.4.1
[pip] numpy==1.16.4
[pip] numpydoc==0.7.0
[pip] pytorch-nlp==0.3.5
[pip] pytorch-pretrained-bert==0.3.0
[pip] torch==1.1.0
[pip] torchfile==0.1.0
[pip] torchtext==0.2.3
[pip] torchvision==0.2.0
[conda] cuda90 1.0 h6433d27_0 pytorch
[conda] faiss-cpu 1.2.1 py36_cuda9.0.176_1 pytorch
[conda] faiss-gpu 1.2.1 py36_cuda9.0.176_1 pytorch
[conda] magma-cuda90 2.3.0 1 pytorch
[conda] mkl 2018.0.1 h19d6760_4 anaconda
[conda] mkl-fft 1.0.0
[conda] mkl-include 2018.0.3 1
[conda] mkl-random 1.0.1
[conda] mkl-service 1.1.2 py36h17a0993_4
[conda] mkl_fft 1.0.2 np114py36_intel_0 [intel] intel
[conda] mkl_random 1.0.1 np114py36_intel_0 [intel] intel
[conda] mkldnn 0.14.0 0 mingfeima
[conda] nccl2 1.0 0 pytorch
[conda] pytorch-nlp 0.3.5
[conda] pytorch-pretrained-bert 0.3.0
[conda] torch 1.1.0
[conda] torchfile 0.1.0
[conda] torchtext 0.2.3
[conda] torchvision 0.2.0

I met a very similar bug with torch.nn.parallel.data_parallel in PyTorch 1.2.0/1.3.0.

When applying data_parallel to the model calling flatten_parameters in the forward pass under torch.no_grad, it also throws the same error:

RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().

You can run the code below on 1.2.0/1.3.0 to reproduce the behavior:

import torch
from torch.nn.parallel import data_parallel

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
    def forward(self, x):
        self.rnn.flatten_parameters()
        return self.rnn(x)  # N * T * hidden_dim


model = Model().to('cuda')
x = torch.rand(4, 52, 300, device='cuda')

with torch.no_grad():
    data_parallel(model, x, range(2))

Environment

PyTorch version: 1.2.0/1.3.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: CentOS 7
GCC version: 6.4.0
CMake version: 3.12.0

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130

GPU models and configuration:
GPU 0: Tesla K40m
GPU 1: Tesla K40m
Nvidia driver version: 418.56

Guys, I think the issue is somehow related to how internally GRU/LSTM deal with the hidden/cell states when they are None, for example the following code works on 1.2.0 and 1.3.0

import torch
from torch.nn.parallel import data_parallel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpu = torch.cuda.device_count()
print('Number of GPUs Available:', num_gpu)

def initHidden(batch_size, bidirectional, hidden_size, num_layers, device, num_gpu):
    '''
    This function is used to create a init vector for GRU/LSTMs
    '''
    if bidirectional:
        num_directions=2
    else:
        num_directions=1
    if num_gpu > 1:
        # The Dataparallel does split by default on dim=0 so we create like this to transpose
        # inside the model forward
        hidden = torch.zeros(batch_size, num_layers * num_directions, hidden_size, device=device)
        initial_cell = torch.zeros(batch_size, num_layers * num_directions, hidden_size, device=device)
        return hidden, initial_cell
    else:
        hidden = torch.zeros(num_layers * num_directions, batch_size, hidden_size, device=device)
        initial_cell = torch.zeros(num_layers * num_directions, batch_size, hidden_size, device=device)
        return hidden, initial_cell

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = torch.nn.GRU(300, 1024, 1, batch_first=True, bidirectional=True)
    def forward(self, x, hidden):
        if self.training:
            self.rnn.flatten_parameters()
        return self.rnn(x, hidden.permute(1,0,2).contiguous())  # N * T * hidden_dim


model = Model()
if num_gpu > 1:
    model = torch.nn.DataParallel(model)
model = model.to(device)

x = torch.rand(4, 52, 300, device='cuda')
hidden = initHidden(4, True, 1024, 1, device, num_gpu)

with torch.no_grad():
    model(x,hidden[0])