sukjunhwang/IFC

Questions about Memory

9p15p opened this issue · 2 comments

9p15p commented

Thanks for your great work.

I have two questions about memory_bus and memory_pos.

The first one:
In the paper, memory tokens helps features in different frames communicate with each other.
However, In the code, It seems the communications is designed for communication among layers.

        for layer_idx in range(self.num_layers):
            output = torch.cat((output, memory_bus))

            output = self.enc_layers[layer_idx](output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
            output, memory_bus = output[:hw, :, :], output[hw:, :, :]

            memory_bus = memory_bus.view(M, bs, t, c).permute(2,1,0,3).flatten(1,2) # TxBMxC
            memory_bus = self.bus_layers[layer_idx](memory_bus)
            memory_bus = memory_bus.view(t, bs, M, c).permute(2,1,0,3).flatten(1,2) # MxBTxC

The second one:
It seems self.memory_bus and self.memory_pos are not updated. Intuitively, I guess it will be helpful if it is updated along with frames.

memory_bus = self.memory_bus

self.memory_bus = torch.nn.Parameter(torch.randn(num_memory_bus, d_model))
        self.memory_pos = torch.nn.Parameter(torch.randn(num_memory_bus, d_model))
        if num_memory_bus:
            nn.init.kaiming_normal_(self.memory_bus, mode="fan_out", nonlinearity="relu")
            nn.init.kaiming_normal_(self.memory_pos, mode="fan_out", nonlinearity="relu")

        self.return_intermediate_dec = return_intermediate_dec

        self.d_model = d_model
        self.nhead = nhead

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def pad_zero(self, x, pad, dim=0):
        if x is None:
            return None
        pad_shape = list(x.shape)
        pad_shape[dim] = pad
        return torch.cat((x, x.new_zeros(pad_shape)), dim=dim)

    def forward(self, src, mask, query_embed, pos_embed, is_train):
        # prepare for enc-dec
        bs = src.shape[0] // self.num_frames if is_train else 1
        t = src.shape[0] // bs
        _, c, h, w = src.shape

        memory_bus = self.memory_bus
        memory_pos = self.memory_pos

        # encoder
        src = src.view(bs*t, c, h*w).permute(2, 0, 1)               # HW, BT, C
        frame_pos = pos_embed.view(bs*t, c, h*w).permute(2, 0, 1)   # HW, BT, C
        frame_mask = mask.view(bs*t, h*w)                           # BT, HW

        src, memory_bus = self.encoder(src, memory_bus, memory_pos, src_key_padding_mask=frame_mask, pos=frame_pos, is_train=is_train)

        # decoder
        dec_src = src.view(h*w, bs, t, c).permute(2, 0, 1, 3).flatten(0,1)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)     # Q, B, C
        tgt = torch.zeros_like(query_embed)

        dec_pos = pos_embed.view(bs, t, c, h*w).permute(1, 3, 0, 2).flatten(0,1)
        dec_mask = mask.view(bs, t*h*w)                             # B, THW

        clip_hs = self.clip_decoder(tgt, dec_src, memory_bus, memory_pos, memory_key_padding_mask=dec_mask,
                                    pos=dec_pos, query_pos=query_embed, is_train=is_train)

        ret_memory = src.permute(1,2,0).reshape(bs*t, c, h, w)

        return clip_hs, ret_memory

Do I misunderstand something?

Hi @9p15p ,

I believe your two questions can be answered as a whole.

output, memory_bus = output[:hw, :, :], output[hw:, :, :]

Here, we separate concatenated spatial tokens and memory tokens apart.
Since a batch is composed of multiple frames in a video, the memory tokens are actually from multiple frames.
Then, the memory tokens communicate using the lines below.
memory_bus = memory_bus.view(M, bs, t, c).permute(2,1,0,3).flatten(1,2) # TxBMxC
memory_bus = self.bus_layers[layer_idx](memory_bus)
memory_bus = memory_bus.view(t, bs, M, c).permute(2,1,0,3).flatten(1,2) # MxBTxC

The memory tokens also get updated at each encoder layer as the for loop iterates the lines over and over.

for layer_idx in range(self.num_layers):

If you have more questions, feel free to ask.
Thank you

9p15p commented

Thank you! I get it.