Lightning-Universe/lightning-bolts

VisionDataModule set/get transform doesn't change datset transform

jascase901 opened this issue ยท 0 comments

๐Ÿ› Bug

Setting the transform of the data module, should change the transform of the underlying dataset.

import pl_bolts                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                          
from pl_bolts.datamodules import MNISTDataModule                                                                                                                                                                                                                                          
from torchvision import transforms as transform_lib                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                          
mnist = MNISTDataModule(data_dir = "/tmp/mnist")                                                                                                                                                                                                                                          
mnist.prepare_data()                                                                                                                                                                                                                                                                      
mnist.setup(stage="fit")                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                          
print("before set_transform")                                                                                                                                                                                                                                                             
print(mnist.dataset_train.dataset.transforms)                                                                                                                                                                                                                                             
#                                                                                                                                                                                                                                                                                         
#                                                                                                                                                                                                                                                                                         
# Expect this to change the train dataset transform?                                                                                                                                                                                                                                      
mnist.train_transforms = transform_lib.Compose(                                                                                                                                                                                                                                           
    [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.6,), std=(0.5,))]                                                                                                                                                                                                          
)                                                                                                                                                                                                                                                                                         
#                                                                                                                                                                                                                                                                                         
# expect to print the new transform                                                                                                                                                                                                                                                       
print("after transform")                                                                                                                                                                                                                                                                  
print(mnist.dataset_train.dataset.transforms) 

Results

before set_transform                                                                                                                                                                                                                                                                      
StandardTransform                                                                                                                                                                                                                                                                         
Transform: Compose(                                                                                                                                                                                                                                                                       
               ToTensor()                                                                                                                                                                                                                                                                 
           )                                                                                                                                                                                                                                                                              
after transform                                                                                                                                                                                                                                                                           
StandardTransform                                                                                                                                                                                                                                                                         
Transform: Compose(                                                                                                                                                                                                                                                                       
               ToTensor()                                                                                                                                                                                                                                                                 
           )       

Expected

I expected the datset transform to differ after I set the transform

Environment

  • PyTorch Version (e.g., 1.0):1,13.1+c117
  • OS (e.g., Linux):linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 11