I referred to #28, but the test result is like this, can you tell me the reason? Thank you
saisai1002 opened this issue · 3 comments
saisai1002 commented
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
cfg.CONST.WEIGHTS = '/home/baijinggroup/lss/Pix2Vox/pretrained_weights/Pix2Vox-A-ShapeNet.pth'
checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))
fix_checkpoint = {}
fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
epoch_idx = checkpoint['epoch_idx']
encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
encoder.eval()
decoder.eval()
refiner.eval()
merger.eval()
img1_path = '/home/baijinggroup/lss/Pix2Vox/datasets/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
#img1_path = '/home/baijinggroup/lss/Pix2Vox/datasets/ShapeNetRendering/04090263/1aa5498ac780331f782611375da5ea9a/rendering/00.png'
img1_np = np.asarray(Image.open(img1_path))
#img1_np = cv2.imread(img1_path)
sample = np.array([img1_np])
IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
test_transforms = utils.data_transforms.Compose([
utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
utils.data_transforms.ToTensor(),
])
rendering_images = test_transforms(rendering_images=sample)
rendering_images = rendering_images.unsqueeze(0)
with torch.no_grad():
image_features = encoder(rendering_images)
raw_features, generated_volume = decoder(image_features)
if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
generated_volume = merger(raw_features, generated_volume)
else:
generated_volume = torch.mean(generated_volume, dim=1)
if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
generated_volume = refiner(generated_volume)
generated_volume = generated_volume.squeeze(0)
img_dir = '/home/baijinggroup/lss/Pix2Vox/core/sample_images'
gv = generated_volume.cpu().numpy()
gv_new = np.swapaxes(gv, 2, 1)
rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir), epoch_idx)
This is the model's output:
The result after I added this code np.swapaxes(gv,2,1):
brian220 commented
I got a better result by using cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
to read the image. (Same as the author did in dataloader)
Code:
encoder = Encoder(cfg)
decoder = Decoder(cfg)
refiner = Refiner(cfg)
merger = Merger(cfg)
cfg.CONST.WEIGHTS = '/home/caig/Desktop/pix2vox/Pix2Vox/pretrain/Pix2Vox-A-ShapeNet.pth'
checkpoint = torch.load(cfg.CONST.WEIGHTS, map_location=torch.device('cpu'))
fix_checkpoint = {}
fix_checkpoint['encoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['encoder_state_dict'].items())
fix_checkpoint['decoder_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['decoder_state_dict'].items())
fix_checkpoint['refiner_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['refiner_state_dict'].items())
fix_checkpoint['merger_state_dict'] = OrderedDict((k.split('module.')[1:][0], v) for k, v in checkpoint['merger_state_dict'].items())
epoch_idx = checkpoint['epoch_idx']
encoder.load_state_dict(fix_checkpoint['encoder_state_dict'])
decoder.load_state_dict(fix_checkpoint['decoder_state_dict'])
if cfg.NETWORK.USE_REFINER:
print('Use refiner')
refiner.load_state_dict(fix_checkpoint['refiner_state_dict'])
if cfg.NETWORK.USE_MERGER:
print('Use merger')
merger.load_state_dict(fix_checkpoint['merger_state_dict'])
encoder.eval()
decoder.eval()
refiner.eval()
merger.eval()
img1_path = '/home/caig/Desktop/data/shapenet/ShapeNetRendering/02691156/1a04e3eab45ca15dd86060f189eb133/rendering/00.png'
img1_np = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
sample = np.array([img1_np])
IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
test_transforms = utils.data_transforms.Compose([
utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
utils.data_transforms.ToTensor(),
])
rendering_images = test_transforms(rendering_images=sample)
rendering_images = rendering_images.unsqueeze(0)
with torch.no_grad():
image_features = encoder(rendering_images)
raw_features, generated_volume = decoder(image_features)
if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
generated_volume = merger(raw_features, generated_volume)
else:
generated_volume = torch.mean(generated_volume, dim=1)
if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
generated_volume = refiner(generated_volume)
generated_volume = generated_volume.squeeze(0)
img_dir = '/home/caig/Desktop/pix2vox/Pix2Vox/outputs'
gv = generated_volume.cpu().numpy()
gv_new = np.swapaxes(gv, 2, 1)
rendering_views = utils.binvox_visualization.get_volume_views(gv_new, os.path.join(img_dir),
epoch_idx)
lui-shex commented
where were you to add this code in this project?
Thanks for response!
brian220 commented
where were you to add this code in this project?
Thanks for response!
I add this code in the core folder and give it a name (ex. test_single_img.py)
then modify the runner.py to run it
Change line 77 here to the class name in the new file (ex. in test_single_img.py)
(Please check test.py and runner.py for more details)
Thanks!