Label flattening fails with custom mean function from another GP
Closed this issue ยท 1 comments
neildhir commented
๐ Bug
Using a custom mean function, which is the posterior mean of another GP, when getting a prediction the input shape of the labels and the posterior mean do not match, hence throwing an error.
To reproduce
** Code snippet to reproduce **
import torch
from botorch.models.transforms import Normalize, Standardize
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
from gpytorch.means import Mean
import math
f = lambda x: torch.sin(math.pi * x) + 0.1
# Problem set up
d = 1
bounds = torch.tensor([[-4.0], [4.0]])
train_X1 = torch.linspace(bounds[0].item(), bounds[1].item(), 20).unsqueeze(-1).to(torch.float64)
train_Y1 = f(train_X1)
train_Y1 += 0.1 * torch.randn_like(train_Y1)
# Build first GP model
gp1 = SingleTaskGP(train_X1, train_Y1, input_transform=Normalize(d,bounds=bounds), outcome_transform=Standardize(m=1))
mll = ExactMarginalLogLikelihood(gp1.likelihood, gp1)
fit_gpytorch_mll(mll)
# Custom mean function using the mean function of a fitted GP
class MyMean(Mean):
def __init__(self, gp, **kwargs):
super(MyMean, self).__init__(**kwargs)
self.gp = gp
def forward(self, x):
with torch.no_grad():
batch_shape = x.shape[:-1]
posterior_mean = self.gp.posterior(x).mean
print("Input shape:", x.shape)
print("Posterior mean shape:", posterior_mean.shape)
if batch_shape == posterior_mean.shape[:-1]:
return posterior_mean
else:
return posterior_mean.reshape(*batch_shape, -1)
# Fit second GP with mean function from first GP using MyMean()
train_X2 = torch.tensor([-2.5,0.,1.]).unsqueeze(-1).to(torch.float64)
train_Y2 = f(train_X2)
train_Y2 += 0.1 * torch.randn_like(train_Y2)
gp2 = SingleTaskGP(train_X2,
train_Y2,
outcome_transform=Standardize(m=1),
input_transform=Normalize(d=train_X2.shape[1],bounds=bounds),
mean_module=MyMean(gp1), # Custom mean used here
)
mll = ExactMarginalLogLikelihood(gp2.likelihood, gp2)
fit_gpytorch_mll(mll)
# Test of posterior yields error
x = torch.randn(3, 1) # Example input tensor
gp2.posterior(x).mean
** Stack trace/error message **
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:47, in DefaultPredictionStrategy.__init__(self, train_inputs, train_prior_dist, train_labels, likelihood, root, inv_root)
46 try:
---> 47 train_labels = train_labels.reshape(
48 *train_labels.shape[: -len(self.train_shape)], self._train_shape.numel()
49 )
50 except RuntimeError:
RuntimeError: shape '[1]' is invalid for input of size 3
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[324], line 2
1 x = torch.randn(3, 1) # Example input tensor
----> 2 gp2.posterior(x)
File ~/anaconda3/lib/python3.10/site-packages/botorch/models/gpytorch.py:383, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform, **kwargs)
377 X, output_dim_idx = add_output_dim(
378 X=X, original_batch_shape=self._input_batch_shape
379 )
380 # NOTE: BoTorch's GPyTorchModels also inherit from GPyTorch's ExactGP, thus
381 # self(X) calls GPyTorch's ExactGP's __call__, which computes the posterior,
382 # rather than e.g. SingleTaskGP's forward, which computes the prior.
--> 383 mvn = self(X)
384 if observation_noise is not False:
385 if self._num_outputs > 1:
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_gp.py:294, in ExactGP.__call__(self, *args, **kwargs)
291 train_output = super().__call__(*train_inputs, **kwargs)
293 # Create the prediction strategy for
--> 294 self.prediction_strategy = prediction_strategy(
295 train_inputs=train_inputs,
296 train_prior_dist=train_output,
297 train_labels=self.train_targets,
298 likelihood=self.likelihood,
299 )
301 # Concatenate the input to the training input
302 full_inputs = []
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:37, in prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood)
35 else:
36 cls = DefaultPredictionStrategy
---> 37 return cls(train_inputs, train_prior_dist, train_labels, likelihood)
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/kernels/scale_kernel.py:124, in ScaleKernel.prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood)
123 def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
--> 124 return self.base_kernel.prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood)
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/kernels/kernel.py:445, in Kernel.prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood)
438 def prediction_strategy(
439 self,
440 train_inputs: Tensor,
(...)
443 likelihood: GaussianLikelihood,
444 ) -> exact_prediction_strategies.PredictionStrategy:
--> 445 return exact_prediction_strategies.DefaultPredictionStrategy(
446 train_inputs, train_prior_dist, train_labels, likelihood
447 )
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:51, in DefaultPredictionStrategy.__init__(self, train_inputs, train_prior_dist, train_labels, likelihood, root, inv_root)
47 train_labels = train_labels.reshape(
48 *train_labels.shape[: -len(self.train_shape)], self._train_shape.numel()
49 )
50 except RuntimeError:
---> 51 raise RuntimeError(
52 "Flattening the training labels failed. The most common cause of this error is "
53 + "that the shapes of the prior mean and the training labels are mismatched. "
54 + "The shape of the train targets is {0}, ".format(train_labels.shape)
55 + "while the reported shape of the mean is {0}.".format(train_prior_dist.mean.shape)
56 )
58 self.train_inputs = train_inputs
59 self.train_prior_dist = train_prior_dist
RuntimeError: Flattening the training labels failed. The most common cause of this error is that the shapes of the prior mean and the training labels are mismatched. The shape of the train targets is torch.Size([3]), while the reported shape of the mean is torch.Size([3, 1]).
Expected Behavior
A prediction from GP2 using the default covariance and the mean function from GP1.
System information
Please complete the following information:
- GPyTorch Version: 1.11
- PyTorch Version: 2.2.0.post100
- mac OS Sonoma 14.4
Additional context
Am attempting to use the variance from the first GP as well in the second GP once I can get this working.
neildhir commented
Solved by changing posterior_mean = self.gp.posterior(x).mean
to posterior_mean = self.gp.posterior(x).mean.unsqueeze(-1)
.