`hparams` not loaded when loading checkpoint via LightningCLI
YouRik opened this issue · 0 comments
YouRik commented
Bug description
Loading a model from a checkpoint using LitModule.load_from_checkpoint(...)
loads the weights and stored hparams to initialize the model through its constructor.
However, running the LightningCLI
with the --ckpt_path
parameter only loads the weights, not the hyperparameters. Debugging shows that the model gets initialized through its constructor withoutthe correct hyperparameters that were stored in the checkpoint, making it necessary to also provide the hyperparameters (e.g. through config, even though they were stored in the checkpoint).
What version are you seeing the problem on?
v2.3, v2.4
How to reproduce the bug
# 1. bug_report.py
import lightning as L
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.cli import LightningCLI
class MyDataset(Dataset):
def __init__(self, size=1000):
self.size = size
self.data = torch.randn(size, 10)
self.labels = torch.randint(0, 2, (size,))
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class MyDataModule(L.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.save_hyperparameters()
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = MyDataset(size=1000)
self.val_dataset = MyDataset(size=200)
self.test_dataset = MyDataset(size=200)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
class MyLightningModule(L.LightningModule):
def __init__(self, hidden_dim=16):
super().__init__()
self.save_hyperparameters()
self.layer_1 = nn.Linear(10, hidden_dim)
self.layer_2 = nn.Linear(hidden_dim, 2)
def forward(self, x):
x = F.relu(self.layer_1(x))
x = self.layer_2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def cli_main():
cli = LightningCLI(MyLightningModule, MyDataModule)
if __name__ == "__main__":
cli_main()
# 2. config.yaml
seed_everything: 21
trainer:
accelerator: cuda
devices: 1
max_epochs: 100
log_every_n_steps: 1
model:
hidden_dim: 8
# 3. Train the model with: python bug_report.py fit --config ./config.yaml
# 4. Observe that a checkpoint has been created after training
# 5. Test the model with: python bug_report.py test --ckpt_path lightning_logs/version_0/checkpoints/epoch=99-step=3200.ckpt
# 6. Observe (through debugging or error message) that the saved hyperparameters from the checkpoint are not passed to the constructor of MyLightningModule
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.4.1
#- Python version: 3.10.12
#- OS: Linux
#- CUDA/cuDNN version: 12.2
#- GPU models and configuration: NVIDIA RTX 6000 Ada
#- How you installed Lightning: pip
More info
No response