theislab/mubind

matmul in one-step instead of per batch in ActivitiesLayer forward

Closed this issue · 3 comments

@johschnee the following change in the selex module increases its speed significantly i.e. calculation in one step across all batches, instead of iterative for loop and indexing per batch.
old

# iterative multiplication
# for i in range(self.n_batches):
# eta = torch.exp(self.log_etas[i, :])
# out[batch == i] = out[batch == i] * eta

new
# multiplication in one step
etas = torch.exp(self.log_etas)
out = out * etas[batch, :]

Do you think this can be also done for the activities layer? Precisely, one-step calculation instead of a for-loop. This is the code snippet. Presumably, it seems non-doable due to the matmul step. Please discuss in case need more inpu.
for i in range(self.n_batches):
a = torch.exp(torch.stack(list(self.log_activities), dim=1)[i, :, :])
batch_mask = batch == i
b = binding_per_mode[batch_mask]
if self.ignore_kernel is not None:
mask = self.ignore_kernel != 1 # == False
scores[batch_mask] = torch.matmul(b[:, mask], a[mask, :])
else:
print(scores[batch_mask].shape)
print(b.shape, a.shape)
# assert False
assert false
scores[batch_mask] = torch.matmul(b, a)

Leaving some working example snippets below in case this is simpler to explain. Also sharing in StackOverflow in case input.

import numpy as np
p = 2
q = 9
r = 512
t = 4

# data initialization
np.random.seed(500)
# initialize random data
S = np.random.rand(r, q)
A = np.random.randint(0, 3, size=(p, t, q))
B = np.random.rand(r, t)
categories = np.random.randint(0, p, r)

# iterative (slow)
print('iterative')
for i in range(p):
    # print(i)
    a = A[i, :, :]
    mask = categories == i
    b = B[mask]
    print(b.shape, a.shape, S[mask].shape,
          np.matmul(b, a).shape)
    S[mask] = np.matmul(b, a)
print(S.shape, np.sum(np.matmul(b, a))) # # , np.sum(scores))

# iterative (by category)
print('')
print('by category')
np.random.seed(500)
scores = np.random.rand(r, q)
i = 0
print('category 0')
a = A[i, :, :]
mask = categories == i
b = B[mask]
print(b.shape, a.shape, S[mask].shape)
scores[mask] = np.matmul(b, a)
i = 1
print('category 0')
a = A[i, :, :]
mask = categories == i
b = B[mask]
print(b.shape, a.shape, S[mask].shape)
scores[mask] = np.matmul(b, a)
print(scores.shape, np.sum(np.matmul(b, a))) # # , np.sum(scores))

Problem here

# attempt to multiply once, indexing all categories only once (not possible)
np.random.seed(500)
S = np.random.rand(r, q)

# attempt to use the categories vector
a = A[categories, :, :]
b = B[categories]
# due to the shapes of the arrays, this multiplication is not possible
print('\nshapes during attempted use case')
print(b.shape, a.shape, S[categories].shape)
S[categories] = np.matmul(b, a)

print(scores.shape, np.sum(np.matmul(b, a))) # # , np.sum(scores))
iterative
(250, 4) (4, 9) (250, 9) (250, 9)
(262, 4) (4, 9) (262, 9) (262, 9)
(512, 9) 5599.511791050638

by category
category 0
(250, 4) (4, 9) (250, 9)
category 0
(262, 4) (4, 9) (262, 9)
(512, 9) 5599.511791050638

shapes during attempted use case
(512, 4) (512, 4, 9) (512, 9)

Speed improved with this snippet, based on StackOverflow's suggestion

if option == 1:
b = binding_per_mode.unsqueeze(1)
a = torch.exp(torch.stack(list(self.log_activities), dim=1))
result = torch.matmul(b, a[batch, :, :])
scores = result.squeeze(1)