hzxie/Pix2Vox

I referred to #28, but the test result is like this, can you tell me the reason? Thank you

saisai1002 opened this issue · 3 comments

encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)

cfg.CONST.WEIGHTS = '/home/baijinggroup/lss/Pix2Vox/pretrained_weights/Pix2Vox-A-ShapeNet.pth'

checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))

fix_checkpoint = {}
fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())

epoch_idx = checkpoint['epoch_idx']
encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

encoder.eval()
decoder.eval()
refiner.eval()
merger.eval()

img1_path = '/home/baijinggroup/lss/Pix2Vox/datasets/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
#img1_path = '/home/baijinggroup/lss/Pix2Vox/datasets/ShapeNetRendering/04090263/1aa5498ac780331f782611375da5ea9a/rendering/00.png'	
img1_np = np.asarray(Image.open(img1_path))
#img1_np = cv2.imread(img1_path)

sample = np.array([img1_np])

IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W

test_transforms = utils.data_transforms.Compose([
    utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
    utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), 
    utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
    utils.data_transforms.ToTensor(),
])

rendering_images = test_transforms(rendering_images=sample)
rendering_images = rendering_images.unsqueeze(0)

with torch.no_grad():
    image_features = encoder(rendering_images)
    raw_features, generated_volume = decoder(image_features)

    if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
        generated_volume = merger(raw_features, generated_volume)
    else:
        generated_volume = torch.mean(generated_volume, dim=1)

    if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
        generated_volume = refiner(generated_volume)

generated_volume = generated_volume.squeeze(0)
img_dir = '/home/baijinggroup/lss/Pix2Vox/core/sample_images'
gv = generated_volume.cpu().numpy()
gv_new = np.swapaxes(gv, 2, 1)
rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir), epoch_idx)

This is the input:
image

This is the model's output:
image
The result after I added this code np.swapaxes(gv,2,1):
image

I got a better result by using cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. to read the image. (Same as the author did in dataloader)

voxels-000153

Code:

 encoder = Encoder(cfg)
 decoder = Decoder(cfg)
 refiner = Refiner(cfg)
 merger = Merger(cfg)

 cfg.CONST.WEIGHTS = '/home/caig/Desktop/pix2vox/Pix2Vox/pretrain/Pix2Vox-A-ShapeNet.pth'
 checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))
 
 fix_checkpoint = {}
 fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
 fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
 fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
 fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())
 
 epoch_idx = checkpoint['epoch_idx']
 encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
 decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])

 if cfg.NETWORK.USE_REFINER:
     print('Use refiner')
     refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
 if cfg.NETWORK.USE_MERGER:
     print('Use merger')
     merger.load_state_dict(fix_checkpoint['merger_state_dict'])

 
 encoder.eval()
 decoder.eval()
 refiner.eval()
 merger.eval()
 
 img1_path = '/home/caig/Desktop/data/shapenet/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
 img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
 
 sample = np.array([img1_np])
 
 IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
 CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
 
 test_transforms = utils.data_transforms.Compose([
     utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
     utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
     utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
     utils.data_transforms.ToTensor(),
 ])
 
 rendering_images = test_transforms(rendering_images=sample)
 rendering_images = rendering_images.unsqueeze(0)
 
 with torch.no_grad():
     image_features = encoder(rendering_images)
     raw_features, generated_volume = decoder(image_features)
 
     if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
         generated_volume = merger(raw_features, generated_volume)
     else:
         generated_volume = torch.mean(generated_volume, dim=1)
 
     if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
         generated_volume = refiner(generated_volume)

 generated_volume = generated_volume.squeeze(0)

 img_dir = '/home/caig/Desktop/pix2vox/Pix2Vox/outputs'
 gv = generated_volume.cpu().numpy()
 gv_new = np.swapaxes(gv, 2, 1)
 rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir),
                                                                                     epoch_idx)

where were you to add this code in this project?

Thanks for response!

where were you to add this code in this project?

Thanks for response!

I add this code in the core folder and give it a name (ex. test_single_img.py)
then modify the runner.py to run it

image

Change line 77 here to the class name in the new file (ex. in test_single_img.py)

(Please check test.py and runner.py for more details)

Thanks!