facebookresearch/ijepa

Bias in Multiblock Mask Collator

joeagriffith opened this issue · 0 comments

I've been sampling the multi-block mask collator and plotting the masks to understand how they look, and believe I've found a bias that may have significant impact on the training of any models using this class.

The following patterns are consistently shown for batch sizes >= 128, and convey that many patches central in the image are never masked by enc_masks. Note this behaviour only occurs for allow_overlap=False.

Here are four examples I've sampled using the code below, with no cherry picking.
Each image is a 128 sized batch of enc masks, generated using the default arguments for the multi-block mask collator.
Each pixel represents a patch, and is white iff that patch is included in any of the masks in its batch.
Repro code below.

image
image
image
image

import torch
from src.mask import MaskCollator
import matplotlib.pyplot as plt

collator = MaskCollator()

batch = [torch.randn(3, 224, 224) for _ in range(1024)]
batch = collator(batch)
batch, enc_masks, pred_masks = batch

def display_mask(mask):
    # mask is a tensor of indices from 0 to 195
    # can be individual mask, or multiple.
    # display a 14x14 grid, where each cell is on if the corresponding index is in the mask
    grid = torch.zeros(14,14)
    for i in range(196):
        grid[i // 14, i % 14] = 1 if i in mask else 0
    plt.imshow(grid, cmap='gray')
    plt.show()

# change second index from ':' to integer to visualise individual masks
display_mask(enc_masks[0][:])