pytorch/hub

Training deeplabv3_resnet50 help

corey-dawson opened this issue · 0 comments

Hello,
I am not sure if this is the correct place for a training question for one of these models but will give it an attempt anyways. I am trying to start with the deeplabv3_resnet50 vision segmentation pre-trained model and run a training on the model to fit it to my application. Unfortunately, no matter how big of a GPU I try to use, I always get an error message about "CUDA out of memory". For my latest attempt on AWS, utilizing a 24GB GPU instance. Are there any suggestions for training a vision segmentation model? Thanks in advance.

Vars:

Dataloader batch size: 5
epochs: 5
classes: 1

Error:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 860.00 MiB (GPU 0: 21.99 GiB total capacity: 21.38 GiB already allocated; 21.42 GiB reserved in total by PyTorch)
  from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
  model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
  num_classes = 2 # object+ background
  model.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
  model.aux_classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))

  epochs = 5
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  model = model.to(device)
  model.train()
  losses = []
  criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

  for i in range(epochs):
      for batch_idx, (images, masks) in enumerate(trainloader):
          images = images.to(device) # error is thrown when images and masks attempt to load to GPU
          masks = masks.to(device) 
          outputs = model(images)["out"]
          loss = criterion(outputs, masks)
          losses.append(loss)
      
          # Backward and optimize
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

Note: utilizing AWS training jobs with g5.2xlarge container. Container stats are below:
image