Finetune TinySAM on custom dataset
Riley-livingston opened this issue · 1 comments
Hello, Im trying to fine-tune the mask decoder of tiny sam on a custom dataset while freezing the weights of the image_encoder and prompt_encoder. Im having an issue in my training loop where the sam.forward() requires a "multimask_output" argument but the MaskDecoder.forward() doesn't accept a "multitask_output" argument.
Im not an ML Engineer so I don't know much about the underlying code. If anyone with more knowledge than me has some insight into how I can resolve this issue I would appreciate it, thanks!
here is how im freezing the image encoder and prompt encoder to maintain the original weights:
for name, param in sam_model.named_parameters():
if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)
I am also providing bounding box Prompts as the input. Here is my custom class for Dataset creation:
class SAMDataset(Dataset):
"""
Dataset class for SAM model, serving images with associated bounding boxes and masks,
"""
def __init__(self, dataset, bbox_mapping, sam_model, device='cuda'):
self.dataset = dataset
self.bbox_mapping = bbox_mapping
self.sam_model = sam_model
self.device = device
self.target_size = (1024, 1024) # Adjusted to the expected input size of the model
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# Assuming dataset[idx] returns a dict with 'image' and 'label' keys
pil_image = self.dataset[idx]['image']
pil_mask = self.dataset[idx]['label']
image_tensor = to_tensor(np.array(pil_image)).to(self.device)
mask_tensor = to_tensor(np.array(pil_mask)).to(self.device)
# Resize image and mask to target size
image_tensor = resize(image_tensor, self.target_size)
mask_tensor = resize(mask_tensor, self.target_size)
# Fetch bounding boxes directly without padding
bboxes = self.bbox_mapping.get(idx + 1, []) # Adjust index if necessary
bboxes_tensor = torch.tensor(bboxes, dtype=torch.float, device=self.device)
return {
'image': image_tensor,
'bboxes': bboxes_tensor,
'mask': mask_tensor
}
### Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset,shuffle=True, drop_last=False)
image torch.Size([1, 3, 1024, 1024])
bboxes torch.Size([1, 1, 4])
mask torch.Size([1, 1, 1024, 1024])
`
### Training Loop
num_epochs = 1
device = "cuda"
sam_model.to(device)
sam_model.train()
for epoch in range(num_epochs):
epoch_losses = []
for batch in tqdm(train_dataloader):
# Preparing the batched_input according to the Tiny sam_model's expected input format
batched_input = [{
'image': batch['image'].squeeze(0).to(device),
'bboxes': batch['bboxes'].squeeze(0).to(device)
}]
# forward pass
outputs_list = sam_model(batched_input, multimask_output = True)
# Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
# Here, you'd need to adapt the code to match the structure of your outputs
predicted_masks = torch.stack([output['pred_mask'] for output in outputs_list]).squeeze(0)
ground_truth_masks = batch["mask"].float().squeeze(1).to(device)
loss = seg_loss(predicted_masks, ground_truth_masks)
# backward pass (compute gradients of parameters)
optimizer.zero_grad()
loss.backward()
# optimize
optimizer.step()
epoch_losses.append(loss.item())
print(f'EPOCH: {epoch}')
print(f'Mean loss: {mean(epoch_losses)}')
error when I DONT provide multitask_output:
TypeError Traceback (most recent call last)
<ipython-input-108-f41ebba752d9> in <cell line: 12>()
21
22 # forward pass
---> 23 outputs_list = sam_model(batched_input)
24
25 # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
2 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context
TypeError: Sam.forward() missing 1 required positional argument: 'multimask_output'
error when I do provide the multitask_output argument:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-42-9d874c2eda3d>](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in <cell line: 12>()
19 }]
20 # forward pass
---> 21 outputs_list = sam_model(batched_input, multimask_output = True)
22
23 # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
5 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
TypeError: MaskDecoder.forward() got an unexpected keyword argument 'multimask_output'