WandbLogger will cause error on TPU v3-8
buoyancy99 opened this issue · 0 comments
buoyancy99 commented
Bug description
Using WandbLogger with 8 TPUs will error out.
Here are a few ablations I did:
1 TPU + WandbLogger -> works
8 TPU -> works
8 TPU + WandbLogger -> fails
I provide a minimal code example below
What version are you seeing the problem on?
v2.4
How to reproduce the bug
Here is a minimal example following official examples of TPU & WandbLogger
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms
from lightning.pytorch.loggers import WandbLogger
from torchvision.datasets import MNIST
wandb_logger = WandbLogger(project="MNIST")
# Note - you must have torchvision installed for this example
BATCH_SIZE = 1024
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
class LitModel(pl.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
super().__init__()
self.save_hyperparameters()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes),
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = pl.Trainer(
logger=wandb_logger,
max_epochs=3,
accelerator="tpu",
devices=8,
)
# Train
trainer.fit(model, dm)
Error messages and logs
/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/__init__.py:202: UserWarning: `tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when using PyTorch/XLA. To silence this warning, `pip uninstall -y tensorflow && pip install tensorflow-cpu`. If you are in a notebook environment such as Colab or Kaggle, restart your notebook runtime afterwards.
warnings.warn(
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: buoyancy99 (scene-representation-group). Use `wandb login --relogin` to force relogin
wandb: wandb version 0.18.0 is available! To upgrade, please run:
wandb: $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.17.7
wandb: Run data is saved locally in ./wandb/run-20240913_032210-ja31ya3a
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run solar-grass-3
wandb: ⭐️ View project at https://wandb.ai/scene-representation-group/MNIST
wandb: 🚀 View run at https://wandb.ai/scene-representation-group/MNIST/runs/ja31ya3a
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/dengmingyang2/.local/lib/python3.10/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'ja31ya3a'
Exception in thread SockSrvRdThr:
Exception in thread SockSrvRdThr:
The above Exception in thread SockSrvRdThr:
will appear many times due to multithreading, and then it's repetition of following error
Traceback (most recent call last):
File "/home/boyuan/debug.py", line 106, in <module>
trainer.fit(model, dm)
File "/home/boyuan/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/home/boyuan/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 46, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/home/boyuan/.local/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/xla.py", line 98, in launch
process_context = xmp.spawn(
File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 214, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 174, in run_multiprocess
replica_results = list(
File "/home/boyuan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 175, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
wandb.errors.UsageError: Unable to attach to run ja31ya3a
Environment
Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.4
#- Python version (e.g., 3.12): 3.10
#- OS (e.g., Linux): linux
#- CUDA/cuDNN version: TPU
#- GPU models and configuration: TPU-v3, 8 TPU cores
#- How you installed Lightning(`conda`, `pip`, source): pip
More info
No response