Lightning-AI/pytorch-lightning

Module gets optimized in a vanilla loop, but not with a Trainer

kirilllzaitsev opened this issue · 5 comments

Bug description

I'm referring to the official MNIST example from the 1.5.0 docs, which when gathered and tweaked for 2.1.3 (also with RichProgressBar) goes as follows:

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", "./data")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64


class LitMNIST(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


model = LitMNIST(learning_rate=2e-4)
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme

pbar_callback = RichProgressBar(
    refresh_rate=1,
    leave=False,
    theme=RichProgressBarTheme(
        metrics="grey7",
        metrics_text_delimiter="\t",
        metrics_format=".3f",
    ),
)
callbacks = [pbar_callback]
trainer = Trainer(
    accelerator="auto",
    max_epochs=200,
)
trainer.fit(model)

Both train_loss and val_loss are stuck at their original values, being also agnostic to learning rate changes, overfitting setup (1 train sample), etc.

And this is the vanilla replacement for the Trainer that makes the model work:

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.prepare_data()
model.setup("fit")

num_epochs = 10
for epoch in range(num_epochs):
    print(f"### Epoch {epoch+1}/{num_epochs} ###")
    model.train()
    train_loss = 0
    stage = "train"
    for batch in model.train_dataloader():
        optimizer.zero_grad()
        batch = model.transfer_batch_to_device(batch, device, 0)
        loss = model.training_step(batch, batch_idx=0)
        print(f"{loss=}")

        loss.backward()
        optimizer.step()

producing:

train_loss.item()=2.2950997352600098
train_loss.item()=2.3004744052886963
train_loss.item()=2.2909865379333496
train_loss.item()=2.261011838912964
train_loss.item()=2.258417844772339
train_loss.item()=2.260951280593872
train_loss.item()=2.2366206645965576
train_loss.item()=2.2260756492614746
train_loss.item()=2.2181403636932373
train_loss.item()=2.206907033920288
train_loss.item()=2.1996991634368896
train_loss.item()=2.208554983139038
train_loss.item()=2.186537027359009
train_loss.item()=2.19142484664917
train_loss.item()=2.1634175777435303
train_loss.item()=2.1367063522338867
train_loss.item()=2.136497974395752
train_loss.item()=2.106687068939209
train_loss.item()=2.1067605018615723
train_loss.item()=2.105369806289673
train_loss.item()=2.0849668979644775
train_loss.item()=2.0725436210632324
train_loss.item()=2.061047077178955
train_loss.item()=2.0679869651794434
train_loss.item()=2.0411217212677
train_loss.item()=2.0361852645874023
train_loss.item()=2.0023090839385986
train_loss.item()=1.9851876497268677
train_loss.item()=2.0050344467163086
train_loss.item()=1.981319785118103
train_loss.item()=1.9606249332427979
train_loss.item()=1.9484590291976929
train_loss.item()=1.8978023529052734
train_loss.item()=1.901628851890564
train_loss.item()=1.8795408010482788
train_loss.item()=1.8782236576080322
train_loss.item()=1.8133445978164673
train_loss.item()=1.8485740423202515
train_loss.item()=1.7907288074493408
train_loss.item()=1.798086166381836
train_loss.item()=1.7626056671142578
train_loss.item()=1.7450191974639893
train_loss.item()=1.7486826181411743
train_loss.item()=1.7234904766082764
train_loss.item()=1.6913076639175415
train_loss.item()=1.6666518449783325
train_loss.item()=1.6856220960617065
train_loss.item()=1.623719334602356
train_loss.item()=1.6575146913528442
train_loss.item()=1.654606580734253
train_loss.item()=1.5434222221374512
train_loss.item()=1.5818841457366943
train_loss.item()=1.6028629541397095
train_loss.item()=1.5772353410720825
train_loss.item()=1.5514637231826782
train_loss.item()=1.4747498035430908
train_loss.item()=1.4480849504470825
train_loss.item()=1.4451349973678589
train_loss.item()=1.4886677265167236
train_loss.item()=1.3995665311813354
train_loss.item()=1.3996336460113525
train_loss.item()=1.3508365154266357
train_loss.item()=1.4039489030838013
train_loss.item()=1.356732726097107

also having a minor extension to the module:

    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        if isinstance(batch, dict):
            for key in batch.keys():
                batch[key] = batch[key].to(device)
        else:
            batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
        return batch

I experienced the issue in a completely different setup which made me go and try the MNIST one. Wondering what could that be?



### Error messages and logs

Both `train_loss` and `val_loss` are stuck at their original values when training with `Trainer` and getting optimized with the vanilla loop above.


### Environment

<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:
                - NVIDIA GeForce RTX 3070 Ti Laptop GPU
        - available:         True
        - version:           11.7
* Lightning:
        - efficientnet-pytorch: 0.7.1
        - functorch:         1.13.1
        - lightning:         2.1.3
        - lightning-cloud:   0.5.34
        - lightning-utilities: 0.8.0
        - open-clip-torch:   2.19.0
        - pytorch-lightning: 2.0.5
        - segmentation-models-pytorch: 0.3.3
        - torch:             2.0.1+cu117
        - torch-ema:         0.3
        - torch-fidelity:    0.3.0
        - torchmetrics:      0.7.3
        - torchnet:          0.0.4
        - torchtyping:       0.1.4
        - torchvision:       0.15.2+cu117
* Packages:
        - absl-py:           1.4.0
        - action-detection:  0.1.0
        - addict:            2.4.0
        - afm-op:            0.0
        - aicrowd-cli:       0.1.15
        - aiofiles:          22.1.0
        - aiohttp:           3.7.4.post0
        - aiosignal:         1.3.1
        - aiosqlite:         0.19.0
        - albumentations:    1.3.0
        - alembic:           1.11.1
        - alternet:          0.0.1
        - antlr4-python3-runtime: 4.9.3
        - anyio:             3.6.2
        - appdirs:           1.4.4
        - argcomplete:       3.1.1
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - array-record:      0.5.0
        - arrow:             1.2.3
        - astor:             0.8.1
        - astroid:           2.14.2
        - asttokens:         2.2.1
        - astunparse:        1.6.3
        - async-timeout:     3.0.1
        - attrs:             23.1.0
        - audioread:         3.0.0
        - autopage:          0.5.1
        - autopep8:          2.0.2
        - av:                10.0.0
        - awswrangler:       3.1.1
        - azure-common:      1.1.28
        - azure-core:        1.29.4
        - azure-identity:    1.14.0
        - azure-mgmt-core:   1.4.0
        - azure-mgmt-monitor: 6.0.2
        - azure-monitor-query: 1.2.0
        - babel:             2.12.1
        - backcall:          0.2.0
        - beautifulsoup4:    4.12.2
        - bidict:            0.22.1
        - binaryornot:       0.4.4
        - black:             23.3.0
        - bleach:            6.0.0
        - blessed:           1.20.0
        - blinker:           1.6.2
        - boto3:             1.26.130
        - botocore:          1.29.130
        - brewer2mpl:        1.4.1
        - brotli:            1.0.9
        - brotlipy:          0.7.0
        - cached-property:   1.5.2
        - cachetools:        5.3.0
        - catkin-pkg:        0.5.2
        - catkin-tools:      0.9.4
        - catkin-tools-fetch: 0.3.5
        - certifi:           2023.11.17
        - cffi:              1.15.1
        - cfgv:              3.3.1
        - chardet:           4.0.0
        - charset-normalizer: 2.1.1
        - click:             8.1.4
        - cliff:             4.3.0
        - clip:              1.0
        - cloudpickle:       2.2.1
        - cmaes:             0.9.1
        - cmake:             3.26.3
        - cmd2:              2.4.3
        - colorama:          0.4.6
        - colorlog:          6.7.0
        - comet-ml:          3.33.11
        - comm:              0.1.3
        - commonmark:        0.9.1
        - configargparse:    1.5.3
        - configobj:         5.0.8
        - configupdater:     3.1.1
        - contourpy:         1.0.7
        - contrails:         0.0.1
        - cookiecutter:      2.3.0
        - coverage:          7.3.2
        - crc8:              0.1.0
        - croniter:          1.3.14
        - cryptography:      40.0.2
        - customtkinter:     5.1.3
        - cycler:            0.11.0
        - cython:            0.29.34
        - dacite:            1.8.1
        - darkdetect:        0.8.0
        - dash:              2.9.3
        - dash-core-components: 2.0.0
        - dash-html-components: 2.0.0
        - dash-table:        5.0.0
        - dask:              2023.3.1
        - dataclass-array:   1.5.1
        - datasets:          2.12.0
        - dateutils:         0.6.12
        - debugpy:           1.6.7
        - decorator:         5.1.1
        - deepdiff:          6.3.0
        - deeplsd:           0.0
        - defusedxml:        0.7.1
        - deprecated:        1.2.13
        - descartes:         1.1.0
        - detectron2:        0.6
        - dgp:               1.0
        - diffdist:          0.1
        - dill:              0.3.6
        - diskcache:         5.6.3
        - distiller:         0.4.0rc0
        - distlib:           0.3.6
        - distro:            1.8.0
        - dm-tree:           0.1.8
        - dnet:              0.0.1
        - dnspython:         2.4.0
        - docker:            6.1.2
        - docker-pycreds:    0.4.0
        - docstring-parser:  0.14.1
        - docutils:          0.20.1
        - drjit:             0.4.3
        - dulwich:           0.21.5
        - efficientnet-pytorch: 0.7.1
        - einops:            0.6.1
        - einsum:            0.3.0
        - empy:              3.3.4
        - et-xmlfile:        1.1.0
        - etils:             1.5.1
        - eventlet:          0.33.3
        - everett:           3.1.0
        - exceptiongroup:    1.0.4
        - executing:         1.2.0
        - fairscale:         0.4.13
        - fastapi:           0.88.0
        - fastjsonschema:    2.16.3
        - fiftyone:          0.21.4
        - fiftyone-brain:    0.13.0
        - fiftyone-db:       0.4.0
        - fiftyone-db-ubuntu2204: 0.4.0
        - filelock:          3.12.0
        - fire:              0.5.0
        - flake8:            6.0.0
        - flake8-black:      0.3.6
        - flask:             2.3.2
        - flatbuffers:       23.3.3
        - flow-vis:          0.1
        - fonttools:         4.39.3
        - fqdn:              1.5.1
        - frozendict:        2.3.8
        - frozenlist:        1.3.3
        - fsspec:            2023.10.0
        - ftfy:              6.1.1
        - functorch:         1.13.1
        - future:            0.18.3
        - fvcore:            0.1.5.post20221221
        - gast:              0.4.0
        - gcsfs:             2023.10.0
        - gcvit:             0.0.1
        - gdown:             4.7.1
        - gitdb:             4.0.10
        - gitpython:         3.1.31
        - glob2:             0.7
        - google-api-core:   2.12.0
        - google-auth:       2.16.2
        - google-auth-oauthlib: 0.4.6
        - google-cloud-core: 2.3.3
        - google-cloud-storage: 2.11.0
        - google-crc32c:     1.5.0
        - google-pasta:      0.2.0
        - google-resumable-media: 2.6.0
        - googleapis-common-protos: 1.61.0
        - grad-cam:          1.4.8
        - graphql-core:      3.2.3
        - greenlet:          2.0.2
        - grpcio:            1.51.1
        - h11:               0.14.0
        - h2:                4.1.0
        - h5py:              3.8.0
        - homography-est:    0.0.0
        - hpack:             4.0.0
        - httpcore:          0.17.3
        - httpx:             0.24.1
        - huggingface-hub:   0.14.1
        - hydra-colorlog:    1.2.0
        - hydra-core:        1.3.2
        - hydra-optuna-sweeper: 1.2.0
        - hypercorn:         0.14.4
        - hyperframe:        6.0.1
        - identify:          2.5.24
        - idna:              3.4
        - imagededup:        0.3.2
        - imageio:           2.28.1
        - imgaug:            0.4.0
        - immutabledict:     2.2.0
        - importlib-metadata: 6.6.0
        - importlib-resources: 5.12.0
        - inflate64:         0.3.1
        - ini2toml:          0.12
        - iniconfig:         1.1.1
        - inquirer:          3.1.3
        - iopath:            0.1.9
        - ip-basic:          1.0.0
        - ipdb:              0.13.13
        - ipykernel:         6.22.0
        - ipympl:            0.9.3
        - ipython:           8.13.2
        - ipython-genutils:  0.2.0
        - ipywidgets:        7.7.2
        - iso8601:           1.1.0
        - isodate:           0.6.1
        - isoduration:       20.11.0
        - isort:             5.12.0
        - itsdangerous:      2.1.2
        - jax:               0.4.8
        - jedi:              0.18.2
        - jinja2:            3.1.2
        - jmespath:          1.0.1
        - joblib:            1.2.0
        - json5:             0.9.11
        - jsonargparse:      4.21.0
        - jsonlines:         3.1.0
        - jsonpatch:         1.32
        - jsonpointer:       2.3
        - jsonschema:        4.17.3
        - jupyter:           1.0.0
        - jupyter-client:    8.2.0
        - jupyter-console:   6.6.3
        - jupyter-core:      5.3.0
        - jupyter-events:    0.6.3
        - jupyter-http-over-ws: 0.0.8
        - jupyter-server:    2.5.0
        - jupyter-server-fileid: 0.9.0
        - jupyter-server-terminals: 0.4.4
        - jupyter-server-ydoc: 0.8.0
        - jupyter-ydoc:      0.2.4
        - jupyterlab:        3.6.3
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 2.22.1
        - jupyterlab-widgets: 1.1.7
        - kaggle:            1.5.15
        - kaleido:           0.2.1
        - kbnet:             0.1.0
        - keras:             2.11.0
        - keras-preprocessing: 1.1.2
        - kiwisolver:        1.4.4
        - kornia:            0.6.0
        - lark:              1.1.8
        - layout-aware-monodepth: 0.0.1
        - lazy-loader:       0.2
        - lazy-object-proxy: 1.6.0
        - lerf:              0.1.0
        - libclang:          16.0.0
        - librosa:           0.10.0.post2
        - lightning:         2.1.3
        - lightning-cloud:   0.5.34
        - lightning-utilities: 0.8.0
        - line-refinement:   0.0.0
        - lit:               16.0.2
        - llvmlite:          0.40.0
        - locket:            1.0.0
        - lpips:             0.1.4
        - lvis:              0.5.3
        - mako:              1.2.4
        - markdown:          3.4.3
        - markdown-it-py:    2.2.0
        - markupsafe:        2.1.2
        - mask2former:       0.1
        - matplotlib:        3.6.1
        - matplotlib-inline: 0.1.6
        - mccabe:            0.7.0
        - mdurl:             0.1.2
        - mediapy:           1.1.6
        - mistune:           2.0.5
        - mitsuba:           3.4.0
        - mkl-fft:           1.3.6
        - mkl-random:        1.2.2
        - mkl-service:       2.4.0
        - ml-dtypes:         0.1.0
        - mock:              5.1.0
        - modelip:           0.0.1
        - mongoengine:       0.24.2
        - monodepth2:        0.0.1
        - motor:             3.2.0
        - mpmath:            1.3.0
        - msal:              1.24.0
        - msal-extensions:   1.0.0
        - msgpack:           1.0.5
        - msgpack-numpy:     0.4.8
        - multidict:         6.0.4
        - multiprocess:      0.70.14
        - multiscaledeformableattention: 1.0
        - multivolumefile:   0.2.3
        - munch:             2.5.0
        - mypy:              0.981
        - mypy-extensions:   0.4.3
        - nbclassic:         1.0.0
        - nbclient:          0.7.4
        - nbconvert:         7.3.1
        - nbformat:          5.7.0
        - nerfacc:           0.5.2
        - nerfstudio:        0.2.2
        - nest-asyncio:      1.5.6
        - networkx:          3.1
        - ninja:             1.11.1
        - nltk:              3.8.1
        - nodeenv:           1.8.0
        - notebook:          6.5.4
        - notebook-shim:     0.2.3
        - numba:             0.57.0
        - numpy:             1.23.5
        - nuscenes-devkit:   1.1.9
        - nvidia-cublas-cu11: 11.10.3.66
        - nvidia-cuda-cupti-cu11: 11.7.101
        - nvidia-cuda-nvrtc-cu11: 11.7.99
        - nvidia-cuda-runtime-cu11: 11.7.99
        - nvidia-cudnn-cu11: 8.5.0.96
        - nvidia-cufft-cu11: 10.9.0.58
        - nvidia-curand-cu11: 10.2.10.91
        - nvidia-cusolver-cu11: 11.4.0.1
        - nvidia-cusparse-cu11: 11.7.4.91
        - nvidia-nccl-cu11:  2.14.3
        - nvidia-nvtx-cu11:  11.7.91
        - oauthlib:          3.2.2
        - omegaconf:         2.3.0
        - onnx:              1.14.0
        - onnx-tf:           1.10.0
        - open-clip-torch:   2.19.0
        - open3d:            0.17.0
        - opencv-contrib-python: 4.7.0.72
        - opencv-python:     4.8.0.74
        - opencv-python-headless: 4.7.0.72
        - openexr:           1.3.9
        - openocd:           0.1.1
        - openpyxl:          3.1.2
        - opt-einsum:        3.3.0
        - optuna:            2.10.1
        - ordered-set:       4.1.0
        - osrf-pycommon:     2.0.2
        - packaging:         23.1
        - pandas:            1.5.3
        - pandoc:            2.3
        - pandocfilters:     1.5.0
        - panopticapi:       0.1
        - parso:             0.8.3
        - partd:             1.4.1
        - pathspec:          0.11.1
        - pathtools:         0.1.2
        - pbr:               5.11.1
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            10.1.0
        - pip:               23.1.2
        - platformdirs:      3.5.1
        - plotly:            5.13.1
        - pluggy:            1.0.0
        - plumbum:           1.8.2
        - ply:               3.11
        - pointnet2:         3.0.0
        - pooch:             1.7.0
        - portalocker:       2.7.0
        - pprintpp:          0.4.0
        - pre-commit:        3.3.3
        - pretrainedmodels:  0.7.4
        - prettytable:       3.8.0
        - priority:          2.0.0
        - prometheus-client: 0.16.0
        - promise:           2.3
        - prompt-toolkit:    3.0.38
        - protobuf:          3.19.6
        - psutil:            5.9.0
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - py7zr:             0.20.5
        - pyarrow:           10.0.0
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pybcj:             1.0.1
        - pycocotools:       2.0.6
        - pycodestyle:       2.10.0
        - pycparser:         2.21
        - pycryptodomex:     3.18.0
        - pydantic:          1.10.7
        - pydensecrf:        1.0rc2
        - pydeprecate:       0.3.1
        - pydot:             1.4.2
        - pyequilib:         0.5.6
        - pyflakes:          3.0.1
        - pygithub:          1.58.1
        - pygments:          2.15.1
        - pyjwt:             2.7.0
        - pykitti:           0.3.1
        - pyliblzfse:        0.4.1
        - pylint:            2.16.2
        - pymcubes:          0.1.4
        - pymeshlab:         2022.2.post4
        - pymongo:           4.4.1
        - pynacl:            1.5.0
        - pyngrok:           6.0.0
        - pyntcloud:         0.3.1
        - pyopenssl:         23.1.1
        - pyparsing:         3.0.9
        - pyperclip:         1.8.2
        - pyppmd:            1.0.0
        - pyquaternion:      0.9.9
        - pyre-extensions:   0.0.23
        - pyrootutils:       1.0.4
        - pyrsistent:        0.19.3
        - pyserial:          3.5
        - pysocks:           1.7.1
        - pytest:            7.3.1
        - python-box:        6.1.0
        - python-dateutil:   2.8.2
        - python-dotenv:     1.0.0
        - python-editor:     1.0.4
        - python-engineio:   4.4.1
        - python-json-logger: 2.0.7
        - python-multipart:  0.0.6
        - python-slugify:    5.0.2
        - python-socketio:   5.8.0
        - python-version:    0.0.2
        - pytlsd:            0.0.3
        - pytorch-lightning: 2.0.5
        - pytsmod:           0.3.6
        - pytz:              2023.3
        - pyu2f:             0.1.5
        - pywavelets:        1.4.1
        - pyyaml:            6.0
        - pyzmq:             25.1.0
        - pyzstd:            0.15.9
        - qtconsole:         5.4.3
        - qtpy:              2.3.1
        - qudida:            0.0.4
        - rarfile:           4.0
        - raymarching:       0.1.3
        - readchar:          4.0.5
        - regex:             2023.5.5
        - requests:          2.30.0
        - requests-mock:     1.11.0
        - requests-oauthlib: 1.3.1
        - requests-toolbelt: 0.10.1
        - responses:         0.18.0
        - retrying:          1.3.4
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rich:              13.4.2
        - rsa:               4.9
        - rsl-depth-completion: 0.1.0
        - rsl-rl:            1.0.2
        - s3transfer:        0.6.1
        - safetensors:       0.3.1
        - scikit-build:      0.17.6
        - scikit-image:      0.20.0
        - scikit-learn:      1.2.2
        - scipy:             1.10.1
        - seaborn:           0.12.2
        - segmentation-models-pytorch: 0.3.3
        - semantic-version:  2.10.0
        - semver:            2.13.0
        - send2trash:        1.8.2
        - sentencepiece:     0.1.99
        - sentry-sdk:        1.22.1
        - setproctitle:      1.3.2
        - setuptools:        67.6.0
        - sfmnext:           0.0.1
        - shap:              0.41.0
        - shapely:           2.0.1
        - shtab:             1.6.1
        - simplejson:        3.19.1
        - six:               1.16.0
        - slicer:            0.0.7
        - smmap:             5.0.0
        - snakeviz:          2.2.0
        - sniffio:           1.3.0
        - sortedcontainers:  2.4.0
        - soundfile:         0.12.1
        - soupsieve:         2.4.1
        - soxr:              0.3.5
        - sqlalchemy:        2.0.18
        - sse-starlette:     0.10.3
        - sseclient-py:      1.7.2
        - stable-diffusion-sdkit: 2.1.3
        - stack-data:        0.6.2
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - stevedore:         5.1.0
        - strawberry-graphql: 0.138.1
        - sympy:             1.11.1
        - tabulate:          0.9.0
        - tdqm:              0.0.1
        - tenacity:          8.2.2
        - tensorboard:       2.11.2
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - tensorboardx:      2.6
        - tensorflow:        2.11.0
        - tensorflow-addons: 0.20.0
        - tensorflow-datasets: 4.9.0
        - tensorflow-estimator: 2.11.0
        - tensorflow-graphics: 2021.12.3
        - tensorflow-io-gcs-filesystem: 0.32.0
        - tensorflow-metadata: 1.13.0
        - tensorflow-model-optimization: 0.7.4
        - tensorflow-probability: 0.19.0
        - termcolor:         2.3.0
        - terminado:         0.17.1
        - terrain-reconstruction: 0.0.1
        - terrain-representation: 0.0.1
        - test-tube:         0.7.5
        - testing-practice:  0.0.1
        - text-unidecode:    1.3
        - texttable:         1.6.7
        - threadpoolctl:     3.1.0
        - tifffile:          2023.4.12
        - timm:              0.9.2
        - tinycss2:          1.2.1
        - tinycudann:        1.7
        - tokenizers:        0.13.3
        - toml:              0.10.2
        - tomli:             2.0.1
        - tomlkit:           0.11.1
        - toolz:             0.12.0
        - torch:             2.0.1+cu117
        - torch-ema:         0.3
        - torch-fidelity:    0.3.0
        - torchmetrics:      0.7.3
        - torchnet:          0.0.4
        - torchtyping:       0.1.4
        - torchvision:       0.15.2+cu117
        - tornado:           6.3.1
        - tqdm:              4.65.0
        - tr:                0.0.1
        - traitlets:         5.9.0
        - transformers:      4.26.1
        - trimesh:           3.21.5
        - triton:            2.0.0
        - ttach:             0.0.3
        - typeguard:         2.13.3
        - typeshed-client:   2.3.0
        - typing-extensions: 4.8.0
        - typing-inspect:    0.8.0
        - tyro:              0.5.2
        - tzdata:            2023.3
        - tzlocal:           5.0.1
        - ultralytics:       8.0.124
        - universal-analytics-python3: 1.1.1
        - uri-template:      1.2.0
        - urllib3:           1.26.15
        - uvicorn:           0.22.0
        - vedo:              2023.5.0
        - virtualenv:        20.23.1
        - visdom:            0.2.4
        - viser:             0.0.10
        - vision-mtl:        0.1.0
        - visu3d:            1.5.1
        - voxel51-eta:       0.10.0
        - vtk:               9.3.0
        - wandb:             0.15.2
        - waymo-open-dataset-tf-2-11-0: 1.6.0
        - wcwidth:           0.2.6
        - webcolors:         1.13
        - webencodings:      0.5.1
        - websocket-client:  1.3.3
        - websockets:        11.0.2
        - werkzeug:          2.3.3
        - wheel:             0.40.0
        - widgetsnbextension: 3.6.6
        - wrapt:             1.15.0
        - wsproto:           1.2.0
        - wurlitzer:         3.0.3
        - xarray:            2022.6.0
        - xatlas:            0.0.7
        - xformers:          0.0.16
        - xlsxwriter:        3.1.0
        - xmltodict:         0.13.0
        - xxhash:            3.2.0
        - y-py:              0.5.9
        - yacs:              0.1.8
        - yamllint:          1.31.0
        - yarl:              1.9.2
        - ypy-websocket:     0.8.2
        - zipp:              3.15.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.16
        - release:           5.19.0-1028-lowlatency
        - version:           #29~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Jun 21 09:41:37 UTC 2

</details>                                       

### More info

_No response_

@kirilllzaitsev Can you be more specific with what you mean by "stuck"?

MNIST is very easy to optimize for this classifier. After the first epoch we already get a > 90% accuracy on the validation set:

Epoch 1: 100%|██████████| 860/860 [00:08<00:00, 97.71it/s, v_num=54, train_loss=0.242, val_loss=0.288, val_acc=0.913]

At epoch 8 I get 96%:

Epoch 8: 100%|██████████| 860/860 [00:09<00:00, 95.09it/s, v_num=54, train_loss=0.025, val_loss=0.122, val_acc=0.964]

I don't see any issues with this code, can you point it out please?

torch 2.1.1, lightning 2.1.3, torchmetrics 0.7.3

This is what I mean by "stuck":
image
Please note that this holds only for the Trainer-way, while the standard loop is fine.

Replacing all lightning.pytorch imports with pytorch_lightning worked. The same code that produced the above plots gives the following, having imports as the only change:
image
However, this didn't fix the other use case I'm working on (which I mentioned in the question) with the same problem.