QData/spacetimeformer

Training with custom dataset. Error: object is not callable

Opened this issue · 2 comments

I try to modify the training script to use my own data set for training (the original version is using CIFAR10/MNIST). My data set is a folder with images in jpg format.
I've got a problem with the data module.
Here's the modified code in train.py.

def create_dset(config):
    INV_SCALER = lambda x: x
    SCALER = lambda x: x
    NULL_VAL = None
    PLOT_VAR_IDXS = None
    PLOT_VAR_NAMES = None
    PAD_VAL = None


    if config.dset in ["mnist", "cifar"]:
        if config.dset == "mnist":
            config.target_points = 28 - config.context_points
            datasetCls = stf.data.image_completion.MNISTDset
            PLOT_VAR_IDXS = [18, 24]
            PLOT_VAR_NAMES = ["18th row", "24th row"]
        else:
            config.target_points = 32 * 32 - config.context_points
            datasetCls = stf.data.image_completion.CIFARDset
            PLOT_VAR_IDXS = [0]
            PLOT_VAR_NAMES = ["Reds"]
        DATA_MODULE = stf.data.DataModule(
            datasetCls=datasetCls,
            dataset_kwargs={"context_points": config.context_points},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )

        return (
            DATA_MODULE,
            INV_SCALER,
            SCALER,
            NULL_VAL,
            PLOT_VAR_IDXS,
            PLOT_VAR_NAMES,
            PAD_VAL,
        )
    # Try to use my own data set here
    elif config.dset == "custom":
        data_dir = "./spacetimeformer/mydata"

        # Define data transformations
        transform = transforms.Compose([
            transforms.Resize((256, 256)),  # Resize the images to a specific size
            transforms.ToTensor(),          # Convert images to tensors
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize the pixel values
        ])
        
        # Create the custom image dataset
        dataset = CUSTOMDset(data_dir, context_points=config.context_points, transform=transform)
        
        DATA_MODULE = stf.data.DataModule(
            datasetCls=dataset,
            dataset_kwargs={"context_points": config.context_points},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )

        # Rest of the values remain the same as in the "mnist" and "cifar" cases
        PLOT_VAR_IDXS = [0]
        PLOT_VAR_NAMES = ["Reds"]
        PAD_VAL = None
    
        return (
            DATA_MODULE,
            INV_SCALER,
            SCALER,
            NULL_VAL,
            PLOT_VAR_IDXS,
            PLOT_VAR_NAMES,
            PAD_VAL,
        )

I got error message when can the function later.
( data_module, inv_scaler, scaler, null_val, plot_var_idxs, plot_var_names, pad_val, ) = create_dset(args)
When I use it later in test_dataloader = data_module.test_dataloader().
It gives me error message like:

Traceback (most recent call last):
	File "train_my.py", line 525, in <module>
	  main(args)
	File "train_my.py", line 435, in main
	  test_dataloader = data_module.test_dataloader()
	File "/home/spacetimeformer/spacetimeformer/data/datamodule.py", line 37, in test_dataloader
	  return self._make_dloader("test", shuffle=shuffle)
	File "/home/spacetimeformer/spacetimeformer/data/datamodule.py", line 44, in _make_dloader
	  self.datasetCls(**self.dataset_kwargs, split=split),
TypeError: 'CUSTOMDset' object is not callable

I try to print the object and it gives me <spacetimeformer.data.datamodule.DataModule object at 0x7f59cd8f84c0>.

My environment is: Python=3.8, torch=2.0.1.

Any idea about the bug or how to train on custom data will be appreciated! I'm really new to pytorch lightning module and struggling to solve the problem.

@HiFaye4869 do you have any progress? I’m trying to use it on my own dataset as well

I haven't tried a custom dataset yet but I plan to do so soon. I have one question though, why are you not using the setup that's used in the repo? Mainly torch 1.11.0?