/pytorch-multi-class-focal-loss

An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.

Primary LanguagePythonMIT LicenseMIT

DOI

Multi-class Focal Loss

An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case.

It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples.

Usage

  • FocalLoss is an nn.Module and behaves very much like nn.CrossEntropyLoss() i.e.

    • supports the reduction and ignore_index params, and
    • is able to work with 2D inputs of shape (N, C) as well as K-dimensional inputs of shape (N, C, d1, d2, ..., dK).
  • Example usage

    focal_loss = FocalLoss(alpha, gamma)
    ...
    inp, targets = batch
    out = model(inp)
    loss = focal_loss(out, targets)

Loading through torch.hub

This repo supports importing modules through torch.hub. FocalLoss can be easily imported into your code via, for example:

focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='FocalLoss',
	alpha=torch.tensor([.75, .25]),
	gamma=2,
	reduction='mean',
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)

Or:

focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=[.75, .25],
	gamma=2,
	reduction='mean',
	device='cpu',
	dtype=torch.float32,
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)