Project-MONAI/GenerativeModels

Possible Improvements for Current Repository

dongyang0122 opened this issue · 6 comments

Upon thorough examination of the repository, we believe it could be enriched by the introduction of additional features. These enhancements aim to augment the repository's functionality and extend the available modules for the MONAI user community. The suggested enhancements are detailed below.

  • We propose the development of varied conditional encoder modules, as depicted in the original latent-diffusion repository, for the generation of N-Dimensional medical images. The prospective supplementary modules are outlined as follows:
    • ClassEmbedder
    • TransformerEmbedder
    • BERTTokenizer
    • BERTEmbedder
    • SpatialRescaler
    • FrozenCLIPTextEmbedder
    • FrozenClipImageEmbedder, etc.
    • Furthermore, it is crucial to incorporate comprehensive tutorials for each newly implemented encoder.
  • Consider refining the implementation of latent-diffusion to accommodate various condition types. Currently, it exclusively supports "cross-attention". We propose the inclusion of two or more additional options to enhance the system's capabilities.
    1. concat
    2. hybrid, etc.
  • Suggested improvements relating to GPU.
    • Inclusion of activation checkpointing for memory optimization along with associated tutorials.
    • Integration of distributed model training, accompanied by relevant tutorials.
  • Introduction of PyTorch ConvTranspose support in the decoder to prevent int32 limitation on torch.nn.functional.interpolate for large tensors.
  • We propose offering pre-trained diffusion model weights, accessible via the Cloud, for user integration within their specific applications, accompanied by a comprehensive demo or tutorial for ease of use.

We express keen interest in proceeding with comprehensive discussions concerning any of the items outlined above.

@marksgraham @Warvito @ericspod Receiving insights and feedback from you and your team would be greatly beneficial. Thanks.

Hi @dongyang0122 ,

We propose the development of varied conditional encoder modules, as depicted in the original latent-diffusion repository, for the generation of N-Dimensional medical images. The prospective supplementary modules are outlined as follows:
ClassEmbedder
TransformerEmbedder
BERTTokenizer
BERTEmbedder
SpatialRescaler
FrozenCLIPTextEmbedder
FrozenClipImageEmbedder, etc.
Furthermore, it is crucial to incorporate comprehensive tutorials for each newly implemented encoder.

Regarding ClassEmbedder, the current diffusion model already incorporates the class conditioning. Here we are using an implementation closer to the one from Huggingface's diffuser than the one from Compvis. Since it is mainly an embedding layer, I think moving it to a different class is unnecessary.

