[Bug] Error in tutorials and derivative GPs
Codesflow-Simon opened this issue ยท 1 comments
๐ Bug
Derivative GPs are not training on derivative information. The derivative GPs, like those in the tutorial linked below are not actually training on the derivative information, they are just evaluating the derivative in model.eval()
.
To reproduce
The bug exists in the tutorial on 2d derivatives:
https://docs.gpytorch.ai/en/stable/examples/08_Advanced_Usage/Simple_GP_Regression_Derivative_Information_2d.html
The below code is exactly the code in the tutorial, except I have added some extra code to print some values to show the bug in the comment flags ## mll debug ##
.
This code is doing the exact same thing as -mll(output, train_y)
, but I'm making it explicit to get the intermediary values.
We notice that printing the mean that comes out of the output print(mean[:12].reshape(4, 3))
both derivatives (which appear 2 of every 3 terms, hence the reshape to make it more readable) are both zero, there is no reason this should be the case unless the data is constant, which it isn't.
This means that the diff
tensor will always have incorrect values for the derivative outputs. Using this bad diff
the rest of my injection snippet calculates the negative marginal log likelihood.
Running the script, the reported loss using loss = -mll(output, train_y)
yields the same loss as my calculations. Meaning that this error is also in the built in code.
In my own further testing I have noticed that calling model.eval()
, likelihood.eval()
will correct this error and the mean output will be as expected, this is how the tutorial is able to produce good plots of the derivatives.
import torch
import gpytorch
import math
from matplotlib import cm
from matplotlib import pyplot as plt
import numpy as np
def franke(X, Y):
term1 = .75*torch.exp(-((9*X - 2).pow(2) + (9*Y - 2).pow(2))/4)
term2 = .75*torch.exp(-((9*X + 1).pow(2))/49 - (9*Y + 1)/10)
term3 = .5*torch.exp(-((9*X - 7).pow(2) + (9*Y - 3).pow(2))/4)
term4 = .2*torch.exp(-(9*X - 4).pow(2) - (9*Y - 7).pow(2))
f = term1 + term2 + term3 - term4
dfx = -2*(9*X - 2)*9/4 * term1 - 2*(9*X + 1)*9/49 * term2 + \
-2*(9*X - 7)*9/4 * term3 + 2*(9*X - 4)*9 * term4
dfy = -2*(9*Y - 2)*9/4 * term1 - 9/10 * term2 + \
-2*(9*Y - 3)*9/4 * term3 + 2*(9*Y - 7)*9 * term4
return f, dfx, dfy
xv, yv = torch.meshgrid(torch.linspace(0, 1, 10), torch.linspace(0, 1, 10), indexing="ij")
train_x = torch.cat((
xv.contiguous().view(xv.numel(), 1),
yv.contiguous().view(yv.numel(), 1)),
dim=1
)
f, dfx, dfy = franke(train_x[:, 0], train_x[:, 1])
train_y = torch.stack([f, dfx, dfy], -1).squeeze(1)
train_y += 0.05 * torch.randn(train_y.size()) # Add noise to both values and gradients
class GPModelWithDerivatives(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(GPModelWithDerivatives, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMeanGrad()
self.base_kernel = gpytorch.kernels.RBFKernelGrad(ard_num_dims=2)
self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3) # Value + x-derivative + y-derivative
model = GPModelWithDerivatives(train_x, train_y, likelihood)
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 50
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.05) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
optimizer.zero_grad()
output = model(train_x)
## mll debug ##
###############
lieklihood_out = mll.likelihood(output)
mean, covar = lieklihood_out.loc, lieklihood_out.lazy_covariance_matrix
diff = train_y.flatten(-2) - mean
print(mean[:12].reshape(4, 3))
# Repeat the covar to match the batch shape of diff
if diff.shape[:-1] != covar.batch_shape:
if len(diff.shape[:-1]) < len(covar.batch_shape):
diff = diff.expand(covar.shape[:-1])
else:
padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape)
covar = covar.repeat(
*(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)),
1,
1,
)
covar = covar.evaluate_kernel()
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
num_data = output.event_shape.numel()
my_mll = res.div_(num_data)
print(f"inv_quad: {-inv_quad.item()}, logdet: {-logdet.item()}, neg mll: {-my_mll.item()}")
## mll debug end ##
###################
loss = -mll(output, train_y)
loss.backward()
print("Iter %d/%d - Loss: %.3f lengthscales: %.3f, %.3f noise: %.3f" % (
i + 1, training_iter, loss.item(),
model.covar_module.base_kernel.lengthscale.squeeze()[0],
model.covar_module.base_kernel.lengthscale.squeeze()[1],
model.likelihood.noise.item()
))
optimizer.step()
# Set into eval mode
model.eval()
likelihood.eval()
# Initialize plots
fig, ax = plt.subplots(2, 3, figsize=(14, 10))
# Test points
n1, n2 = 50, 50
xv, yv = torch.meshgrid(torch.linspace(0, 1, n1), torch.linspace(0, 1, n2), indexing="ij")
f, dfx, dfy = franke(xv, yv)
# Make predictions
with torch.no_grad(), gpytorch.settings.fast_computations(log_prob=False, covar_root_decomposition=False):
test_x = torch.stack([xv.reshape(n1*n2, 1), yv.reshape(n1*n2, 1)], -1).squeeze(1)
predictions = likelihood(model(test_x))
mean = predictions.mean
extent = (xv.min(), xv.max(), yv.max(), yv.min())
ax[0, 0].imshow(f, extent=extent, cmap=cm.jet)
ax[0, 0].set_title('True values')
ax[0, 1].imshow(dfx, extent=extent, cmap=cm.jet)
ax[0, 1].set_title('True x-derivatives')
ax[0, 2].imshow(dfy, extent=extent, cmap=cm.jet)
ax[0, 2].set_title('True y-derivatives')
ax[1, 0].imshow(mean[:, 0].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 0].set_title('Predicted values')
ax[1, 1].imshow(mean[:, 1].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 1].set_title('Predicted x-derivatives')
ax[1, 2].imshow(mean[:, 2].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 2].set_title('Predicted y-derivatives')
plt.show()
System information
- Gpytorch 1.11
- Torch 2.2.1+cu118
- Ubuntu 22.04.4 LTS
This behavior is expected and is no cause for concern