cornellius-gp/gpytorch

[Feature Request] Inconsistent behaviour of batch broadcast when computing log_marginal of non-Gaussian likelihoods.

shixinxing opened this issue ยท 0 comments

๐Ÿš€ Feature Request

When I call log_marginal() and expected_log_prob(), the Gaussian and non-Gaussian likelihoods show inconsistent shapes. The Gaussian likelihood seems to allow for extra batch dimensions of observations, while the non-Gaussian likelihood does not.

Problem

In the following code snippet, the observation $y$ has shape [4,3,8], and the latent variable has a Gaussian distribution $q(f)$ with shape [3,8]. While the log-marginal and the expected log-likelihood of GaussianLikelihood have broadcastable results on the first batch dimension, the LaplaceLikelihood triggers not broadcastable errors.

y = torch.rand(4,3,8)
mean, L = torch.randn(3, 8), torch.rand(3, 8, 8)
q_f = MultivariateNormal(mean=mean, covariance_matrix=L@L.mT)

# log_marginal for Gaussian likelihood
lk = GaussianLikelihood(batch_shape=[3])  # batched noise

# expected_log_prob
print(lk.expected_log_prob(target=y, input=q_f).shape)  # [4,3,8]
# log_marginal
print(lk.log_marginal(y, q_f).shape) # [4,3,8]

# log_marginal for other likelihoods
laplace_lk = LaplaceLikelihood(batch_shape=[3])

# expected_log_prob
print(laplace_lk.expected_log_prob(observations=y, function_dist=q_f).shape)  # Error: not broadcastable
# log_marginal
print(laplace_lk.log_marginal(y, q_f).shape)  # Error: not broadcastable