lightly-ai/lightly

DINO fails with PyTorch Fabric + multiple GPU

Closed this issue ยท 8 comments

Based on this tutorial if you use PyTorch Fabric for distributed training it will fail during the backward pass when using more than 1 GPU.

Tested with PyTorch Lightning multiple GPU + DINO = works.
Tested with PyTorch Fabric single GPU + DINO = works.
Tested with PyTorch Fabric multiple GPU + DINO = fails.

Repro:

import copy

import lightning as L
from lightning.fabric import Fabric
import torch
import torchvision
from torch import nn

from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule


class DINO(L.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        backbone = nn.Sequential(*list(resnet.children())[:-1])
        input_dim = 512
        # instead of a resnet you can also use a vision transformer backbone as in the
        # original paper (you might have to reduce the batch size in this case):
        # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
        # input_dim = backbone.embed_dim

        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

        self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)

    def forward(self, x, teacher: bool = False):
        if teacher:
            y = self.student_backbone(x).flatten(start_dim=1)
            z = self.student_head(y)
        else:
            y = self.teacher_backbone(x).flatten(start_dim=1)
            z = self.teacher_head(y)
        return z

    def on_after_backward(self):
        self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)

# Configure the devices to 2 for a failure
fabric = Fabric(accelerator='cuda', num_nodes=1, devices=1)
fabric.launch()

torch.autograd.set_detect_anomaly(True)

input_dim = 512
model = DINO()


