Lightning-AI/pytorch-lightning

`hparams` not loaded when loading checkpoint via LightningCLI

YouRik opened this issue · 0 comments

Bug description

Loading a model from a checkpoint using LitModule.load_from_checkpoint(...) loads the weights and stored hparams to initialize the model through its constructor.
However, running the LightningCLI with the --ckpt_path parameter only loads the weights, not the hyperparameters. Debugging shows that the model gets initialized through its constructor withoutthe correct hyperparameters that were stored in the checkpoint, making it necessary to also provide the hyperparameters (e.g. through config, even though they were stored in the checkpoint).

What version are you seeing the problem on?

v2.3, v2.4

How to reproduce the bug

# 1. bug_report.py

import lightning as L
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.cli import LightningCLI


class MyDataset(Dataset):
    def __init__(self, size=1000):
        self.size = size
        self.data = torch.randn(size, 10)
        self.labels = torch.randint(0, 2, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


class MyDataModule(L.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = MyDataset(size=1000)
        self.val_dataset = MyDataset(size=200)
        self.test_dataset = MyDataset(size=200)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


class MyLightningModule(L.LightningModule):
    def __init__(self, hidden_dim=16):
        super().__init__()
        self.save_hyperparameters()
        self.layer_1 = nn.Linear(10, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, 2)

    def forward(self, x):
        x = F.relu(self.layer_1(x))
        x = self.layer_2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


def cli_main():
    cli = LightningCLI(MyLightningModule, MyDataModule)


if __name__ == "__main__":
    cli_main()


# 2. config.yaml
seed_everything: 21

trainer:
  accelerator: cuda
  devices: 1
  max_epochs: 100
  log_every_n_steps: 1

model:
  hidden_dim: 8

# 3. Train the model with: python bug_report.py fit --config ./config.yaml
# 4. Observe that a checkpoint has been created after training
# 5. Test the model with: python bug_report.py test --ckpt_path lightning_logs/version_0/checkpoints/epoch=99-step=3200.ckpt
# 6. Observe (through debugging or error message) that the saved hyperparameters from the checkpoint are not passed to the constructor of MyLightningModule

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.4.1
#- Python version: 3.10.12
#- OS: Linux
#- CUDA/cuDNN version: 12.2
#- GPU models and configuration: NVIDIA RTX 6000 Ada
#- How you installed Lightning: pip

More info

No response