Regarding the embedders and extra classes for text (TransformerEmbedder, BERTTokenizer, BERTEmbedder, FrozenCLIPTextEmbedder, FrozenClipImageEmbedder), I think it depends mostly on the focus that the MONAI team wants on the repository, deciding if MONAI is mainly for medical imaging or also incorporates NLP elements. Personally, I think some of these components are already quite mature in other packages (like transformers (https://huggingface.co/docs/transformers/index)), where incorporating it on MONAI does not look necessary or a priority. @dongyang0122 What advantages would we have by adding these into MONAI? @ericspod does MONAI team plan to add NLP components in the future?

Regarding SpatialRescaler, if you all agree that it is necessary, I think it is okay and straightforward to add.

Consider refining the implementation of latent-diffusion to accommodate various condition types. Currently, it exclusively supports "cross-attention". We propose the inclusion of two or more additional options to enhance the system's capabilities.
concat
hybrid, etc.

Currently, the model architecture support all LDM conditionings (crossattn, concat, hybrid and adm(https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddpm.py#L1415C40-L1415C43)). The users can use the mentioned concat and the hybrid conditionings by concatenating the conditioning image with the inputted noisy images (during the training and sampling process) by itself (aka concat conditioning) or with cross attention (aka hybrid). For this reason, I would argue it is not necessary to make changes to the DiffusionModelUNet class. If you are referring to possibly changing the LatentDiffusionInferer, I guess it would be possible, but since these types of conditioning are not much popular, I would recommend having a simpler Inferer without them making it easier to understand and use. Any thoughts on these @marksgraham @ericspod ?

Suggested improvements relating to GPU.
Inclusion of activation checkpointing for memory optimization along with associated tutorials.
Integration of distributed model training, accompanied by relevant tutorials.

During the implementation, we tried to make it closer to the other MONAI classes and examples. @dongyang0122 Could you please share some examples of what you had in mind (related to activation checkpoint and distributed model training) from Core or other MONAI repositories?

Introduction of PyTorch ConvTranspose support in the decoder to prevent int32 limitation on torch.nn.functional.interpolate for large tensors.

@dongyang0122 Could you please elaborate more about how to use ConvTranspose in this point? Currently, we are using interpolate in the AutoencoderKL and DiffusionModelUNet to be close to the original architecture.

We propose offering pre-trained diffusion model weights, accessible via the Cloud, for user integration within their specific applications, accompanied by a comprehensive demo or tutorial for ease of use.

Yes, this sounds great. Currently, we made just a few available in the modelzoo section (https://github.com/Project-MONAI/GenerativeModels/blob/main/model-zoo/README.md), but having contributions from others researcher from the cumminity would be great. Thank you.

These are just my thoughts. @ericspod @marksgraham, what are your insights on these points?

Hi,

Regarding the text embedders, I agree that whether they're included should be more of a strategic call on whether MONAI intends to support NLP more. Perhaps Eric can share his thoughts in when he is back from leave. Having said that, if these components are available in the transformers package as @Warvito says, perhaps as a minimum we can provide a tutorial showing users how to use text-based embeddings within MONAI Generative?

Regarding concat and hybrid conditioning - they may be less popular, but I have already been asked by a user if we support them (for transformer training, in this case). I'm ambivalent on whether we include them in the Inferers but suggest as a minimum we include demonstrations of how a user can use concat/hybrid conditioning in a tutorial.

I'll wait for @dongyang0122 to reply to Walter's queries re the GPU and ConvTranspose points

Hi @Warvito and @marksgraham,

Thanks for your response!

  1. SpatialRescaler: I also think this is quite a helpful module for conditional LDM, which can simply map the condition to the same spatial size of the latent feature. I will draft a PR for it soon. For other encoder modules, we can have more discussion for them.
  2. I think Dong refers to LatentDiffusionInferer. The current implementation of LatentDiffusionInferer only supports cross-attention, which will require users to write extra code (e.g., a "for loop") to do the inference. See example in 2d_super_resolution and image_to_image_translation. This "for loop" basically did same thing as LatentDiffusionInferer except using concat condition. Letting LatentDiffusionInferer support different condition options will simplify the inference code a lot.
  3. activation checkpoint is particularly useful when we want to synthesize a large volume (e.g., 512x512x512). I have implemented it for AutoencoderKL. I will submit a PR to show how to integrate it.
  4. Supporting ConvTranspose is also for synthesizing a large volume. PyTorch F.interpolate expects output.numel() <= numeric_limits. See the attached screenshot. This limitation can be easily reached if we have large feature maps.
    image

@dongyang0122 Please correct me if I missed any points.

Thank you for the discussion! GPU memory is a bottleneck for synthesizing large 3D volumes. Related issues like ConvTranspose might be a must-have for 3D medical imaging...

Hi

  1. SpatialRescaler: I also think this is quite a helpful module for conditional LDM, which can simply map the condition to the same spatial size of the latent feature. I will draft a PR for it soon. For other encoder modules, we can have more discussion for them.

Sounds good!

  1. I think Dong refers to LatentDiffusionInferer. The current implementation of LatentDiffusionInferer only supports cross-attention, which will require users to write extra code (e.g., a "for loop") to do the inference. See example in 2d_super_resolution and image_to_image_translation. This "for loop" basically did same thing as LatentDiffusionInferer except using concat condition. Letting LatentDiffusionInferer support different condition options will simplify the inference code a lot.

Hmm it is a good point that we use this in our own tutorials already, and that we can't currently use the Infererer class. I think I would support it being included then, if @Warvito agrees? I think if we add it to LatentDiffusionInferer we should also add it to DiffusionInferer to keep their use as consistent as possible.

  1. activation checkpoint is particularly useful when we want to synthesize a large volume (e.g., 512x512x512). I have implemented it for AutoencoderKL. I will submit a PR to show how to integrate it.

Look forward to the PR!

  1. Supporting ConvTranspose is also for synthesizing a large volume. PyTorch F.interpolate expects output.numel() <= numeric_limits. See the attached screenshot. This limitation can be easily reached if we have large feature maps.
    image

I would support having this option available to users then (with the default kept to Interpolate for backwards compatibility.