/meds-torch

Primary LanguagePythonMIT LicenseMIT

MEDS-torch

PyTorch Lightning Config: Hydra Template
Python PyPI Hydra Tests Code Quality Contributors Pull Requests License

Description

This repository provides a comprehensive suite for advanced machine learning over Electronic Health Records (EHR) using PyTorch, PyTorch Lightning, and Hydra for configuration management. The project leverages MEDS_Polars, a robust system for transforming EHR data into a structured, tabular format that enhances the accessibility and analyzability of medical datasets. By employing a variety of tokenization strategies and neural network architectures, this framework facilitates the development and testing of models that can predict, generate, and understand complex medical trajectories.

Key features include:

  • Configurable ML Pipeline: Utilize Hydra to dynamically adjust configurations and seamlessly integrate with PyTorch Lightning for scalable training across multiple environments.
  • Advanced Tokenization Techniques: Explore different approaches to processing EHR data, such as triplet tokenization and code-specific embeddings, to capture the nuances of medical information.
  • Pre-training Strategies: Leverage contrastive learning, autoregressive token forecasting, and other pre-training techniques to boost model performance with MEDS data.
  • Transfer Learning: Implement and test transfer learning scenarios to adapt pre-trained models to new tasks or datasets effectively.
  • Generative and Supervised Models: Support for zero-shot generative models and supervised training allows for a broad application of the framework in predictive and generative tasks within healthcare.

The goal of this project is to push the boundaries of what's possible in healthcare machine learning by providing a flexible, robust, and scalable platform that accommodates a wide range of research and operational needs. Whether you're conducting academic research, developing clinical applications, or exploring new machine learning methodologies, this repository offers the tools and flexibility needed to innovate and excel in the field of medical data analysis.

Installation

Pip

# clone project
git clone git@github.com:Oufattole/meds-torch.git
cd meds-torch

# [OPTIONAL] create conda environment
conda create -n meds-torch python=3.12
conda activate meds-torch

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -e .

How to run

Train model with default configuration

# train on CPU
python -m meds_torch.train trainer=cpu

# train on GPU
python -m meds_torch.train trainer=gpu

Train model with chosen experiment configuration from configs/experiment/

python -m meds_torch.train experiment=experiment_name.yaml

You can override any parameter from command line like this

python -m meds_torch.train trainer.max_epochs=20 data.batch_size=64

📌  Introduction

Why you might want to use it:

✅ Save on boilerplate
Easily add new models, datasets, tasks, experiments, and train on different accelerators, like multi-GPU, TPU or SLURM clusters.

✅ Support different tokenization methods for EHR data

  • Triplet Tokenization -- add to read the docs explanations of each subtype
  • Everything is text -- add to read the docs explanations of each subtype
  • Everything is a code TODO -- add to read the docs explanations of each subtype

✅ MEDS data pretraining (and Transfer Learning Support)

  • General Contrastive window Pretraining
  • STraTS Value Forecasting
  • Autoregressive Token Forecasting
  • Token Masked Imputation

✅ Zero shot Generative Model Support

  • Allow support for generating meds format future trajectories for patients using the Autoregressive Token Forecasting.

✅ Supervised Model Support

  • randomly initialize a model and train it in a supervised maner on your MEDS format medical data.
  • Load pretrained model weights

✅ Education
Thoroughly commented. You can use this repo as a learning resource.

✅ Reusability
Collection of useful MLOps tools, configs, and code snippets. You can use this repo as a reference for various utilities.

Why you might not want to use it:

❌ Things break from time to time
Lightning and Hydra are still evolving and integrate many libraries, which means sometimes things break. For the list of currently known problems visit this page.

❌ Not adjusted for data engineering
Template is not really adjusted for building data pipelines that depend on each other. It's more efficient to use it for model prototyping on ready-to-use data.

❌ Overfitted to simple use case
The configuration setup is built with simple lightning training in mind. You might need to put some effort to adjust it for different use cases, e.g. lightning fabric.

❌ Might not support your workflow
For example, you can't resume hydra-based multirun or hyperparameter search.

Loggers

By default wandb logger is installed with the repo. Please install a different logger below if you wish to use it:

# neptune-client
# mlflow
# comet-ml
# aim>=3.16.2  # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550

Development Help

pytest-instafail shows failures and errors instantly instead of waiting until the end of test session, run it with:

pytest --instafail

To run failing tests continuously each time you edit code until they pass:

pytest --looponfail

To run tests on 8 parallel workers run:

pytest -n 8