huggingface/pytorch-image-models

Gradient computation fails for 'HRNet_FeatureExtractor' due to an inplace operation

mohammadalihumayun opened this issue · 6 comments

Using latest torch version when i try to train HRNet_FeatureExtractor from modules.feature_extraction, i get following error

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [4, 512, 4, 50]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later.

@mohammadalihumayun can you provide a snippet of how it's being used, eg. whatever module you're using to wrap the feat extractor and how it's connected to other modules?

@mohammadalihumayun can you provide a snippet of how it's being used, eg. whatever module you're using to wrap the feat extractor and how it's connected to other modules?

Following is the model using the feat extractor
Please note that dataset used as input is a list of tuples each containing images as numpy arrays and labels as tex strings
`
converter = CTCLabelConverter(character)
num_class = len(converter.character)
from modules.feature_extraction import HRNet_FeatureExtractor
from modules.sequence_modeling import BidirectionalLSTM

class new_Model(nn.Module):
def init(self,input_channel = 3,
output_channel = 32,
FeatureExtraction = 'HRNet',
SequenceModeling = 'DBiLSTM',
Prediction = 'CTC',
batch_max_length=100,
hidden_size=256,
imgH=32,
imgW=400,):
super(new_Model, self).init()
self.stages = {'Feat': FeatureExtraction,
'Seq': SequenceModeling,
'Pred': Prediction}

    self.FeatureExtraction = HRNet_FeatureExtractor(input_channel, output_channel)
    self.FeatureExtraction_output = output_channel
    self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
    self.SequenceModeling_output = hidden_size
    self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
    self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
def forward(self, input, text=None, is_train=True):
    visual_feature = self.FeatureExtraction(input)
    visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))  # [b, c, h, w] -> [b, w, c, h]
    visual_feature = visual_feature.squeeze(3)
    contextual_feature = self.SequenceModeling(visual_feature)
    prediction = self.Prediction(contextual_feature.contiguous())
    return prediction

model = new_Model( )
model = model.to(device)
model.train()
`

And following is the training procedure

`
""" setup loss """
criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
# loss averager
loss_avg = Averager()

filter that only require gradient decent

filtered_parameters = []
params_num = []
for p in filter(lambda p: p.requires_grad, model.parameters()):
    filtered_parameters.append(p)
    params_num.append(np.prod(p.size()))
sum_params_num = sum(params_num)
# setup optimizer
optimizer = optim.Adam(filtered_parameters, lr=1e-3, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
reduce_lr = [50,100] # epochs where you want to reduce the LR by 10

start_iter = 0
start_time = time.time()
best_accuracy = -1
best_norm_ED = -1
iteration = start_iter
init_time = time.time()
batch_size =4
num_epochs=10
grad_clip=5
import torch
import time
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
char_set=set(character)
batch_max_length =100

class CustomDataset(Dataset):
def init(self, data, imgH=32, imgW=400):
self.data = data
self.imgH = imgH
self.imgW = imgW
self.transform = T.Compose([
T.Resize((self.imgH, self.imgW)), # Resize image
T.ToTensor(), # Convert to tensor
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
])

def __len__(self):
    return len(self.data)

def __getitem__(self, idx):
    image_array, label_string = self.data[idx]
    image = Image.fromarray(image_array)
    image_tensor = self.transform(image)
    label_string=''.join(char for char in label_string if char in char_set)
    return image_tensor, label_string[:batch_max_length]

Initialize the dataset and DataLoader

train_dataset = CustomDataset(dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

torch.autograd.set_detect_anomaly(True)
for epoch in tqdm(range(num_epochs)):
logger.log("="*20,"Epoch =",epoch+1,"="*20)

for i, batch in enumerate(train_loader):
    # Unpack batch
    image_tensors, labels = zip(*batch)  # Unzip into separate lists
    image_tensors = torch.stack([torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) for img in image_tensors]).to(device)

    # Handle labels
    text, length = converter.encode(labels, batch_max_length=batch_max_length)
    batch_size = image_tensors.size(0)

    #if 'CTC' in Prediction:
    preds = model(image_tensors.permute(0, 2, 3,1))
    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
    preds = preds.log_softmax(2).permute(1, 0, 2)
    cost = criterion(preds, text, preds_size, length)
    model.zero_grad()
    cost.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # gradient clipping with 5 (Default)
    optimizer.step()
    loss_avg.add(cost)

`

However please note that same code run fines when i use another feature extractor
e.g. just by replacing
self.FeatureExtraction = HRNet_FeatureExtractor(input_channel, output_channel)
with
self.FeatureExtraction = DenseNet_FeatureExtractor(input_channel, output_channel)
within the model, the code runs fine

@mohammadalihumayun okay, so the contents of HRNet_FeatureExtractor are not visible, a module by that name does not exist in timm so I assume it's using timm HRNet if you've filed the issue here...

as far as I can tell it's not timm HRNet, which is what I support here, I can't help with another impl of HRNet.