Questions about Memory
9p15p opened this issue · 2 comments
Thanks for your great work.
I have two questions about memory_bus
and memory_pos
.
The first one:
In the paper, memory tokens helps features in different frames communicate with each other.
However, In the code, It seems the communications is designed for communication among layers.
for layer_idx in range(self.num_layers):
output = torch.cat((output, memory_bus))
output = self.enc_layers[layer_idx](output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
output, memory_bus = output[:hw, :, :], output[hw:, :, :]
memory_bus = memory_bus.view(M, bs, t, c).permute(2,1,0,3).flatten(1,2) # TxBMxC
memory_bus = self.bus_layers[layer_idx](memory_bus)
memory_bus = memory_bus.view(t, bs, M, c).permute(2,1,0,3).flatten(1,2) # MxBTxC
The second one:
It seems self.memory_bus
and self.memory_pos
are not updated. Intuitively, I guess it will be helpful if it is updated along with frames.
self.memory_bus = torch.nn.Parameter(torch.randn(num_memory_bus, d_model))
self.memory_pos = torch.nn.Parameter(torch.randn(num_memory_bus, d_model))
if num_memory_bus:
nn.init.kaiming_normal_(self.memory_bus, mode="fan_out", nonlinearity="relu")
nn.init.kaiming_normal_(self.memory_pos, mode="fan_out", nonlinearity="relu")
self.return_intermediate_dec = return_intermediate_dec
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def pad_zero(self, x, pad, dim=0):
if x is None:
return None
pad_shape = list(x.shape)
pad_shape[dim] = pad
return torch.cat((x, x.new_zeros(pad_shape)), dim=dim)
def forward(self, src, mask, query_embed, pos_embed, is_train):
# prepare for enc-dec
bs = src.shape[0] // self.num_frames if is_train else 1
t = src.shape[0] // bs
_, c, h, w = src.shape
memory_bus = self.memory_bus
memory_pos = self.memory_pos
# encoder
src = src.view(bs*t, c, h*w).permute(2, 0, 1) # HW, BT, C
frame_pos = pos_embed.view(bs*t, c, h*w).permute(2, 0, 1) # HW, BT, C
frame_mask = mask.view(bs*t, h*w) # BT, HW
src, memory_bus = self.encoder(src, memory_bus, memory_pos, src_key_padding_mask=frame_mask, pos=frame_pos, is_train=is_train)
# decoder
dec_src = src.view(h*w, bs, t, c).permute(2, 0, 1, 3).flatten(0,1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # Q, B, C
tgt = torch.zeros_like(query_embed)
dec_pos = pos_embed.view(bs, t, c, h*w).permute(1, 3, 0, 2).flatten(0,1)
dec_mask = mask.view(bs, t*h*w) # B, THW
clip_hs = self.clip_decoder(tgt, dec_src, memory_bus, memory_pos, memory_key_padding_mask=dec_mask,
pos=dec_pos, query_pos=query_embed, is_train=is_train)
ret_memory = src.permute(1,2,0).reshape(bs*t, c, h, w)
return clip_hs, ret_memory
Do I misunderstand something?
Hi @9p15p ,
I believe your two questions can be answered as a whole.
IFC/projects/IFC/ifc/models/transformer.py
Line 134 in fb2ee45
Here, we separate concatenated spatial tokens and memory tokens apart.
Since a batch is composed of multiple frames in a video, the memory tokens are actually from multiple frames.
Then, the memory tokens communicate using the lines below.
IFC/projects/IFC/ifc/models/transformer.py
Lines 136 to 138 in fb2ee45
The memory tokens also get updated at each encoder layer as the for loop iterates the lines over and over.
IFC/projects/IFC/ifc/models/transformer.py
Line 129 in fb2ee45
If you have more questions, feel free to ask.
Thank you
Thank you! I get it.