Code snippet to reduce VRAM usage when too many frames to process.
hoveychen opened this issue · 0 comments
hoveychen commented
Base on #20, I've modified the code to reduce vram usage when processing.
Usage:
Replace the register_extended_attention_pnp()
function in tokenflow_utils.py
with the code snippet below.
def register_extended_attention_pnp(model, injection_schedule):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward_original(q, k, v):
n_frames, seq_len, dim = q.shape
h = self.heads
head_dim = dim // h
q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)
out_all = []
for frame in range(n_frames):
out = []
for j in range(h):
sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)
out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)
out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
out_all.append(out) # n_frames * (h, seq_len, head_dim)
out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
return out
def forward_extended(q, k, v):
n_frames, seq_len, dim = q.shape
h = self.heads
head_dim = dim // h
q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)
out_all = []
window_size = 3
for frame in range(n_frames):
out = []
# sliding window to improve speed.
window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))
for j in range(h):
sim_all = []
for kframe in window:
sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)
sim_all = torch.cat(sim_all).reshape(len(window), seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)
out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
out_all.append(out) # n_frames * (h, seq_len, head_dim)
out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
return out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
n_frames = batch_size // 3
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
# inject unconditional
q[n_frames:2 * n_frames] = q[:n_frames]
k[n_frames:2 * n_frames] = k[:n_frames]
# inject conditional
q[2 * n_frames:] = q[:n_frames]
k[2 * n_frames:] = k[:n_frames]
out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])
out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])
out = torch.cat([out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)
return to_out(out)
return forward
for _, module in model.unet.named_modules():
if isinstance_str(module, "BasicTransformerBlock"):
module.attn1.forward = sa_forward(module.attn1)
setattr(module.attn1, 'injection_schedule', [])
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
# we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
Note
The code slightly modified the extended attention method in the paper, where the self attentions are just extended across consecutive 3 key frames instead of all the key frames.