[Feature Request] Inconsistent behaviour of batch broadcast when computing log_marginal of non-Gaussian likelihoods.
shixinxing opened this issue ยท 0 comments
shixinxing commented
๐ Feature Request
When I call log_marginal()
and expected_log_prob()
, the Gaussian and non-Gaussian likelihoods show inconsistent shapes. The Gaussian likelihood seems to allow for extra batch dimensions of observations, while the non-Gaussian likelihood does not.
Problem
In the following code snippet, the observation [4,3,8]
, and the latent variable has a Gaussian distribution [3,8]
. While the log-marginal and the expected log-likelihood of GaussianLikelihood
have broadcastable results on the first batch dimension, the LaplaceLikelihood
triggers not broadcastable
errors.
y = torch.rand(4,3,8)
mean, L = torch.randn(3, 8), torch.rand(3, 8, 8)
q_f = MultivariateNormal(mean=mean, covariance_matrix=L@L.mT)
# log_marginal for Gaussian likelihood
lk = GaussianLikelihood(batch_shape=[3]) # batched noise
# expected_log_prob
print(lk.expected_log_prob(target=y, input=q_f).shape) # [4,3,8]
# log_marginal
print(lk.log_marginal(y, q_f).shape) # [4,3,8]
# log_marginal for other likelihoods
laplace_lk = LaplaceLikelihood(batch_shape=[3])
# expected_log_prob
print(laplace_lk.expected_log_prob(observations=y, function_dist=q_f).shape) # Error: not broadcastable
# log_marginal
print(laplace_lk.log_marginal(y, q_f).shape) # Error: not broadcastable