xxlong0/ESTDepth

Reproduce the result in your paper

Closed this issue · 8 comments

Hi, thanks for your great work!

I have a few questions about how to reproduce your evaluation results on Scannet.
First, I have tried to use the pretrained model you provided to evaluate the Scannet test set. I modified the eval_hybrid.py for evaluation. The evaluation metrics are calculated from your code (Test mode returns the depth_metrics). However, the results are not satisfactory. I'm considering whether it's because the model you provide is not the best or because of the code?
Besides, what is the difference between your two testing modes?

Hello. Could you please provide more details about your modification?
For the joint test mode, the model will utilize short-term temporal coherence, while for ESTM mode, the model will utilize long-term temporal coherence along the whole video provided. You can refer to the paper for more details.

Thanks for your reply! In "eval_hybrid.py", the function "test_scannet" uses class SevenScenes as the test dataset, so I change it to Scannet. And there are other changes for reading the dataset successfully. I will copy the modified code of function test_scannet and class Scannet below. Besides, I make sure that the interval of images is 10. (interval is determined by the index in the name of image from Scannet)
By the way, it will be so nice for you to offer your scripts for evaluation.

test_scannet:

def test_scannet(model, args):
    model.eval()
    # data loader
    # dataset, dataloader
    #test_dataset = SevenScenes(args.datapath, seq_length=args.seq_len,
    #                           seq_inter=args.seq_len - 2, frame_interval=10, eval_all=False)
    test_dataset = ScannetDataset(args.datapath, args.testlist, depth_min=args.depth_min,
                   depth_max=args.depth_max, mode='test', reloadscan=True,
                   n_frames=args.seq_len)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False,
                                 num_workers=4, pin_memory=True)

    evaluation_dir = args.evalpath + "_joint_seqlen" + str(args.seq_len)

    if not os.path.exists(evaluation_dir):
        os.makedirs(evaluation_dir)

    # main loop
    scenes_file = open(args.testlist, 'r')


    with torch.no_grad():
        for scene in scenes_file.readlines():
            scene = scene.rstrip()
            print(scene)
            rgb_dir = os.path.join(evaluation_dir, scene, 'rgb')
            init_depth_dir = os.path.join(evaluation_dir, scene, 'init_depth')
            init_prob_dir = os.path.join(evaluation_dir, scene, 'init_prob')
            init_probvolume_dir = os.path.join(evaluation_dir, scene, 'init_probvolume')

            refined_depth_dir = os.path.join(evaluation_dir, scene, 'refined_depth')
            refined_prob_dir = os.path.join(evaluation_dir, scene, 'refined_prob')
            refined_probvolume_dir = os.path.join(evaluation_dir, scene, 'refined_probvolume')

            dirs = [rgb_dir, init_depth_dir, init_prob_dir,
                    refined_depth_dir, refined_prob_dir,
                    init_probvolume_dir, refined_probvolume_dir]

            for dir in dirs:
                if not os.path.exists(dir):
                    os.makedirs(dir)

        #test_dataset.reset(scene)
        test_metrics = {"a1_0": 0, "a2_0": 0, "a3_0": 0, "abs_diff_0": 0, "abs_rel_0": 0,
                        "sq_rel_0": 0, "rmse_0": 0, "rmse_log_0": 0}
        pre_costs = None
        pre_cam_poses = None
        for index, sample in enumerate(test_dataloader):
            sample_cuda = tocuda(sample)

            outputs, metrics = model(sample_cuda["imgs"],
                                                      sample_cuda["cam_poses"],
                                                      sample_cuda["cam_intr"],
                                                      sample_cuda,
                                                      pre_costs=pre_costs,
                                                      pre_cam_poses=pre_cam_poses,
                                                      mode='test'
                                                      # pre_costs=None,
                                                      # pre_cam_poses=None
                                                      )
            intrinsic = sample_cuda["cam_intr"]

            dmaps_gt = sample_cuda['dmaps']
            dmasks_gt = sample_cuda['dmasks']
            image_outputs = {}

            #eval the depth, metrics, 0 is the refined depth
            for k, v in test_metrics.items():
                test_metrics[k] = test_metrics[k] + (metrics[k] - test_metrics[k]) / (index+1)
            if index % 100 == 0:
                #print(index)
                print("[TEST]: ", test_metrics)

            for img_i in range(dmaps_gt.shape[1] - 2):
                rgb_basename = os.path.basename(sample["img_path"][img_i + 1][0])

                _, img_ext = os.path.splitext(rgb_basename)
                rgb_filepath = os.path.join(rgb_dir, rgb_basename)
                # cv2.imwrite(rgb_filepath,
                #             cv2.cvtColor(sample["img_raws"][:, img_i + 1, :, :, :].squeeze().numpy(),
                #                          cv2.COLOR_RGB2BGR))

                if args.save_init_depth == "True":
                    init_depth = np.float16(outputs[("depth", img_i, 2)].squeeze(1).cpu().numpy())
                    init_depth_filepath = os.path.join(init_depth_dir,
                                                       rgb_basename.replace(img_ext, ".npy"))
                    np.save(init_depth_filepath, init_depth)

                    init_depth_color = colorize_depth(outputs[("depth", img_i, 2)].squeeze(1),
                                                      max_depth=5.0).permute(0, 2, 3, 1).squeeze().cpu().numpy()
                    init_depth_color_filepath = os.path.join(init_depth_dir,
                                                             rgb_basename.replace(img_ext,
                                                                                  ".jpg"))
                    cv2.imwrite(init_depth_color_filepath,
                                cv2.cvtColor(np.uint8(init_depth_color), cv2.COLOR_RGB2BGR))
                if args.save_init_prob == "True":
                    init_prob = colorize_probmap(outputs[("init_prob", img_i)].squeeze(1)).permute(0, 2, 3,
                                                                                                   1).squeeze().cpu().numpy()
                    init_prob_filepath = os.path.join(init_prob_dir,
                                                      rgb_basename.replace(img_ext, ".jpg"))
                    cv2.imwrite(init_prob_filepath, cv2.cvtColor(np.uint8(init_prob), cv2.COLOR_RGB2BGR))
                    np.save(init_prob_filepath.replace('jpg', 'npy'),
                            np.float16(outputs[("init_prob", img_i)].squeeze().cpu().numpy()))

                if args.save_refined_depth == "True":
                    refined_depth = np.float16(outputs[("depth", img_i, 0)].squeeze(1).cpu().numpy())
                    refined_depth_filepath = os.path.join(refined_depth_dir,
                                                          rgb_basename.replace(img_ext,
                                                                               ".npy"))
                    np.save(refined_depth_filepath, refined_depth)

                    refined_depth_color = colorize_depth(outputs[("depth", img_i, 0)].squeeze(1),
                                                         max_depth=5.0).permute(0, 2, 3,
                                                                                1).squeeze().cpu().numpy()
                    refined_depth_color_filepath = os.path.join(refined_depth_dir,
                                                                rgb_basename.replace(img_ext,
                                                                                     ".jpg"))
                    cv2.imwrite(refined_depth_color_filepath,
                                cv2.cvtColor(np.uint8(refined_depth_color), cv2.COLOR_RGB2BGR))

                if args.save_refined_prob == "True":
                    refined_prob = colorize_probmap(outputs[("fused_prob", img_i)].squeeze(1)).permute(0, 2,
                                                                                                       3,
                                                                                                       1).squeeze().cpu().numpy()
                    refined_prob_filepath = os.path.join(refined_prob_dir,
                                                         rgb_basename.replace(img_ext,
                                                                              ".jpg"))
                    cv2.imwrite(refined_prob_filepath,
                                cv2.cvtColor(np.uint8(refined_prob), cv2.COLOR_RGB2BGR))

                    np.save(refined_prob_filepath.replace('jpg', 'npy'),
                            np.float16(outputs[("fused_prob", img_i)].squeeze().cpu().numpy()))
        print(test_metrics)

