Possible to parallelize emission
ksachdeva opened this issue · 2 comments
Hi @yjlolo
It's me again. Wanted to discuss if it is possible to parallelize emission
.
Let's look at this snippet
x_recon = torch.zeros([batch_size, T_max, self.input_dim]).to(x.device)
mu_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
logvar_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
mu_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
logvar_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
z_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
z_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
for t in range(T_max):
# q(z_t | z_{t-1}, x_{t:T})
mu_q, logvar_q = self.combiner(h_rnn[:, t, :], z_prev,
rnn_bidirection=self.rnn_bidirection)
zt_q = self.reparameterization(mu_q, logvar_q)
z_prev = zt_q
# p(z_t | z_{t-1})
mu_p, logvar_p = self.transition(z_prev)
zt_p = self.reparameterization(mu_p, logvar_p)
xt_recon = self.emitter(zt_q).contiguous()
mu_q_seq[:, t, :] = mu_q
logvar_q_seq[:, t, :] = logvar_q
z_q_seq[:, t, :] = zt_q
mu_p_seq[:, t, :] = mu_p
logvar_p_seq[:, t, :] = logvar_p
z_p_seq[:, t, :] = zt_p
x_recon[:, t, :] = xt_recon
As per the above code self.emitter
is called inside the loop (of time steps).
Here is a thought -
If the emitter (function/model) is written in a way that it takes (time_steps, z_dim) as the input shape instead (z_dim) then we can take it out of the for-loop.
Since you are storing z_q per time step (i.e. z_q_seq[:, t, :] = zt_q
) we could simply then supply z_q_seq to the emitter function that takes input of shape (time_steps, z_dim).
What do you think about this? Am I am ignoring some aspect that makes the model invalid?
Regards
Kapil
Yes in this case taking the emitter outside of the for loop is definitely better in terms of computational speed.
But I guess the boost wouldn't be so much because emitter
is simply composed of few layers of nn.Linear
which should be quite fast to compute and would depend more on number of time steps.
But please feel free to experiment and PR!
Thanks, @yjlolo for confirming. I understand the performance gain may not be that much for this dataset and arrangement of emitter but this architecture applied to a more difficult problem with a richer emitter may benefit.
I just wanted to confirm that doing that does not introduce a bug and/or violate the paper.
Regards
Kapil