lucidrains/byol-pytorch

How to train the better backbone with BYOL?

knaffe opened this issue · 1 comments

I use the example code (byol-pytorch/examples/lightning/train.py) with my own data(about 0.7 billion images) to get the improved resnet50.
But I replace the new resnet50 model with the open source pretrained resnet50, my target task is worse than the original one(My target task is metric learning).
I try to train BYOL for 50, 100, 200 epochs, but the target task result is worse.
Do I miss some settings? How could I eval the trained BYOL/Improved model ?
The Code is below and I train the model with 4 V100 GPUs.


train_mode = True
load_self_pretrain = True

if train_mode:
    model = SelfSupervisedLearner(
        resnet,
        image_size = IMAGE_SIZE,
        hidden_layer = 'avgpool',
        projection_size = 256,
        projection_hidden_size = 4096,
        moving_average_decay = 0.99
    )
    if load_self_pretrain:
        # state_dict = load_state_dict_from_url(model_urls[arch],
        #                                         progress=progress)
        state_dict = torch.load(args.pretrain_path, map_location=torch.device('cpu'))
        resnet.load_state_dict(state_dict)

    ds = ImagesDataset(args.image_folder, IMAGE_SIZE)
    train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
    if NUM_GPUS >=2:
        distributed_backend = "ddp"
    else:
        distributed_backend = None
    trainer = pl.Trainer(
        gpus = NUM_GPUS,
        max_epochs = EPOCHS,
        accumulate_grad_batches = 1,
        default_root_dir = save_path,
        distributed_backend=distributed_backend
    )

    trainer.fit(model, train_loader)
    improved_resnet = model.learner.net
    torch.save(improved_resnet.state_dict(), './improved-resnet50.pth')  

I have an same confusion, does we need to train model with byol on the first, and then use the weights train the whole model or only train the rest of the model?