sthalles/SimCLR

How do i train the SimCLR model with my local dataset?

bestalllen opened this issue · 5 comments

Dear researcher,
Thank you for the open-source code you provided, it is of great help to me for understanding contrastive learning.
But I still have some confusion when training the SimCLR model with my local dataset, could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

Hi! I am no the author, but I would like to give you some advice. To use your own dataset, the only part you need to modify is here:

`
def get_dataset(self, name, n_views):
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(32),
n_views),
download=True),

                      'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                      transform=ContrastiveLearningViewGenerator(
                                                          self.get_simclr_pipeline_transform(96),
                                                          n_views),
                                                      download=True)}

these lines of code are to create a dataset, where the author uses the public dataset from pytorch. In your case, you should write a pytorch dataset class and replace these code. Note that you don't forget to includetransform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(your image size),
n_views),`

Hope this can help~

Hi! I am no the author, but I would like to give you some advice. To use your own dataset, the only part you need to modify is here:

`
def get_dataset(self, name, n_views):
valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(32),
n_views),
download=True),

                      'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                      transform=ContrastiveLearningViewGenerator(
                                                          self.get_simclr_pipeline_transform(96),
                                                          n_views),
                                                      download=True)}

these lines of code are to create a dataset, where the author uses the public dataset from pytorch. In your case, you should write a pytorch dataset class and replace these code. Note that you don't forget to includetransform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(your image size),
n_views),`

Hope this can help~

Ok, thanks for your reply. I have implemented this experiment on my own dataset, still thanks for your help!

hello, how do you change the dataset,could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

hello, how do you change the dataset,could you give me some guidance or tips? I would appreciate it if you could reply to this issue.

I would like to help. Would you like to provide me with a more specific problem that you came accross.

Hello, I want to use my data set, the current institution is like this

  • mydata
    -data0
    -figure1.jpg
    -figure2.jpg
    -figure3.jpg
    -....
    -data1
    -figure1.jpg
    -figure2.jpg
    -figure3.jpg
    -....

From data0 to data9,There are 10 categories.
How do I generate an author-like dataset with class_names.txt fold_indices.txt test_X.bin and so on
I would appreciate it if you could reply to this issue.