cornellius-gp/gpytorch

Label flattening fails with custom mean function from another GP

Closed this issue ยท 1 comments

๐Ÿ› Bug

Using a custom mean function, which is the posterior mean of another GP, when getting a prediction the input shape of the labels and the posterior mean do not match, hence throwing an error.

To reproduce

** Code snippet to reproduce **

import torch

from botorch.models.transforms import Normalize, Standardize
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
from gpytorch.means import Mean

import math

f = lambda x: torch.sin(math.pi * x) + 0.1
# Problem set up
d = 1
bounds = torch.tensor([[-4.0], [4.0]])
train_X1 = torch.linspace(bounds[0].item(), bounds[1].item(), 20).unsqueeze(-1).to(torch.float64)
train_Y1 = f(train_X1)
train_Y1 += 0.1 * torch.randn_like(train_Y1)
# Build first GP model
gp1 = SingleTaskGP(train_X1, train_Y1, input_transform=Normalize(d,bounds=bounds), outcome_transform=Standardize(m=1))
mll = ExactMarginalLogLikelihood(gp1.likelihood, gp1)
fit_gpytorch_mll(mll)

# Custom mean function using the mean function of a fitted GP
class MyMean(Mean):
    def __init__(self, gp, **kwargs):
        super(MyMean, self).__init__(**kwargs)
        self.gp = gp
    def forward(self, x):
        with torch.no_grad():
            batch_shape = x.shape[:-1]
            posterior_mean = self.gp.posterior(x).mean

            print("Input shape:", x.shape)
            print("Posterior mean shape:", posterior_mean.shape)

        if batch_shape == posterior_mean.shape[:-1]:
            return posterior_mean
        else:
            return posterior_mean.reshape(*batch_shape, -1)

# Fit second GP with mean function from first GP using MyMean()
train_X2 = torch.tensor([-2.5,0.,1.]).unsqueeze(-1).to(torch.float64)
train_Y2 = f(train_X2)
train_Y2 += 0.1 * torch.randn_like(train_Y2)
gp2 = SingleTaskGP(train_X2,
                   train_Y2,
                   outcome_transform=Standardize(m=1),
                   input_transform=Normalize(d=train_X2.shape[1],bounds=bounds),
                   mean_module=MyMean(gp1), # Custom mean used here
                   )
mll = ExactMarginalLogLikelihood(gp2.likelihood, gp2)
fit_gpytorch_mll(mll)

# Test of posterior yields error
x = torch.randn(3, 1)  # Example input tensor
gp2.posterior(x).mean

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:47, in DefaultPredictionStrategy.__init__(self, train_inputs, train_prior_dist, train_labels, likelihood, root, inv_root)
     46 try:
---> 47     train_labels = train_labels.reshape(
     48         *train_labels.shape[: -len(self.train_shape)], self._train_shape.numel()
     49     )
     50 except RuntimeError:

RuntimeError: shape '[1]' is invalid for input of size 3

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[324], line 2
      1 x = torch.randn(3, 1)  # Example input tensor
----> 2 gp2.posterior(x)

File ~/anaconda3/lib/python3.10/site-packages/botorch/models/gpytorch.py:383, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform, **kwargs)
    377     X, output_dim_idx = add_output_dim(
    378         X=X, original_batch_shape=self._input_batch_shape
    379     )
    380 # NOTE: BoTorch's GPyTorchModels also inherit from GPyTorch's ExactGP, thus
    381 # self(X) calls GPyTorch's ExactGP's __call__, which computes the posterior,
    382 # rather than e.g. SingleTaskGP's forward, which computes the prior.
--> 383 mvn = self(X)
    384 if observation_noise is not False:
    385     if self._num_outputs > 1:

File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_gp.py:294, in ExactGP.__call__(self, *args, **kwargs)
    291     train_output = super().__call__(*train_inputs, **kwargs)
    293     # Create the prediction strategy for
--> 294     self.prediction_strategy = prediction_strategy(
    295         train_inputs=train_inputs,
    296         train_prior_dist=train_output,
    297         train_labels=self.train_targets,
    298         likelihood=self.likelihood,
    299     )
    301 # Concatenate the input to the training input
    302 full_inputs = []

File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:37, in prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood)
     35 else:
     36     cls = DefaultPredictionStrategy
---> 37 return cls(train_inputs, train_prior_dist, train_labels, likelihood)

File ~/anaconda3/lib/python3.10/site-packages/gpytorch/kernels/scale_kernel.py:124, in ScaleKernel.prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood)
    123 def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
--> 124     return self.base_kernel.prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood)

File ~/anaconda3/lib/python3.10/site-packages/gpytorch/kernels/kernel.py:445, in Kernel.prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood)
    438 def prediction_strategy(
    439     self,
    440     train_inputs: Tensor,
   (...)
    443     likelihood: GaussianLikelihood,
    444 ) -> exact_prediction_strategies.PredictionStrategy:
--> 445     return exact_prediction_strategies.DefaultPredictionStrategy(
    446         train_inputs, train_prior_dist, train_labels, likelihood
    447     )

File ~/anaconda3/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:51, in DefaultPredictionStrategy.__init__(self, train_inputs, train_prior_dist, train_labels, likelihood, root, inv_root)
     47     train_labels = train_labels.reshape(
     48         *train_labels.shape[: -len(self.train_shape)], self._train_shape.numel()
     49     )
     50 except RuntimeError:
---> 51     raise RuntimeError(
     52         "Flattening the training labels failed. The most common cause of this error is "
     53         + "that the shapes of the prior mean and the training labels are mismatched. "
     54         + "The shape of the train targets is {0}, ".format(train_labels.shape)
     55         + "while the reported shape of the mean is {0}.".format(train_prior_dist.mean.shape)
     56     )
     58 self.train_inputs = train_inputs
     59 self.train_prior_dist = train_prior_dist

RuntimeError: Flattening the training labels failed. The most common cause of this error is that the shapes of the prior mean and the training labels are mismatched. The shape of the train targets is torch.Size([3]), while the reported shape of the mean is torch.Size([3, 1]).

Expected Behavior

A prediction from GP2 using the default covariance and the mean function from GP1.

System information

Please complete the following information:

  • GPyTorch Version: 1.11
  • PyTorch Version: 2.2.0.post100
  • mac OS Sonoma 14.4

Additional context

Am attempting to use the variance from the first GP as well in the second GP once I can get this working.

Solved by changing posterior_mean = self.gp.posterior(x).mean to posterior_mean = self.gp.posterior(x).mean.unsqueeze(-1).