Oufattole/meds-torch

Model Performance is Suboptimal

Opened this issue · 1 comments

I observed the training loss (for both long_los and icu_mortality tasks in the MIMIC IV tutorial) is unstable and nonconverging for supervised training. Even after tuning the loss just oscillates over epochs.

I added a synthetic data test in the tune_debug branch that should be easily learnable, the model only has to learn whether or not there is the code "ADMISSION//CARDIAC" or "ADMISSION//ONCOLOGY" in the patient's history. And I trained an LSTM on this and get these metrics over $1000$ epochs and with learning rate $0.1$:
image
image

I ran this with the command:

pytest -k test_model_memorization_train[pytorch_dataset-supervised-earlyfusionNone-triplet_encoder-lstm] -s

I also trained a transformer with the same data and task and got these (note that the learning rate is in the legend for the three learning rates I tried):
image
image

The command used is:

pytest -k test_model_memorization_train[pytorch_dataset-supervised-earlyfusionNone-triplet_encoder-transformer_encoder] -s

It seems dropout is the cause of the issue. 4 Types of dropout were concurrently being applied above. So I turned the all off and tried using one at at time, and in all cases it seemed to not improve training performance. Results are below:

No dropout (set to 0)
image
image

Only layer and embed 0.1
image
image

Only layer 0.1
image
image

Only attn 0.1
image
image

ff dropout only 0.1
image
image

only embed dropout: 0.1
image
image

I tried training the transformer encoder with an extremely low learning rate (1e-16). I would expect to get very little changes in metrics over epochs for such a small learning rate, and validation metrics follow this, but training loss is still oscillating:

image
image

I added gradient plots here too using wandb to log model gradients:
https://api.wandb.ai/links/noufattole/9b0usj5g