`torch.nn.Module` wrapper for transforms does not seem to work for custom transforms
akashc1 opened this issue · 1 comments
akashc1 commented
I tried creating custom transforms, and tried using the module wrapper. It seems that even a no-op transform in a nn.Module
does not work in the FFCV framework. It appears to be a fundamental bug in the way the module wrapper is implemented:
- FFCV statically uses
numba
to compile any operation numba
cannot resolve the type of a torch module- compilation fails for the operation.
I created a script to demonstrate this using a trivial no-op module: #8
class DummyModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, img):
return img
grez72 commented
I was having a similar issue, and found that adding ToTensor() to the pipeline before my nn.Module transforms fixed it. Perhaps this will work in your case too:
from ffcv.transforms import ToTensor
from ffcv.fields.decoders import SimpleRGBImageDecoder
img_pipeline = [SimpleRGBImageDecoder(), ToTensor(), DummyModule()]