Oufattole/meds-torch

Add Event Stream Modeling

Opened this issue · 1 comments

We should add support for EventStream models. The tokenization is already supported in the pytorch dataset class, just set the collate type to event_stream in your hyda config. I think we just need to copy some code from ESGPT github to run this.

Getting Started

  • Branch off the dev branch (most up-to-date)
  • Event stream data support is implemented in collating (see test output batch examples)

Implementation Steps

  1. Create event_stream_input_encoder.py:

  2. Implement ESGPT custom hierarchical architecture:

  3. Supervised Model:

    • Use existing supervised_model PyTorch Lightning class
    • Override model.input_encoder and model.backbone with new ESGPT components
    • For ESGPT pretraining: Create a new PyTorch Lightning class in the models folder
  4. Add Integration Tests:

Lemme know what else I can clarify @mmcdermott