[Bug]: - Layer call changes/breaks graph?
HeinrichAD opened this issue · 2 comments
Module
Layers
Contact Details
No response
Current Behavior
It seems that a layer call changes/breaks its tensor graph.
The error is easily visible when trying to copy a layer after this layer was called. A deepcopy
at this point throws a runtime error:
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
It should be noted that this is not the case if the layer is exported (vanilla_export
) before its called.
Expected Behavior
A layer call shouldn't change/break its tensor graph. And a deepcopy
should be possible even after a layer call.
Version
v0.1.0
Environment
- OS: Linux arch 5.18.9-arch1-1
- Python version: 3.7
- PyTorch version: 1.11.0+cu102
- Packages used version: deel-torchlip torch
Relevant log output
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
To Reproduce
The example uses SpectalLinear
but this error is not layer dependent. I also tested FrobeniusLinear
, SpectralConv2d
and FrobeniusConv2d
and got the same result.
#!/usr/bin/env python3
from copy import deepcopy
from deel.torchlip import SpectralLinear
import torch
lin = SpectralLinear(1, 1)
#lin = lin.vanilla_export()
x0 = torch.tensor([1.])
copy = deepcopy(lin) # ok
lin(x0)
copy = deepcopy(lin) # fail if not vanilla_export
# RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
Hello @HeinrichAD,
The Lipschitz layers use reparameterization to compute normalized weights. The SpectralLinear
and SpectralConv2d
modules are in particular based on torch.nn.utils.spectral_norm
. However, it seems that deepcopying layers with weight reparameterization in PyTorch does not work (see this or this issue). And they unfortunately don't provide any workaround.
@cofri thank you for clarification.
That's unfortunately but apparently seems to be unavoidable.
A possible workaround is to always evaluate etc. on the vanilla export. That's fine for me.