/evidential-deeplearning

Implementation of "Evidential Deep Learning to Quantify Classification Uncertainty" proposing a method to quantify uncertainty in a neural network.

Primary LanguagePythonMIT LicenseMIT

Evidential deep learning

This repo implements Evidential Deep Learning to Quantify Classification Uncertainty. The paper proposes a method to model class probabilities with a Dirichlet distribution that is parametrized by a neural network. In this setting, the neural network outputs a positive vector representing "evidences" of the truth being on a specific class and serves as parameter of the Dirichlet distribution. From that, we can get the probability vector as the mean of the distribution, but most importantly extract "evidences" for each class as a measure of belief and a global uncertainty that can be seen as a distinct class meaning "I don't know". The use of the Dirichlet distribution comes from the fact that it is the prior of the Multinomial distribution.

Under the Subjective Logic framework, belief mass assignments represent the belief that the truth can be on a given state (or a class in this setting) and it also provides an overall uncertainty quantity such that $u + \sum_{k} b_k = 1$, for K possible states (or classes).

Belief masses are calculated from evidences: $b_k = \frac{e_k}{S}$ where $S = \sum_{k} (e_k + 1)$ and the overall uncertainty is defined as $\frac{K}{S}$. Those belief mass assignments correspond to a Dirichlet distribution with parameters $\alpha_{k} = e_k + 1$, and this distribution is used as a prior over the class probabilities. The expected probability of a class k is the mean of the Dirichlet distribution: $P_k = \frac{\alpha_{k}}{ \sum_{i} \alpha_{i} }$

The paper proposes to parametrize the Dirichlet distribution describing the beliefs with a neural network that outputs a positive vector of evidences for each classes instead of outputting logits or class probabilities like in the regular classification setting. The model is trained using different possible losses that are describe in the paper.

Example on MNIST

Using a simple convolutional neural network predicting evidences, we can extract the class probabilities, beliefs and total uncertainty. 2nd sample from the top-left is interesting as we see that the model assigned some belief to the class 5 on this wiggly 9 but it still managed to correctly guess the class.

Images with predictions Corresponding beliefs and uncertainty

Dependencies

The repository has the following dependencies:

  • python 3.9+
  • pytorch 2+

Getting Started

git clone https://github.com/clabrugere/evidential-deeplearning.git

Install

See the example notebook.

# load your dataset
train_dataloader = ...
test_dataloader = ...
device = ...

# encoder can by arbitrary, for example a simple convnet here
encoder = nn.Sequential(
    nn.Conv2d(1, 20, 5),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(20, 50, 5),
    nn.ReLU(),
    nn.MaxPool2d(2, 2)
)
model = EDLClassifier(encoder, dim_encoder_out=50 * 4 * 4, dim_hidden=500, num_classes=10, dropout=0.2)
model.to(device)

bayes_risk = SSBayesRiskLoss()
kld_loss = KLDivergenceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005)
model.train()

for epoch in range(max_epoch):
    for x, labels in iter(train_dataloader):
        x, labels = x.to(device), labels.to(device)
        # the loss expects the target to be one-hot encoded
        eye = torch.eye(10, dtype=torch.float32, device=device)
        labels = eye[labels]
        evidences = model(x)

        annealing_coef = min(1.0, epoch / max_epoch)
        loss = bayes_risk(evidences, labels) + annealing_coef * kld_loss(evidences, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# make predictions with uncertainty
model.eval()
predictions, uncertainty, beliefs, labels = [], [], [], []
for x, y in iter(test_dataloader):
    x, y = x.to(device), y.to(device)
    y_pred, u, b = model.predict(x)
    labels.append(y)
    predictions.append(y_pred)
    uncertainty.append(u)
    beliefs.append(b)

labels = torch.concat(labels, dim=0)
predictions = torch.concat(predictions, dim=0)
uncertainty = torch.concat(uncertainty, dim=0)
beliefs = torch.concat(beliefs, dim=0)

References