Lightning-AI/pytorch-lightning

WandbLogger will cause error on TPU v3-8

buoyancy99 opened this issue · 0 comments

Bug description

Using WandbLogger with 8 TPUs will error out.

Here are a few ablations I did:
1 TPU + WandbLogger -> works
8 TPU -> works
8 TPU + WandbLogger -> fails

I provide a minimal code example below

What version are you seeing the problem on?

v2.4

How to reproduce the bug

Here is a minimal example following official examples of TPU & WandbLogger

import pytorch_lightning as pl

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms
from lightning.pytorch.loggers import WandbLogger
from torchvision.datasets import MNIST

wandb_logger = WandbLogger(project="MNIST")

# Note - you must have torchvision installed for this example

BATCH_SIZE = 1024

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

class LitModel(pl.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer


# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=3,
    accelerator="tpu",
    devices=8,
)
# Train
trainer.fit(model, dm)

Error messages and logs

/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/__init__.py:202: UserWarning: `tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when using PyTorch/XLA. To silence this warning, `pip uninstall -y tensorflow && pip install tensorflow-cpu`. If you are in a notebook environment such as Colab or Kaggle, restart your notebook runtime afterwards.
  warnings.warn(
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: buoyancy99 (scene-representation-group). Use `wandb login --relogin` to force relogin
wandb: wandb version 0.18.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.17.7
wandb: Run data is saved locally in ./wandb/run-20240913_032210-ja31ya3a
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run solar-grass-3
wandb: ⭐️ View project at https://wandb.ai/scene-representation-group/MNIST
wandb: 🚀 View run at https://wandb.ai/scene-representation-group/MNIST/runs/ja31ya3a
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
    shandler(sreq)
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
    self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
    shandler(sreq)
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
    self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
    shandler(sreq)
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
    self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
    shandler(sreq)
  File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
    self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Exception in thread SockSrvRdThr:

The above Exception in thread SockSrvRdThr: will appear many times due to multithreading, and then it's repetition of following error

Traceback (most recent call last):
  File "/home/boyuan/debug.py", line 106, in <module>
    trainer.fit(model, dm)
  File "/home/boyuan/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/boyuan/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 46, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/boyuan/.local/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 98, in launch
    process_context = xmp.spawn(
  File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 214, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 174, in run_multiprocess
    replica_results = list(
  File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 175, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
wandb.errors.UsageError: Unable to attach to run ja31ya3a

Environment

Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.4
#- Python version (e.g., 3.12): 3.10
#- OS (e.g., Linux): linux
#- CUDA/cuDNN version: TPU
#- GPU models and configuration: TPU-v3, 8 TPU cores
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response