transform = DINOTransform()
dataset = torchvision.datasets.VOCDetection(
    "./data",
    download=True,
    transform=transform,
    target_transform=lambda t: 0,
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

dataloader = fabric.setup_dataloaders(dataloader)

criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(fabric.device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)

epochs = 10

print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in dataloader:
        fabric.barrier()
        views = batch[0]
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)
        views = [view.to(fabric.device) for view in views]
        global_views = views[:2]
        teacher_out = [model.forward(view, teacher=True) for view in global_views]
        student_out = [model.forward(view) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        fabric.backward(loss)
        model.on_after_backward()
        # We only cancel gradients of student head.
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Below is the error output when switching devices from 1 to 2:

You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

<user>/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
<user>/.venv/lib/python3.11/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
<user>/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
<user>/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,
<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,
<user>/.venv/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
  File "<user>/simple_test.py", line 101, in <module>
    teacher_out = [model.forward(view, teacher=True) for view in global_views]
  File "<user>/simple_test.py", line 101, in <listcomp>
    teacher_out = [model.forward(view, teacher=True) for view in global_views]
  File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py", line 139, in forward
    output = self._forward_module(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1593, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1411, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/simple_test.py", line 40, in forward
    y = self.student_backbone(x).flatten(start_dim=1)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py", line 97, in forward
    out = self.bn2(out)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py", line 175, in forward
    return F.batch_norm(
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2509, in batch_norm
    return torch.batch_norm(
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
<user>/.venv/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
  File "<user>/simple_test.py", line 101, in <module>
    teacher_out = [model.forward(view, teacher=True) for view in global_views]
  File "<user>/simple_test.py", line 101, in <listcomp>
    teacher_out = [model.forward(view, teacher=True) for view in global_views]
  File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py", line 139, in forward
    output = self._forward_module(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1593, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1411, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/simple_test.py", line 40, in forward
    y = self.student_backbone(x).flatten(start_dim=1)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py", line 97, in forward
    out = self.bn2(out)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py", line 175, in forward
    return F.batch_norm(
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2509, in batch_norm
    return torch.batch_norm(
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:111.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]: Traceback (most recent call last):
[rank1]:   File "<user>/simple_test.py", line 105, in <module>
[rank1]:     fabric.backward(loss)
[rank1]:   File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 448, in backward
[rank1]:     self._strategy.backward(tensor, module, *args, **kwargs)
[rank1]:   File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/strategies/strategy.py", line 191, in backward
[rank1]:     self.precision.backward(tensor, module, *args, **kwargs)
[rank1]:   File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/precision/precision.py", line 107, in backward
[rank1]:     tensor.backward(*args, **kwargs)
[rank1]:   File "<user>/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "<user>/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "<user>/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 11; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
[rank0]: Traceback (most recent call last):
[rank0]:   File "<user>/simple_test.py", line 105, in <module>
[rank0]:     fabric.backward(loss)
[rank0]:   File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 448, in backward
[rank0]:     self._strategy.backward(tensor, module, *args, **kwargs)
[rank0]:   File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/strategies/strategy.py", line 191, in backward
[rank0]:     self.precision.backward(tensor, module, *args, **kwargs)
[rank0]:   File "<user>/.venv/lib/python3.11/site-packages/lightning/fabric/plugins/precision/precision.py", line 107, in backward
[rank0]:     tensor.backward(*args, **kwargs)
[rank0]:   File "<user>/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 525, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "<user>/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "<user>/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 11; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
srun: error: task 1: Exited with exit code 1
srun: error: task 0: Exited with exit code 1

HI @Chrispresso, thank you for reporting the issue! Could you please add the error output to the description?

@philippmwirth just updated with the error and traceback. Main error though is this:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 11; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Thanks for providing the full error trace! To me it looks like the error happens in torchvision's resnet:

  File "<user>/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py", line 97, in forward
    out = self.bn2(out)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py", line 175, in forward
    return F.batch_norm(
  File "<user>/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2509, in batch_norm
    return torch.batch_norm(

Can you try a minimal example using only torchvision resnet (no Lightly prediction or projection heads) and with a dummy loss function instead of DINO?

Not sure if related but it seems there's a small bug in your code: You cancel the last layer gradients twice.

        model.on_after_backward()
        # We only cancel gradients of student head.
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)

Accidentally cancelling gradients twice wasn't the issue. Changed that and still see the problem which makes sense since it's only removing the need for a gradient update.

I've swapped code to use just resnet18 and don't see a problem with 2 devices. The following code runs without a problem:

from lightning.fabric import Fabric
import torch
import torchvision
from torch import nn
import torch.nn.functional as F


fabric = Fabric(accelerator='cuda', num_nodes=1, devices=2)
fabric.launch()

torch.autograd.set_detect_anomaly(True)

input_dim = 512
model = torchvision.models.resnet18()


transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop((96, 96)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.cat([x, x, x], 1))
])
dataset = torchvision.datasets.VOCDetection(
    "./data",
    download=True,
    transform=transform,
    target_transform=lambda t: 0,
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

dataloader = fabric.setup_dataloaders(dataloader)

def criterion(yhat: torch.Tensor, _ignore):
    ones = torch.ones(yhat.shape, device=yhat.device)
    return F.mse_loss(yhat, ones)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)

epochs = 10

print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    for batch in dataloader:
        X = batch[0]
        fabric.barrier()
        pred = model(X)
        loss = criterion(pred, X)
        total_loss += loss.detach()
        fabric.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

If I swap the model to this:

class DINO(L.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        backbone = nn.Sequential(*list(resnet.children())[:-1])
        input_dim = 512
        # instead of a resnet you can also use a vision transformer backbone as in the
        # original paper (you might have to reduce the batch size in this case):
        backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
        input_dim = backbone.embed_dim

        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

        self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)

    def forward(self, x, teacher: bool = False):
        if teacher:
            y = self.student_backbone(x).flatten(start_dim=1)
            z = self.student_head(y)
        else:
            y = self.teacher_backbone(x).flatten(start_dim=1)
            z = self.teacher_head(y)
        return z

    def on_after_backward(self):
        self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)

which is just changing out the backbone from resnet18 to dino_vits16, then it works. So it seems like the issue is a combination of using resnet18 with fabric through DINO.

Curious, it could be due to the combination of multiple forward passes, batch norm, and distributed fabric. Did you already search for similar issues in https://github.com/Lightning-AI/pytorch-lightning/issues?

There wasn't anything there from what I could tell. I did track down this comment. I figured maybe the internal part of batch norm is referenced as a buffer (similar to DINOLoss). Took a shot and changed the below code:

from lightning.fabric.strategies.ddp import DDPStrategy
fabric = Fabric(accelerator='cuda', num_nodes=1, devices=2, strategy=DDPStrategy(broadcast_buffers=False))

This now works for training with resnet18 backbone across multiple devices. The issue seems to be that broadcast_buffers acts as an in-place operation. So only when you do multiple calls to forward do you see this. By explicitly using a DDPStrategy where broadcast_buffers=False, I seem to be able to get around this problem. So can probably close this

Thank you for figuring out the workaround! I will close this issue now. Feel free to reopen if anything else comes up.