DLR-RM/curvature

Typo in README file

StaelTchinda opened this issue · 1 comments

It seems to be a typo in the code snippet of the Get Started section. The loss is computed based on the logits and the sampled_labels (which are obtained from the logits) however the (groundtruth) labels should be used instead of the sampled_labels.

...
for images, labels in tqdm.tqdm(train_data):
    logits = model(images.to(device))

    # We compute the 'true' Fisher information matrix (FIM),
    # by taking the expectation over the model distribution.
    # To obtain the empirical FIM, just use the labels from
    # the data distribution directly.
    dist = torch.distributions.Categorical(logits=logits)
    sampled_labels = dist.sample()

    loss = criterion(logits, sampled_labels)
...