lucidrains/h-transformer-1d

Algorithm Mismatch

jinmang2 opened this issue · 3 comments

Paper Implementation

In the implementation, we get blocked Q, K, V tensors by level with the code below.

qkvs = [(q, k, v, mask)]
for level in range(num_levels):
q, k, v = map(lambda t: rearrange(t, 'b (n r) d -> b n r d', r = 2), (q, k, v))
if exists(mask):
mask = repeat(mask, 'b (n r) -> b n r', r = 2)
# masked mean for queries and keys, but not values
q = masked_aggregate(q, mask, dim = 2)
k = masked_aggregate(k, mask, dim = 2)
v = masked_aggregate(v, mask, dim = 2, average = False)
if exists(mask):
mask = torch.any(mask, dim = 2)

And return the final result of matrix-matrix product with Equation 29 or 69 with the for loop below.

Y = 0
A = 0
for Y_level, A_level in Ys:
if torch.is_tensor(Y):
Y = repeat(Y, 'b n d -> b (n r) d', r = 2)
if torch.is_tensor(A):
A = repeat(A, 'b n -> b (n r)', r = 2)
Y = Y_level + Y
A = A_level + A
out = Y / rearrange(A + eps, 'b n -> b n ()')

What is problem?

However, according to the current code, it is not possible to include information about the level 0 white blocks in the figure below. (Equation 70 of the paper includes the corresponding attention matrix entries.)

fig2

I think you should also add an off-diagonal term of near-interaction (level 0) to match Equation 69!

@jinmang2 Hi MyungHoon! I think you are right, thank you for catching this bug - I've released the changes in 0.1.6 https://github.com/lucidrains/h-transformer-1d/releases/tag/0.1.6 , do you want to review this and see if it resolve the issue you described?

Thanks for solving it so quickly 😄

Thank you for the opportunity to review, I will check the 0.1.6 version code and leave a comment!

closing, since i think its fine now