scannet:

class ScannetDataset(data.Dataset):
    def __init__(self, dataset_path, split_txt=None, height=256, width=320, n_frames=5,
                 depth_min=0.1, depth_max=10., mode='train', reloadscan=False):
        super(ScannetDataset, self).__init__()
        self.dataset_path = dataset_path
        self.n_frames = n_frames
        self.height = height
        self.width = width
        self.depth_min = depth_min
        self.depth_max = depth_max

        self.reloadscan = reloadscan

        self.mode = mode  # train or test

        if os.path.exists(split_txt):
            self.scenes = _read_split_file(split_txt)
        else:
            self.scenes = sorted(os.listdir(self.dataset_path))

        self.build_dataset_index_train(r=self.n_frames)

        scale_w = self.width / 640.
        scale_h = self.height / 480.
        self.cam_intr = torch.tensor([[577.87 * scale_w, 0, 319.5 * scale_w],
                                      [0, 577.87 * scale_h, 239.5 * scale_h],
                                      [0, 0, 1]]).to(torch.float32)

        self.proc_totensor = m_preprocess.to_tensor()

    def __len__(self):
        return len(self.dataset_index)

    def shape(self):
        return [self.n_frames, self.height, self.width]

    def read_sample_train(self, index):
        data_blob = self.dataset_index[index]
        assert self.n_frames == data_blob['n_frames']

        images = []
        images_paths = []
        img_ids = []

        poses = []
        poses_paths = []
        pose_ids = []

        depths = []
        dmasks = []
        depths_paths = []
        depth_ids = []

        for i in range(self.n_frames):
            image = cv2.imread(data_blob['images'][i])

            img_id = re.findall(r'\d+', os.path.basename(data_blob['images'][i]))
            img_ids.append(img_id)

            images_paths.append(data_blob['images'][i])
            image = cv2.resize(image, (self.width, self.height))

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images.append(image)

            # load pose
            pose = np.loadtxt(data_blob['poses'][i], delimiter=' ').astype(np.float32)

            pose_id = re.findall(r'\d+', os.path.basename(data_blob['poses'][i]))
            pose_ids.append(pose_id)

            poses.append(pose)
            poses_paths.append(data_blob['poses'][i])

            # load depth
            depth = cv2.imread(data_blob['depths'][i], cv2.IMREAD_ANYDEPTH)
            depth = cv2.resize(depth, (self.width, self.height))

            depth_id = re.findall(r'\d+', os.path.basename(data_blob['depths'][i]))
            depth_ids.append(depth_id)

            depth = (depth.astype(np.float32)) / 1000.0

            dmask = (depth >= self.depth_min) & (depth <= self.depth_max) & (np.isfinite(depth))
            depth[~dmask] = 0

            ratio = np.sum(np.float32(dmask)) / (self.width * self.height)

            assert ratio > 0.5

            depths.append(depth)
            dmasks.append(dmask)
            depths_paths.append(data_blob['depths'][i])

        images = np.stack(images, axis=0).astype(np.float32)
        poses = np.stack(poses, axis=0).astype(np.float32)

        assert np.all(np.isfinite(poses))
        assert (img_ids == pose_ids) & (img_ids == depth_ids)

        depths = np.stack(depths, axis=0).astype(np.float32)
        dmasks = np.stack(dmasks, axis=0)

        return images, poses, depths, dmasks, img_ids, images_paths

    def __getitem__(self, index):
        #images, poses, depths, dmasks, frameid, images_paths = self.read_sample_train(index)
        if index >= len(self.dataset_index):
            images, poses, depths, dmasks, frameid, images_paths = self.read_sample_train(index)

        flag = True
        while flag:
            try:
                images, poses, depths, dmasks, frameid, images_paths = self.read_sample_train(index)

                flag = False
            except:
                tmp = np.random.randint(0, self.__len__(), 1)[0]
                print("data load error!", index, "use:  ", tmp)
                index = tmp

        # it seems that augment will influence accuracy
        # do_augument = np.random.uniform(0, 1, size=1)
        # if do_augument < 0.5:
        #     images = augument(images)

        sample = {
            'imgs': torch.from_numpy(images).permute(0, 3, 1, 2).to(torch.float32),  # [N,3,H,W]
            'dmaps': torch.from_numpy(depths).unsqueeze(1).to(torch.float32),  # [N,1,H,W]
            'dmasks': torch.from_numpy(dmasks).unsqueeze(1),  # [N,1,H,W]
            'cam_poses': torch.from_numpy(poses).to(torch.float32),  # [N,4,4]
            'cam_intr': self.cam_intr.to(torch.float32),
            'img_path': images_paths
        }

        return sample

    def _load_scan(self, scan, interval, if_dump=True):
        """

        :param scan:
        :param interval: 2 if train mode; 10 if test mode
        :return:
        """
        scan_path = os.path.join(self.dataset_path, scan)

        datum_file = os.path.join(scan_path, 'scene.npy')
        # really need to sample scene more densely (skip every 2 frames not 4)
        if (not os.path.exists(datum_file)) or self.reloadscan:
            #print("load ", datum_file, self.reloadscan, type(self.reloadscan))
            print("load", scan_path)
            imfiles = glob.glob(os.path.join(scan_path, 'pose', '*.txt'))
            ixs = sorted([int(os.path.basename(x).split('.')[0]) for x in imfiles])
            poses = []
            for i in ixs[::interval]:
                posefile = os.path.join(scan_path, 'pose', '%d.txt' % i)
                pose = np.loadtxt(posefile, delimiter=' ').astype(np.float32)

                if ~np.all(np.isfinite(pose)):
                    break
                else:
                    poses.append(posefile)

            images = []
            for i in ixs[::interval]:
                imfile = os.path.join(scan_path, 'color', '%d.jpg' % i)
                images.append(imfile)

            depths = []
            for i in ixs[::interval]:
                depthfile = os.path.join(scan_path, 'depth', '%d.png' % i)
                depths.append(depthfile)

            valid_num = len(poses)

            scene_info = {
                "images": images[:valid_num],
                "depths": depths[:valid_num],
                "poses": poses
            }

            if if_dump:
                np.save(datum_file, scene_info)
            return scene_info

        else:
            return np.load(datum_file, allow_pickle=True).item()

    def build_dataset_index_train(self, r=4):
        self.dataset_index = []
        data_id = 0
        skip = r // 2
        #skip = 1

        for scan in self.scenes:
            scanid = int(re.findall(r'scene(.+?)_', scan)[0])

            scene_info = self._load_scan(scan, interval=2)
            images = scene_info["images"]
            depths = scene_info["depths"]
            poses = scene_info["poses"]

            for i in range(r, len(images) - r, skip):
                training_example = {}
                training_example['depths'] = depths[i - r:i + r + 1]
                training_example['images'] = images[i - r:i + r + 1]
                training_example['poses'] = poses[i - r:i + r + 1]
                training_example['n_frames'] = r
                training_example['id'] = data_id

                self.dataset_index.append(training_example)
                data_id += 1

