ioangatop/srVAE

Ground truth in forward pass?

Constantin771 opened this issue · 0 comments

Hello, first of all thank you for the code!
I have a question about the forward pass in the srVAE model:

def forward(self, x, **kwargs):
""" Forward pass through the inference and the generative model. """
# y ~ f(x) (determinist)
y = self.compressed_transoformation(x)
# u ~ q(u| y)
u_q_mean, u_q_logvar = self.q_u(y)
u_q = self.reparameterize(u_q_mean, u_q_logvar)
# z ~ q(z| x, y)
z_q_mean, z_q_logvar = self.q_z(x)
z_q = self.reparameterize(z_q_mean, z_q_logvar)
# x ~ p(x| y, z)
x_logits = self.p_x((y, z_q))
# y ~ p(y| u)
y_logits = self.p_y(u_q)

It looks like p_x gets the ground truth y as input. Should this not be the y that is predicted by p_y instead?

Best regards!