ttlmh/Bridge-Prompt

数据集准备

Opened this issue · 1 comments

Hi,非常感谢你们的工作。我希望在自己的数据集上复现你们的网络,遇到了以下几个问题:

  1. 为什么datasets.py下对一个数据集(比如Breakfast)有很多个类(Breakfast, Breakfast_feat, Breakfast_acti, Breakfast_FRAMES...)
  2. 在GTEA,SALADS dataset的getitem下,都有如下代码:
    if self.pretrain: vid = torch.from_numpy(vid) vid = torch.unique_consecutive(vid) vid = vid.numpy() vid = np.ma.masked_equal(vid, 0) vid = vid.compressed() vid = np.pad(vid, (0, 10 - vid.shape[0]), 'constant', constant_values=(0, -1))
    我认为这部分代码的意思是用来统计一个clip中出现的动作,用于生成text prompt。但是这改变了原始的label标签,将[0,0,0,0,0,1,1,2,2,2,2,2,3,3,3,4]变成了[ 1, 2, 3, 4, -1, -1, -1, -1, -1, -1]。这种改变后的label是如何计算loss的?
ttlmh commented

Hi, thanks for your interest in our work!

  1. The other Classes were intended for experimental use during development, and they are unused in the final version of Bridge-Prompt.

  2. Before processing, raw labels like [0,0,0,0,0,1,1,2,2,2,2,2,3,3,3,4] are used to precisely label the action of each frame in a window. Post-processed labels like [ 1, 2, 3, 4, -1, -1, -1, -1, -1, -1] (the action 0 is masked for GTEA since it refers to background frames) are used to record the order and the count information of the actions that appear in a window. Our Bridge-Prompt approach does not trace the action for each frame. Instead, it deals with the information of the action sequence in a video clip. Thus, the later version is more suitable for our method. Besides, the final losses are computed based on vision-language matching instead of traditional classification approaches, thus the form of labels does not affect the final loss calculation.

Thanks!