0.5.6 failing Multiple GPU Error
shivam-aggarwal opened this issue · 3 comments
shivam-aggarwal commented
Not able to run the code on multiple GPU using torch.nn.DataParallel.
@lucidrains this is the reason. self.hidden doesn't have corresponding key.
hidden = self.hidden[x.device]
KeyError: device(type='cuda', index=0)
This is the exact line .
Please help with the same.
Thanks!
lucidrains commented
@shivam-aggarwal can you give DDP a try? I think DDP is the preferred way to go distributed these days anyhow
lich99 commented
@shivam-aggarwal maybe you can try to replace the clear()
method to pop()
in function NetWrapper.get_representation()
looks like:
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden = self.hidden.pop(x.device)
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
opeide commented
I had the same issue with DDP and in my case the culprit was torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) (in my own code).
I guess a BN layer was my hidden output.