yjlolo/pytorch-deep-markov-model

Possible to parallelize emission

ksachdeva opened this issue · 2 comments

Hi @yjlolo

It's me again. Wanted to discuss if it is possible to parallelize emission.

Let's look at this snippet

        x_recon = torch.zeros([batch_size, T_max, self.input_dim]).to(x.device)
        mu_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        logvar_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        mu_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        logvar_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        z_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        z_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        for t in range(T_max):
            # q(z_t | z_{t-1}, x_{t:T})
            mu_q, logvar_q = self.combiner(h_rnn[:, t, :], z_prev,
                                           rnn_bidirection=self.rnn_bidirection)
            zt_q = self.reparameterization(mu_q, logvar_q)
            z_prev = zt_q
            # p(z_t | z_{t-1})
            mu_p, logvar_p = self.transition(z_prev)
            zt_p = self.reparameterization(mu_p, logvar_p)

            xt_recon = self.emitter(zt_q).contiguous()

            mu_q_seq[:, t, :] = mu_q
            logvar_q_seq[:, t, :] = logvar_q
            z_q_seq[:, t, :] = zt_q
            mu_p_seq[:, t, :] = mu_p
            logvar_p_seq[:, t, :] = logvar_p
            z_p_seq[:, t, :] = zt_p
            x_recon[:, t, :] = xt_recon

As per the above code self.emitter is called inside the loop (of time steps).

Here is a thought -

If the emitter (function/model) is written in a way that it takes (time_steps, z_dim) as the input shape instead (z_dim) then we can take it out of the for-loop.

Since you are storing z_q per time step (i.e. z_q_seq[:, t, :] = zt_q) we could simply then supply z_q_seq to the emitter function that takes input of shape (time_steps, z_dim).

What do you think about this? Am I am ignoring some aspect that makes the model invalid?

Regards
Kapil

Yes in this case taking the emitter outside of the for loop is definitely better in terms of computational speed.
But I guess the boost wouldn't be so much because emitter is simply composed of few layers of nn.Linear which should be quite fast to compute and would depend more on number of time steps.

But please feel free to experiment and PR!

Thanks, @yjlolo for confirming. I understand the performance gain may not be that much for this dataset and arrangement of emitter but this architecture applied to a more difficult problem with a richer emitter may benefit.

I just wanted to confirm that doing that does not introduce a bug and/or violate the paper.

Regards
Kapil