yinboc/prototypical-network-pytorch

About loss function

Closed this issue · 13 comments

Dear Chen,
I noticed that you used cross-entropy loss function in your code, but the author of the paper tried to minimize the distance between query and support data who share the same lable, which was used as loss function. Here's his code : https://github.com/jakesnell/prototypical-networks
So, Do you have some doubt about the author's loss function?
Thank you!

Hi,
It is my understanding that, in the original paper of Prototypical Networks, they used cross-entropy loss in page 3.
Also in their code I saw the -log term in their loss.
So I guess the loss function would be the same?

Hi,
It is my understanding that, in the original paper of Prototypical Networks, they used cross-entropy loss in page 3.
Also in their code I saw the -log term in their loss.
So I guess the loss function would be the same?

Thank you for your reply.
But are you sure the following equation is cross-entropy loss?In my view, the goal of this loss in the page 3 is minimizing the distance between the same pairs and maximumimg the sum of the distance between differernt pairs(Although it has exp()and log()term but I think you can understand what I mean) ,which is different from the cross-entropy loss whose goal is minimizing the distribution between true labels and the predicted one.

default

Besides, I think the code of the author is something different from his paper,too. Just look at this:
default
When the log_softmax is calculated, the 52th line shows that only the log_softmax of the same pairs are gathered as loss, regardless of the mismatching pairs.However, this operation is meet with the page 2 in his paper:

default

That's two points of my confusion.Anyway, I think your cross-entropy loss is more reasonable and outperforms the author's one in his code.

Thanks for your question.
The cross-entropy loss is: -x[class] + \log( \sum_j \exp(x[j]) ), which is the same as the formulation above. You could see torch.nn.CrossEntropyLoss for details.
-log p_{\phi} (y = k | x) is also equivalent, you just need to expand p_{\phi}.
Let me know if I am mistaken.
Best.

@cyvius96 Hi, I have some problems below:

  1. Paper emphasizes the mean of its support set is the prototype representation of each class but in your code, for example, Choose 2 class, 3 images per class in train phase, so I has 6 images like class1_1, class2_1, class1_2, class2_2, class1_3, class2_3, then divide them into support set(4 images: class1_2, class2_2, class1_3, class2_3) and query set(2 images: class1_1, class2_1), then calculate distance between support set and query set:
    1548385575
    At last, it chooses the max distance as predict class.
    My confusion is that it calculates the distance between the test image to class1-2 and class1-3 image, but class1-2 image and class1-3 image are the same class, according to the paper, we should calculate the mean of its support set which belongs to the same class.
    Thank you in advance!

@cyvius96 Hi, I have some problems below:

  1. Paper emphasizes the mean of its support set is the prototype representation of each class but in your code, for example, Choose 2 class, 3 images per class in train phase, so I has 6 images like class1_1, class2_1, class1_2, class2_2, class1_3, class2_3, then divide them into support set(4 images: class1_2, class2_2, class1_3, class2_3) and query set(2 images: class1_1, class2_1), then calculate distance between support set and query set:
    1548385575
    At last, it chooses the max distance as predict class.
    My confusion is that it calculates the distance between the test image to class1-2 and class1-3 image, but class1-2 image and class1-3 image are the same class, according to the paper, we should calculate the mean of its support set which belongs to the same class.
    Thank you in advance!

https://github.com/cyvius96/prototypical-network-pytorch/blob/a3f8f1e1afd7fcb8cab64ba89268a80790761f88/train.py#L74
I think this line computes the mean of shots for one class.

@cyvius96 I test this code proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0), it has no effect.
Test code:
241
Test result:
240
I use torch.equal() to test, and it turns out no change.

@cyvius96 I test this code proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0), it has no effect.
Test code:
241
Test result:
240
I use torch.equal() to test, and it turns out no change.

Of course, since it is 30-way 1-shot, and there are 30 prototypes (each with dim 1600) for 30 classes.
If you test n-way 5-shot, there should be n prototypes instead of 5n, images in same class are reduced by mean.

@cyvius96 Yeah, you are right! Thank you!
I am a rookie in this field, I have two more problems:

  1. From your code, query_variable is always bigger than shot_variable, Is it normal? Can you give me a reasonable explanation or some links to resources?
  2. I have written a demo.py, ex: I train a model which can classify car and truck, I provide 3 images when I do inference(a car image and a truck image as support set, a test images as query set), but I find it performs bad, should I provide more images as support set(then computes the mean of the shot)?

@cyvius96 Yeah, you are right! Thank you!
I am a rookie in this field, I have two more problems:

  1. From your code, query_variable is always bigger than shot_variable, Is it normal? Can you give me a reasonable explanation or some links to resources?
  2. I have written a demo.py, ex: I train a model which can classify car and truck, I provide 3 images when I do inference(a car image and a truck image as support set, a test images as query set), but I find it performs bad, should I provide more images as support set(then computes the mean of the shot)?

You are welcome.

  1. It is the setting proposed by most recent works.
  2. Umm... Maybe?
    Good luck to your research.

I found the formula in the paper has an error:
242
243
It should be not 1/Nc.

@cyvius96 Hi, Can you help me check the formula?

I found the formula in the paper has an error:
242
243
It should be not 1/Nc.

Yes, I think it should be 1/N_S.

@cyvius96 Thank you for your kind.
In paper, author use softmax loss(softmax function + cross entropy loss), but you use cross entropy loss instead, I think there is some difference between softmax loss and cross entropy loss.