Novartis/torchsurv

Neg partial log likelihood loss is 0 every time the batch size is 1

mayurmallya opened this issue · 4 comments

Hi there,

Thanks for sharing this wonderful library!

I was trying to run a survival analysis using the Cox proportional hazards model and due to the GPU constraints, I have to go with the batch size of 1. And every time I run the model, I observe that the loss value is always 0 when I'm using cox.neg_partial_log_likelihood.

I looked into the implementation of the _partial_likelihood_cox and it seems that the log_denominator gets the same value as the log_hz_sorted when the batch size is 1, resulting in the loss to be 0.

I was wondering if there is a workaround for this issue, please let me know. Also attaching the link to the corresponding code

log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)

Thank you in advance!

Dear @mayurmallya,

Thank you for your interest and for your question!

Summary
When the batch size is 1, the Cox partial log likelihood evaluates to 0. This behavior is due to the nature of the formula and the essence of the Cox partial log likelihood, not an issue in the code. To get meaningful results, you need to increase the sample size. The more participants you include, the more informative your partial log likelihood will be. However, there is a trade-off between efficiency and computational power.

More details
The Cox partial likelihood is constructed from the product of conditional probabilities, specifically the probability that that subject experience an event compared to all the other subjects who are still at risk of having an event. Unlike most likelihood functions, the likelihood depends on all observations in the set; you cannot calculate it separetly for subjects 1-5 and 6-10 and then multiply them together.

Including more subjects in your sample results in a more refined and accurate ranking, which improves the estimation of the log hazards. Conversely, with only one subject, the likelihood provides no information for parameter estimation (because the subject is compared to ... no one).

In maths
From the documentation, the Cox partial log likelihood is

$$ pll = \sum_{i: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right) $$

With only one subject $i$, if $\delta_i = 0$ then the sum is null, and if $\delta_i = 1$,

$$ pll = \left(\log \theta_1 - \log\left(\theta_1 \right) \right) = 0. $$

Side note
The partial log likelihood function can be viewed as a ranking measure. As a side note, Raykar et al. (2007) demonstrated that maximizing the Cox partial likelihood is approximately equivalent to maximizing the C-index. You can find more details in their paper here.

I hope this helps,

Melodie

Thank you @melodiemonod for the detailed answer, much appreciated!

If I have 300 samples in the dataset, I believe the ideal case scenario would be a batch size of 300 (right?). But because of the computational constraints, the batch size would be lower, let's say 10. In that case the likelihood would be calculated for subjects 1-10 and 10-20 separately right?

I'm just trying to understand what you meant by-

Unlike most likelihood functions, the likelihood depends on all observations in the set; you cannot calculate it separetly for subjects 1-5 and 6-10 and then multiply them together.

Also, based on your experience, what batch size would you recommend? Or is it simply higher the better?

Thank you once again :)

Hi @mayurmallya,

1/

I'm just trying to understand what you meant by-

Unlike most likelihood functions, the likelihood depends on all observations in the set; you cannot calculate it separetly for subjects 1-5 and 6-10 and then multiply them together.

You cannot decompose the log likelihood as follows when dealing the cox partial likelihood
$p(Y_1, Y_2) \neq p(Y_1) \times p(Y_2)$

2/

Also, based on your experience, what batch size would you recommend? Or is it simply higher the better?

It's a tradeoff between converging faster and computational power. The primary constraint is often the memory available on your GPU or TPU. On the other hand, larger batch sizes provide more accurate estimates of the gradient, potentially leading to more stable and faster convergence. Practical guidelines advise to start small: Begin with a small batch size (e.g., 32 or 64) to ensure your model trains correctly without memory issues. Monitor performance and track that the loss and accuracy on both the training and validation sets to see if increasing the batch size improves performance. If yes and if you have the memory capacity, gradually increase the batch size to see if it speeds up training without compromising the model’s ability to generalize.

Best regards

Melodie

Thank you @ahmedhshahin and @melodiemonod
Much appreciated! :)