
[Finetuning OneFormer] How to use multiple GPUs

Dear @NielsRogge. First and foremost, thank you so much for your fantastic works. I did follow your tutorial and was able to finetune OneFormer. However, when I try to finetune the model on multi GPUs, it did not work.

I did two approaches:

1. Using DataParallel

import torch.nn as nn
# some code the same as your tutorial
processor.image_processor.num_text = model.config.num_queries - model.config.text_encoder_n_ctx

train_dataset = CustomDataset(processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=16)
optimizer = AdamW(model.parameters(), lr=5e-5)

model = nn.DataParallel(model)
device = 'cuda'

for epoch in range(20):  # loop over the dataset multiple times
    for batch in train_dataloader:
        # zero the parameter gradients
        batch = { for k,v in batch.items()}

        # forward pass
        outputs = model(**batch)

        # backward pass + optimize
        loss = outputs.loss
        print("Loss:", loss.item())

This code running normally but just only GPU:0 was utilized, the other GPUs do not seems to work.
Here is the result from nvidia-smi while it's running:

2. Using Accelerate
Following this tutorial, I modified the code as following:

processor.image_processor.num_text = model.config.num_queries - model.config.text_encoder_n_ctx

train_dataset = CustomDataset(processor)
# val_dataset = CustomDataset(processor)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=16)
optimizer = AdamW(model.parameters(), lr=5e-5)

accelerator = Accelerator()
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)


for epoch in range(20):  # loop over the dataset multiple times
    for batch in train_dataloader:

        # zero the parameter gradients
        # batch = { for k,v in batch.items()}

        # forward pass
        outputs = model(**batch)

        # backward pass + optimize
        loss = outputs.loss
        print("Loss:", loss.item())

This code was running normally, except only GPU:0 works.

I'm quite sure that I'm missing something here. Can you please point me to the right direction? Thank you so much!