MjdMahasneh/Simple-PyTorch-Semantic-Segmentation-CNNs

other evaluations

Closed this issue · 2 comments

thank you for the amazing work, but I want to get the precision, recall and F1 ,so how to add them to the code ?

Hi,

I didnt have time to test this solution but you can have a go and modify if needed:

@torch.inference_mode()
def evaluate_iou(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    total_iou = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_f1 = 0.0

    classwise_iou = [0.0] * net.n_classes
    classwise_precision = [0.0] * net.n_classes
    classwise_recall = [0.0] * net.n_classes
    classwise_f1 = [0.0] * net.n_classes

    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in tqdm(dataloader, total=num_val_batches, desc='IoU evaluation', unit='batch', leave=False):
            image, mask_true = batch['image'], batch['mask']
            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)
            mask_pred = net(image)

            for cls in range(net.n_classes):
                mask_pred_cls = (mask_pred.argmax(dim=1) == cls).float()
                mask_true_cls = (mask_true == cls).float()

                iou_cls = iou_score(mask_pred_cls, mask_true_cls)
                precision_cls = (mask_pred_cls * mask_true_cls).sum() / (mask_pred_cls.sum() + 1e-6)
                recall_cls = (mask_pred_cls * mask_true_cls).sum() / (mask_true_cls.sum() + 1e-6)
                f1_cls = 2 * (precision_cls * recall_cls) / (precision_cls + recall_cls + 1e-6)

                classwise_iou[cls] += iou_cls
                classwise_precision[cls] += precision_cls
                classwise_recall[cls] += recall_cls
                classwise_f1[cls] += f1_cls

    # Averaging the metrics over all batches
    num_batches = max(num_val_batches, 1)
    classwise_iou = [iou / num_batches for iou in classwise_iou]
    classwise_precision = [prec / num_batches for prec in classwise_precision]
    classwise_recall = [rec / num_batches for rec in classwise_recall]
    classwise_f1 = [f1 / num_batches for f1 in classwise_f1]

    # Optionally, you can calculate overall precision, recall, and F1 across all classes, but that depends on your evaluation strategy.
    return classwise_iou, classwise_precision, classwise_recall, classwise_f1

closing as this isn't an issue.