cvlab-stonybrook/PathLDM

Replicating Fine-tuning Results

Closed this issue · 27 comments

@srikarym I'm trying to replicate the results provided in your paper using this codebase. However, when running your code, I obtain the following loss plots as shown in weights and biases:

Screenshot 2024-05-14 at 1 09 42 PM

As you can see, the training loss does not change nearly at all over the course of a couple of training epochs.

I'm using the following datasets and checkpoints directly downloaded from this repo:

  1. Data
  2. Report Summaries
  3. center_crop_real_stats.npz
  4. First stage model weights
  5. Config files

I then run the code on 3 GPUs using the following command:
python -u main.py -t --gpus 0,1,2 --base config.yaml

Given the model does not appear to be learning at all, can you please aid in fixing this issue?

I believe it's an issue commonly faced in Diffusion model training. The denoising loss is not entirely reflective of image generation quality. You could visually check if the generated images are getting better, or measure the FID.

hello,Have you encountered a situation where the fid remains at a large value during the training process?As I encountered in #23

@srikarym I had a similar issue as @sjjadsa found in #23. Neither the loss nor the FID score changes throughout training despite using the data, checkpoints, and configurations you provided.

I believe #23 didn't use the entire training data. The user reports they've finished 531,899 epochs. That's physically impossible as each epoch took us around 2 days.

@srikarym That makes sense. Something doesn't seem right based on that number.

I still have the problem that my FID is around 150 despite training for 50 epochs.
This is using all of the provided checkpoints and data as mentioned in the original comment.
I also used 3 GPUs and seemingly the same batch size to avoid issues with the learning rate.

How long did you train for to achieve the results in the paper?

How do the generated samples look?

@srikarym

image

The images attached are from the "samples_gs" ones. They don't look great.
The overall structure isn't correct and doesn't resemble a real histology image.
I think the main issue is that using the FID metric doesn't allow for proper evaluation.

@LoadinggniaoL What operations did you perform to solve the problem of FID score not changing?

@sjjadsa I wasn't able to fix that. Despite trying around 50 configurations, I was not able to replicate their results using the provided data, checkpoints, and even their configuration files.

The images shown above are from the best run where I obtained an FID of around 70. At this point I'm no longer using this codebase and will rewrite the code myself.

@LoadinggniaoL did you generate samples using the same text report for computing FID? We randomly pick reports, generate 10k samples and compute FID.

@srikarym I computed the FID in a similar way. I randomly selected around 5000 text reports then sent the generated images through the pytorch-fid library. The images are in the expected range, but ignoring the FID, the output images are clearly not histologically relevant. They display features atypical of histology images so I tend to believe the high FID score is accurate.

Can you provide the exact training setup including the data, configuration files, and checkpoints in order for us to replicate these results? It's unclear exactly which setup was used based on prior comments. For instance you've mentioned 12 million training images multiple times yet the provided BRCA data only includes about 1.2 million images (assuming this is a simple mistake).

Where did you get these 5000 reports from? We used ~1000 WSI and report summaries from TCGA BRCA for training. Did you also append low / high tumor and TIL in the beginning of the summarized report?

The 5000 reports were created based on a subset of the original ~1000 reports. I prepended various combinations of the "tumor/til" prefixes as well as other commands like "histopathology whole-slide image with . Regardless of the origin, the fid score remained the same indicating that there is a more fundamental problem with the model.

Also, in the paper you mention 3.2 million patches but then the repo provides 1.2 million. What are the expected results on the 1.2 million patch version?

As a follow-up, in the paper it is worded in a way that suggests you used the training set text for validation.
Is this the case?

The data provided contains 1.2 million patches at 448x448 resolution and 10x magnification. During training, we take 256x256 random crops, which makes the expected size of the dataset 3.2 million.
Did you take random crops during training? The FID statistics we provide are for 256x256 real image crops at 10x

I am using your exact codebase and did take random crops during training.

On the other question, it seems like the training set was used to compute the FID score in the paper? Please clarify whether this is the case or not?

Can you please provide the links to the exact checkpoints, configuration, and anything else used during training in order to replicate your results?

All the FID scores reported in our paper are obtained using the same text reports used for training - this includes both our models and comparisons such as Stable diffusion.
When we used the validation set text reports, FID was ~10.

Wouldn't this unfairly bias the results towards your model though? Considering the base stable diffusion checkpoints didn't have access to that training data.

Also, would it be possible to provide the exact checkpoints tested? It's still unclear based on the prior answers to issues on this repo.

The comparison is with Stable diffusion finetuned on these patches, not the base version.
You can find checkpoints here. Best performing model was finetuned from ImageNet weights, and conditioned on text + tumor / TIL using PLIP encoder.

@srikarym Did you do anything to verify the model wasn't overfitting on the training set then? If reporting on the training set, there's no indication that the FID means anything. You can simply achieve high FID by producing memorized training data.

@srikarym Also, I'm asking for the starting model checkpoints, not just the final ones.

We used cin256-v2 model for the U-Net, which is an ImageNet pretrained model provided by the original LDM repo (see this).
For the Autoencoder, we finetune the vq-f4 VAE on 10x BRCA image patches. The VAE weights can be extracted from our final diffusion checkpoint (#6)

Reg overfitting - we perform data augmentation for a patch level tumor classification task, and observe that synthetic data improves performance.

Thanks for the clarification on the checkpoints! This helps a lot.

Back to overfitting, when testing the added performance using the synthetic data, the explained experiment doesn't show that unless you held the number of real and synthetic training samples constant. Otherwise, you cannot eliminate the performance improvement is due to simply adding more data regardless of the quality. Plus, in Table 5 in the paper, which dataset did you use to evaluate the performance? Was the accuracy computed on a single dataset and if so, which one?

In Table 5, we used the same 10x patches from TCGA BRCA for training the tumor classifier. To generate synthetic samples, we randomly pick a text report, append low / high tumor, and assign the corresponding label to the synthetic patch.
Since the text reports are at the WSI level, the type of downstream / augmentation tasks we can perform is limited.

But when doing the training did you hold the number of samples constant regardless of real or synthetic?
It's unclear whether this is an equitable comparison.

We used an equal number of real and synthetic samples. When mixing both, the training set is doubled. It's a common setup used in evaluating synthetic samples from diffusion models https://arxiv.org/pdf/2304.08466

Thanks so much for the information and all of the help!

As a brief followup, did you ever experiment with whether the FID metric is a good indicator of histology image quality?
Working through that right now as I'm unsure whether the inception model can capture histology features adequately.

It's not the best metric for histology images, but it's good enough to compare different models. In our follow up CVPR paper, we train diffusion models conditioned on SSL embeddings, and use multiple metrics such as CLIP FID, embedding similarity, etc.