hkchengrex/XMem

Inference on a stream rather than a folder, and speed problem

yarinbar opened this issue · 5 comments

First of all, thank you so much for this awesome package!

I am trying to use this package as a real-time processor for a constant stream of images.
What i was doing is write the following function:

    def __init__(self,
                 xmem_ckpt,
                 fbrs_ckpt,
                 s2m_ckpt,
                 **kwargs,
                 ):
        super().__init__()

        self.save_hyperparameters()

        self.image_size = 512
        self.num_classes = 6
        self.nnunet_ckpt = nnunet_ckpt
        self.xmem_ckpt = xmem_ckpt
        self.fbrs_ckpt = fbrs_ckpt
        self.s2m_ckpt = s2m_ckpt

        self.idxs_in_memory = []

        config = VIDEO_INFERENCE_CONFIG.copy()
        config['num_objects'] =  self.num_classes
        config['size'] = self.image_size
        config['fbrs_ckpt'] = fbrs_ckpt
        config['s2m_ckpt'] = s2m_ckpt
        config['enable_long_term'] = True
        config['enable_long_term_count_usage'] = True

        self.xmem = XMem(config, xmem_ckpt, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval()

        if xmem_ckpt is not None:
            model_weights = torch.load(xmem_ckpt)
            self.xmem.load_weights(model_weights, init_as_zero_if_needed=True)

        self.processor = InferenceCore(self.xmem, config)
        self.processor.set_all_labels(list(range(1, num_classes + 1)))

    def xmem_step(self, image, mask, idx):
        image = image.to(self.device)

        valid_labels = None

        with torch.cuda.amp.autocast(enabled=True):

            if mask is not None:
                mask = mask.to(self.device)
                valid_labels = range(1, self.num_classes + 1)
                self.processor.put_to_permanent_memory(image, mask)
                self.idxs_in_memory.append(idx)

            do_not_add_mask_to_memory = mask is not None
            prob = self.processor.step(image,
                                       mask,
                                       valid_labels=valid_labels,
                                       do_not_add_mask_to_memory=do_not_add_mask_to_memory,)

            out_mask = torch.argmax(prob, dim=0) - 1
            return out_mask

Currently my problems are the following:

  1. I want to use a stand alone image segmentation model to generate "ground truth" masks to keep xmem in check. The problem is that i do not know how can i retro-actively add images to memory and make it use them for the following frames.
  2. The first few images are generated very quickly but there is a significant loss in speed after like 10 frames or so - to the point that a new frame is out once every 4 seconds and eventually almost stops completely (I use GTX4090 and have 64GB ram and 13th Gen Intel(R) Core(TM) i7-13700 2.10 GHz)

Thanks!

  1. There are a lot of ways to do this (one is implemented in DEVA https://github.com/hkchengrex/Tracking-Anything-with-DEVA). I would start by looking at MemoryManager and KeyValueMemoryStore.
  2. I guess it is either adding more objects or not utilizing the long-term memory.

Where did you implement it in DEVA? i looked into the classes you mentioned but is there any example of how to use them properly?

First of all, thank you so much for this awesome package!

I am trying to use this package as a real-time processor for a constant stream of images. What i was doing is write the following function:

    def __init__(self,
                 xmem_ckpt,
                 fbrs_ckpt,
                 s2m_ckpt,
                 **kwargs,
                 ):
        super().__init__()

        self.save_hyperparameters()

        self.image_size = 512
        self.num_classes = 6
        self.nnunet_ckpt = nnunet_ckpt
        self.xmem_ckpt = xmem_ckpt
        self.fbrs_ckpt = fbrs_ckpt
        self.s2m_ckpt = s2m_ckpt

        self.idxs_in_memory = []

        config = VIDEO_INFERENCE_CONFIG.copy()
        config['num_objects'] =  self.num_classes
        config['size'] = self.image_size
        config['fbrs_ckpt'] = fbrs_ckpt
        config['s2m_ckpt'] = s2m_ckpt
        config['enable_long_term'] = True
        config['enable_long_term_count_usage'] = True

        self.xmem = XMem(config, xmem_ckpt, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval()

        if xmem_ckpt is not None:
            model_weights = torch.load(xmem_ckpt)
            self.xmem.load_weights(model_weights, init_as_zero_if_needed=True)

        self.processor = InferenceCore(self.xmem, config)
        self.processor.set_all_labels(list(range(1, num_classes + 1)))

    def xmem_step(self, image, mask, idx):
        image = image.to(self.device)

        valid_labels = None

        with torch.cuda.amp.autocast(enabled=True):

            if mask is not None:
                mask = mask.to(self.device)
                valid_labels = range(1, self.num_classes + 1)
                self.processor.put_to_permanent_memory(image, mask)
                self.idxs_in_memory.append(idx)

            do_not_add_mask_to_memory = mask is not None
            prob = self.processor.step(image,
                                       mask,
                                       valid_labels=valid_labels,
                                       do_not_add_mask_to_memory=do_not_add_mask_to_memory,)

            out_mask = torch.argmax(prob, dim=0) - 1
            return out_mask

Currently my problems are the following:

  1. I want to use a stand alone image segmentation model to generate "ground truth" masks to keep xmem in check. The problem is that i do not know how can i retro-actively add images to memory and make it use them for the following frames.
  2. The first few images are generated very quickly but there is a significant loss in speed after like 10 frames or so - to the point that a new frame is out once every 4 seconds and eventually almost stops completely (I use GTX4090 and have 64GB ram and 13th Gen Intel(R) Core(TM) i7-13700 2.10 GHz)

Thanks!

Hi, I encounter the same 'slow down' problem here, have u solved it?

long_term_memory is enabled, and no obvious frame change observed when speed drops down.

I found out this was because torch.topk became very slow when the input matrix exceeded a certain size, so I set "max_mid_term_frames" to 8 and everything works fine now.