Sorry for the confusion, There is no need to change the "sevenscenes" dataset class to "scannet" dataset class, even if test on scannet dataset.
The data/scannet.py is only for training. If you use it for testing, its actual testing set will be different from that of the numbers reported in my paper.
Since the sevenscenes shares the same data structure with scannet dataset. Just prepare a directory, which contains rgb, depth, pose directories. And run the eval_hybrid.py on scannet, and calculate the metrics.

Thanks for your help! I will try it again.

Sorry to bother you again. But I have tried to evaluate the Scannet as you suggested. However, the metrics are still not good enough. Do I need to change any parameters? (btw, I have confirmed that the frame interval is 10.)
Here are the results in the ESTM test mode:
捕获

My bash script is:

python eval_hybrid_seq.py --seq_len 5 --summary_freq 10 --ndepths 64 \
--loadckpt ./checkpoint/model_000006.ckpt \
--datapath ../scannet_test \
--evalpath ./output/hybrid_EST_V4_ndepths64 \
--testlist ./data/scannet_split/test_split.txt --IF_EST_transformer True \
--depth_min 0.1 --depth_max 10. --save_init_prob False --save_refined_prob False

Hello. We run evaluation after we upsampe the pred_depth to the original size of gt depth (480, 640). The different resolutions will cause varying numbers. Just use the codes below to do the evaluation. If the problem still remains, let me know.

