No evaluation codes
Opened this issue · 3 comments
Hi authors,
Thanks for the great work!
I tried to add in evaluation code and got mIoU = 15 for VOC dataset, which deviate significantly from the number reported. I believe there must be some discrepency between my reimplementation and your code. Could you please released the code for evaluation?
Thanks a lot!
Thanks for your interest in our work and your feedback!
I don’t have time to cleanup all the evaluation pipelines but here is the one for PascalVOC. I will try to push it as part of the repo whenever I have time:
Evaluation Pipeline for PascalVOC.
Don’t forget to change the path to the PascalVOC dataset (root_path_voc
). (You can download the dataset at
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchmetrics.classification import MulticlassJaccardIndex
from einops import rearrange
import gem
class ZeroShotSegmentation(torch.nn.Module):
def __init__(self, model, tokenizer, model_name, patch_size=16,device='cpu'):
super(ZeroShotSegmentation, self).__init__()
self.model_name = model_name
self.device = device
self.gem_model = model
self.patch_size = patch_size
self.tokenizer = tokenizer
# @staticmethod
def _get_text_embedding(self, classes: list):
prompts = [f'a photo of a {cls}.' for cls in classes]
tokenized_prompts = self.tokenizer(prompts).to(self.device)
text_embedding = self.gem_model.model.encode_text(tokenized_prompts)
text_embedding = F.normalize(text_embedding, dim=-1)
return text_embedding.unsqueeze(0)
def inference(self, image, text_embedding, mask_shape):
B, _, H, W = image.shape
# forward images
feat_gem, feat_ori = self.gem_model.model.visual(image)
feat_gem = F.normalize(feat_gem, dim=-1)
# Patch/Text similarity
logits_gem = 100.0 * feat_gem[:, 1:] @ text_embedding.transpose(1, 2)
logits_gem = rearrange(logits_gem, 'b (h w) c -> b c h w', h=H // self.patch_size, w=W // self.patch_size)
# Interpolate
logits_gem = F.interpolate(logits_gem, size=mask_shape, mode='bilinear')
# Segmentation prediction
pred_gem = logits_gem.argmax(1) + 1
return pred_gem, logits_gem
def eval_dataset(self, dataloader, classes, device):
text_embedding = self._get_text_embedding(classes=classes[1:]) # remove background class
threshold = 0.85
metric_iou = MulticlassJaccardIndex(num_classes=len(classes), ignore_index=-1).to('cpu')
for i, (image, mask) in enumerate(tqdm(dataloader)):
image, mask =,
# pred_gem: [batch, W, H] | pred_logits_gem: [batch, num_class, W, H]
pred_gem, pred_logits_gem = self.inference(image, text_embedding, mask.shape[-2:])
# keep the highest probability for each pixel
logits_soft_max_gem = pred_logits_gem.softmax(dim=1).max(dim=1)[0] # 1 x H x W
# clone argmaxed prediction
pred_th_gem = pred_gem.clone()
# apply threshold
pred_th_gem[logits_soft_max_gem < threshold] = 0 # replace values under the threshold with the background class
# Compute the IoU
metric_iou(pred_th_gem.cpu(), mask)
if i%20 == 0:
print(metric_iou.compute().item() * 100)
metric_th_gem = 100 * metric_iou.compute().item()
print(f'mIoU: {metric_th_gem}')
return metric_th_gem
def main(model_name, device, pretrained, patch_size=16, root_path_voc='', batch_size=1):
# # Select Dataset
if batch_size > 1:
resize_mask = True
resize_mask = False
dataset = PascalVOC(root=root_path_voc, split='val',
transform=SegmentationTransforms((448, 448), resize_mask=resize_mask),
aug=False, only_image=False, only_mask=False, ignore_index=-1)
test_loader =, batch_size=batch_size, shuffle=False, num_workers=8)
# # Model
model = gem.create_gem_model(model_name=model_name, pretrained=pretrained)
tokenizer = gem.get_tokenizer(model_name=model_name)
# # Evaluator
zero_shot_evaluator = ZeroShotSegmentation(model=model, device=device, patch_size=patch_size,
model_name=model_name, tokenizer=tokenizer)
miou_list_cs = zero_shot_evaluator.eval_dataset(dataloader=test_loader,
return miou_list_cs
if __name__ == '__main__':
from segmentation_datasets.pascal_voc import PascalVOC, SegmentationTransforms
patch_size = 16
model_name = 'ViT-B-16-quickgelu'
pretrained = 'metaclip_400m'
root_path_voc = ‘/path/to/PascalVOC/'
print(f'model: {model_name} | pretrained: {pretrained} ')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main(model_name=model_name, pretrained=pretrained, device=device, patch_size=patch_size, root_path_voc=root_path_voc)
Here is the dataset implementation:
from os.path import join
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from import Dataset
from torchvision.transforms import transforms
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
class PascalVOC(Dataset):
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')
PALETTE = torch.tensor([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]], dtype=torch.uint8)
def __init__(self,
super(PascalVOC, self).__init__()
self.nclass = nclass if nclass is not None else self.PALETTE.shape[0]
self.only_image = only_image
self.only_mask = only_mask
self.split = split
self.return_path = return_path
self.ignore_index = ignore_index
assert self.split in ['train', 'trainval', 'val'], f'{self.split} must be in ["train", "trainval", "val"]'
self.split = 'trainaug' if aug and (self.split == 'train') else self.split
self.root = join(root, 'VOCdevkit/VOC2012/') if split_file is None else root
self.transform = transform
self.anno_type = 'SegmentationClassAug' if aug else 'SegmentationClass'
txt_file = join(self.root, split_file) if split_file is not None \
else join(self.root, 'ImageSets', 'Segmentation', self.split + '.txt')
self.samples = []
with open(txt_file) as f:
samples_tmp = f.readlines()
samples_tmp = list(map(lambda elem: elem.strip(), samples_tmp))
samples_list = []
self.image_files = []
self.label_files = []
for sample in self.samples:
if split_file is not None:
img = f'{str(sample)}.jpg'
label = f'{str(sample)}.png'
img = f'JPEGImages/{str(sample)}.jpg'
label = f'{self.anno_type}/{str(sample)}.png'
self.image_files.append(join(self.root, img))
self.label_files.append(join(self.root, label))
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = self.image_files[idx]
label_path = self.label_files[idx]
img, msk ="RGB"),"RGB")
# if self.img_transform is not None:
images, rgb_target = self.transform(img, msk)
h, w = rgb_target.shape[1:]
one_hot_seg_mask = self.ignore_index * torch.ones((h, w), dtype=torch.long)
for color_idx in range(self.nclass):
idx = (rgb_target == self.PALETTE[color_idx].unsqueeze(-1).unsqueeze(-1))
valid_idx = (idx.sum(0) == 3)#.unsqueeze(0)
one_hot_seg_mask[valid_idx] = color_idx
if self.return_path:
path_to_img_msk = {}
path_to_img_msk["img_path"] = image_path
path_to_img_msk["label_path"] = label_path
return images, one_hot_seg_mask, path_to_img_msk
return images, one_hot_seg_mask
class ToTensorMask(nn.Module):
def __init__(self):
super(ToTensorMask, self).__init__()
def forward(self, mask):
return torch.as_tensor(np.array(mask), dtype=torch.int64).permute(2, 0, 1)
class SegmentationTransforms(object):
def __init__(self, size, img_transforms=None, resize_mask=False):
self.img_transforms = img_transforms if img_transforms is not None else transforms.Compose([
transforms.Resize(size=size, interpolation=transforms.InterpolationMode.BICUBIC),
self.mask_transforms = transforms.Compose([
transforms.Resize(size=size) if resize_mask else nn.Identity(),
def __call__(self, image, mask):
return self.img_transforms(image), self.mask_transforms(mask)
if __name__ == '__main__':
root = '/path/to/PascalVOC/'
dataset = PascalVOC(root=root, split='train', transform=SegmentationTransforms((448, 448), resize_mask=False),
aug=False, only_image=False, only_mask=False)
test_loader =, batch_size=1)
for img, mask in test_loader:
You will also need to install the torchmetrics library via pip install torchmetrics
Feel free to ask if you have any questions!
I am now able to reproduce your result for VOC. Thanks a lot for your reply!!
I am also interested in the different behaviors of models pre-training with single or multiple objectives, i.e. CLIP and BLIP. Do you mind sharing how your method can be implemented with BLIP as well?
Thanks again!