omerbt/TokenFlow

confusions between reshape_heads_to_batch_dim and heads_to_batch_dim

liubo-cs opened this issue · 0 comments

There are some confusing parts between the usage of the two pairs in tokenflow_utils.py:
1 reshape_heads_to_batch_dim and head_to_batch_dim
2 reshape_batch_dim_to_heads and batch_dim_to_head

For example, in tokenflow_utils.py, head_to_batch_dim appears in two blocks (in line 140, and 241 respectively)

to run the pnp example successfully, I added a line before the block at line 241 like
self.head_to_batch_dim = self.reshape_heads_to_batch_dim
and it works. But to run the sdedit example successful, I need to add this line in a different place (in the line 140 block).

Same for the pair of {reshape_batch_dim_to_heads and batch_dim_to_head}. Is there a principled method to tackle this?