`
def cal_metrics_scannet(args):
dataDir = "./refined_depth"
gtDATA = ""
l1_errors_all = []
abs_relative_errors_all = []
sq_rel_errors_all = []
rmse_log_errors_all = []
rmse_errors_all = []
scale_invariant_errors_all = []
a1_errors_all = []
a2_errors_all = []
a3_errors_all = []

MIN_DEPTH = 0.1
MAX_DEPTH = 10.0

depth_types = ["refined_depth", "init_depth"]
for depth_type in depth_types:
    for scene in sorted(os.listdir(dataDir)):
        if not os.path.isdir(os.path.join(dataDir, scene)):
            continue

        print(scene)
        pred_depth_dir = os.path.join(dataDir, scene, depth_type)

        l1_errors = []
        abs_relative_errors = []
        sq_rel_errors = []
        rmse_log_errors = []
        rmse_errors = []
        scale_invariant_errors = []
        a1_errors = []
        a2_errors = []
        a3_errors = []
        count = 0

        for filename in sorted(os.listdir(pred_depth_dir)):
            if filename.endswith(".npy"):

                fileindex = [int(s) for s in re.findall(r'\d+', filename)][0]
                # check the finename
                gt_depth = cv2.imread(
                    os.path.join(gtDATA, scene, "frame-%06d.depth.pgm" % fileindex),
                    -1) / 1000.

                pred_depth = np.float32(np.load(
                    os.path.join(pred_depth_dir, filename)).squeeze())
                pred_depth = cv2.resize(pred_depth, (640, 480), cv2.INTER_LINEAR)

                valid_mask = compute_valid_depth_mask(gt_depth, min_thred=MIN_DEPTH, max_thred=MAX_DEPTH)
                h, w = np.shape(pred_depth)

                gt_depth = gt_depth[valid_mask]
                pred_depth = pred_depth[valid_mask]
                pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
                pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH
                pred_depth[pred_depth != pred_depth] = MAX_DEPTH

                assert (np.isfinite(np.sum(gt_depth)))
                assert (np.isfinite(np.sum(pred_depth)))
                if np.size(gt_depth) < 100:
                    continue

                count += 1

                l1_error = l1(gt_depth, pred_depth)
                abs_relative_error = abs_relative(depth_gt=gt_depth, depth_pred=pred_depth)
                rmse_error = rmse(gt_depth, pred_depth)
                scale_invariant_error = scale_invariant(gt_depth, pred_depth)
                sq_rel_error = sq_relative(pred_depth, gt_depth)
                rmse_log_error = rmse_log(gt_depth, pred_depth)

                a1 = ratio_threshold(gt_depth, pred_depth, 1.25)
                a2 = ratio_threshold(gt_depth, pred_depth, 1.25 * 1.25)
                a3 = ratio_threshold(gt_depth, pred_depth, 1.25 * 1.25 * 1.25)

                l1_errors.append(l1_error)
                abs_relative_errors.append(abs_relative_error)
                rmse_errors.append(rmse_error)
                sq_rel_errors.append(sq_rel_error)
                rmse_log_errors.append(rmse_log_error)
                scale_invariant_errors.append(scale_invariant_error)
                a1_errors.append(a1)
                a2_errors.append(a2)
                a3_errors.append(a3)

                l1_errors_all.append(l1_error)
                abs_relative_errors_all.append(abs_relative_error )
                rmse_errors_all.append(rmse_error)
                sq_rel_errors_all.append(sq_rel_error)
                rmse_log_errors_all.append(rmse_log_error)
                scale_invariant_errors_all.append(scale_invariant_error)
                a1_errors_all.append(a1)
                a2_errors_all.append(a2)
                a3_errors_all.append(a3)

        mean_l1_error = np.mean(np.array(l1_errors))
        mean_abs_relative_error = np.mean(np.array(abs_relative_errors))
        mean_sq_rel_error = np.mean(np.array(sq_rel_errors))
        mean_rmse_error = np.mean(np.array(rmse_errors))
        mean_rmse_log_error = np.mean(np.array(rmse_log_errors))
        mean_scale_invariant_error = np.mean(np.array(scale_invariant_errors))

        mean_a1_error = np.mean(np.array(a1_errors))
        mean_a2_error = np.mean(np.array(a2_errors))
        mean_a3_error = np.mean(np.array(a3_errors))
        print("mean_l1_error", mean_l1_error)
        print("a<1.25", mean_a1_error)
        print("a<1.25^2", mean_a2_error)
        print("a<1.25^3", mean_a3_error)
        print("abs.rel", mean_abs_relative_error)
        print("sq.rel", mean_sq_rel_error)
        print("rmse", mean_rmse_error)
        print("rmse_log", mean_rmse_log_error)
        print("scale.inv", mean_scale_invariant_error)

        error_np = [mean_l1_error, mean_a1_error, mean_a2_error, mean_a3_error, mean_abs_relative_error,
                    mean_sq_rel_error, mean_rmse_error, mean_rmse_log_error, mean_scale_invariant_error]
        np.save(os.path.join(dataDir, scene, depth_type + "_evaluation_depth_errors.npy"), [error_np, count])

        file = open(os.path.join(dataDir, scene, depth_type + "_evaluation_depth_errors.txt"), 'w+')
        file.write("mean_l1_error: " + str(mean_l1_error) + "\n")
        file.write("a<1.25: " + str(mean_a1_error) + "\n")
        file.write("a<1.25^2: " + str(mean_a2_error) + "\n")
        file.write("a<1.25^3: " + str(mean_a3_error) + "\n")
        file.write("abs.rel: " + str(mean_abs_relative_error) + "\n")
        file.write("sq.rel: " + str(mean_sq_rel_error) + "\n")
        file.write("rmse: " + str(mean_rmse_error) + "\n")
        file.write("rmse log: " + str(mean_rmse_log_error) + "\n")
        file.write("scale.inv: " + str(mean_scale_invariant_error) + "\n")
        file.close()

    mean_l1_error = np.mean(np.array(l1_errors_all)) 
    mean_abs_relative_error = np.mean(np.array(abs_relative_errors_all)) 
    mean_sq_rel_error = np.mean(np.array(sq_rel_errors_all)) 
    mean_rmse_error = np.mean(np.array(rmse_errors_all)) 
    mean_rmse_log_error = np.mean(np.array(rmse_log_errors_all)) 
    mean_scale_invariant_error = np.mean(np.array(scale_invariant_errors_all)) 

    mean_a1_error = np.mean(np.array(a1_errors_all)) 
    mean_a2_error = np.mean(np.array(a2_errors_all)) 
    mean_a3_error = np.mean(np.array(a3_errors_all)) 
    print("mean_l1_error", mean_l1_error)
    print("a<1.25", mean_a1_error)
    print("a<1.25^2", mean_a2_error)
    print("a<1.25^3", mean_a3_error)
    print("abs.rel", mean_abs_relative_error)
    print("sq.rel", mean_sq_rel_error)
    print("rmse", mean_rmse_error)
    print("rmse_log", mean_rmse_log_error)
    print("scale.inv", mean_scale_invariant_error)

    file = open(os.path.join(dataDir, depth_type + "_evaluation_errors_all.txt"), 'w+')
    file.write("mean_l1_error: " + str(mean_l1_error) + "\n")
    file.write("a<1.25: " + str(mean_a1_error) + "\n")
    file.write("a<1.25^2: " + str(mean_a2_error) + "\n")
    file.write("a<1.25^3: " + str(mean_a3_error) + "\n")
    file.write("abs.rel: " + str(mean_abs_relative_error) + "\n")
    file.write("sq.rel: " + str(mean_sq_rel_error) + "\n")
    file.write("rmse: " + str(mean_rmse_error) + "\n")
    file.write("rmse log: " + str(mean_rmse_log_error) + "\n")
    file.write("scale.inv: " + str(mean_scale_invariant_error) + "\n")
    file.close()

`

Hi, the current results are basically consistent with those in the paper. Thank you for your patient help!
捕获

Great! Close the issue, since it is solved.