facebookresearch/FFCV-SSL

`torch.nn.Module` wrapper for transforms does not seem to work for custom transforms

akashc1 opened this issue · 1 comments

I tried creating custom transforms, and tried using the module wrapper. It seems that even a no-op transform in a nn.Module does not work in the FFCV framework. It appears to be a fundamental bug in the way the module wrapper is implemented:

  • FFCV statically uses numba to compile any operation
  • numba cannot resolve the type of a torch module
  • compilation fails for the operation.

I created a script to demonstrate this using a trivial no-op module: #8

class DummyModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, img):
        return img
grez72 commented

I was having a similar issue, and found that adding ToTensor() to the pipeline before my nn.Module transforms fixed it. Perhaps this will work in your case too:

from ffcv.transforms import ToTensor
from ffcv.fields.decoders import SimpleRGBImageDecoder
img_pipeline = [SimpleRGBImageDecoder(), ToTensor(), DummyModule()]