cornellius-gp/gpytorch

[Bug] Models not JIT traceable/exportable to TorchScript after Fantasization

Opened this issue ยท 4 comments

๐Ÿ› Bug

Fantasization / conditioning model on new data points renders the model unexportable to TorchScript/not traceable with JIT. Models cannot be JIT traced/exported to Torchscript once get_fantasy_model method is called.

To reproduce

** Code snippet to reproduce **

import math
import torch
import gpytorch

X = torch.linspace(0, 1, 100)
y = torch.sin(X * (2 * math.pi)) + torch.randn(X.size()) * 0.2

train_x, fantasy_x = X[:90], X[90:]
train_y, fantasy_y = y[:90], y[90:]


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

training_iter = 5
model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()

model.eval()
X_test = torch.linspace(0, 1, 51)
pred = model(X_test)
fantasized_model = model.get_fantasy_model(fantasy_x, fantasy_y)

class MeanVarModelWrapper(torch.nn.Module):
    def __init__(self, gp):
        super().__init__()
        self.gp = gp

    def forward(self, x):
        output_dist = self.gp(x)
        return output_dist.mean, output_dist.variance


with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode():
    fantasized_model.eval()
    pred = fantasized_model(X_test)  # Do precomputation
    traced_model = torch.jit.trace(MeanVarModelWrapper(fantasized_model), X_test)

** Stack trace/error message **

{
	"name": "RuntimeError",
	"message": "Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
<Outputs the new `train_train_covar` matrix which includes the new data points>

	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 15
     13 test_x = torch.linspace(0, 1, 51)
     14 pred = fantasized_model(test_x)  # Do precomputation
---> 15 traced_model = torch.jit.trace(MeanVarModelWrapper(fantasized_model), test_x)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1000, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    993 from torch._utils_internal import (
    994     check_if_torch_exportable,
    995     log_torch_jit_trace_exportability,
    996     log_torchscript_usage,
    997 )
    999 log_torchscript_usage(\"trace\")
-> 1000 traced_func = _trace_impl(
   1001     func,
   1002     example_inputs,
   1003     optimize,
   1004     check_trace,
   1005     check_inputs,
   1006     check_tolerance,
   1007     strict,
   1008     _force_outplace,
   1009     _module_class,
   1010     _compilation_unit,
   1011     example_kwarg_inputs,
   1012     _store_inputs,
   1013 )
   1015 if check_if_torch_exportable():
   1016     from torch._export.converter import TS2EPConverter

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:695, in _trace_impl(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    693         else:
    694             raise RuntimeError(\"example_kwarg_inputs should be a dict\")
--> 695     return trace_module(
    696         func,
    697         {\"forward\": example_inputs},
    698         None,
    699         check_trace,
    700         wrap_check_inputs(check_inputs),
    701         check_tolerance,
    702         strict,
    703         _force_outplace,
    704         _module_class,
    705         example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
    706         _store_inputs=_store_inputs,
    707     )
    708 if (
    709     hasattr(func, \"__self__\")
    710     and isinstance(func.__self__, torch.nn.Module)
    711     and func.__name__ == \"forward\"
    712 ):
    713     if example_inputs is None:

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1275, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
   1273 else:
   1274     example_inputs = make_tuple(example_inputs)
-> 1275     module._c._create_method_from_trace(
   1276         method_name,
   1277         func,
   1278         example_inputs,
   1279         var_lookup_fn,
   1280         strict,
   1281         _force_outplace,
   1282         argument_names,
   1283         _store_inputs,
   1284     )
   1286 check_trace_method = module._c._get_method(method_name)
   1288 # Check the trace against new traces created from user-specified inputs

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1543, in Module._slow_forward(self, *input, **kwargs)
   1541         recording_scopes = False
   1542 try:
-> 1543     result = self.forward(*input, **kwargs)
   1544 finally:
   1545     if recording_scopes:

Cell In[8], line 7, in MeanVarModelWrapper.forward(self, x)
      6 def forward(self, x):
----> 7     output_dist = self.gp(x)
      8     return output_dist.mean, output_dist.variance

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:333, in ExactGP.__call__(self, *args, **kwargs)
    328 # Make the prediction
    329 with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
    330     (
    331         predictive_mean,
    332         predictive_covar,
--> 333     ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
    335 # Reshape predictive mean to match the appropriate event shape
    336 predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:322, in DefaultPredictionStrategy.exact_prediction(self, joint_mean, joint_covar)
    317     test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
    318     test_train_covar = joint_covar[..., self.num_train :, : self.num_train]
    320 return (
    321     self.exact_predictive_mean(test_mean, test_train_covar),
--> 322     self.exact_predictive_covar(test_test_covar, test_train_covar),
    323 )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:409, in DefaultPredictionStrategy.exact_predictive_covar(self, test_test_covar, test_train_covar)
    405         return test_test_covar + MatmulLinearOperator(test_train_covar, covar_correction_rhs.mul(-1))
    407 precomputed_cache = self.covar_cache
--> 409 covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
    410 if torch.is_tensor(test_test_covar):
    411     return to_linear_operator(
    412         torch.add(
    413             test_test_covar, covar_inv_quad_form_root @ covar_inv_quad_form_root.transpose(-1, -2), alpha=-1
    414         )
    415     )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:117, in DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_train_covar)
    104 r\"\"\"
    105 Computes :math:`K_{X^{*}X} S` given a precomputed cache
    106 Where :math:`S` is a tensor such that :math:`SS^{\\top} = (K_{XX} + \\sigma^2 I)^{-1}`
   (...)
    113     :obj:`~linear_operator.operators.LinearOperator`: :math:`K_{X^{*}X} S`
    114 \"\"\"
    115 # Here the precomputed cache represents S,
    116 # where S S^T = (K_XX + sigma^2 I)^-1
--> 117 return test_train_covar.matmul(precomputed_cache)

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
<Outputs the new `train_train_covar` matrix which includes the new data points>
}

Expected Behavior

Should have been able to export the model to torchscript/JIT trace the model

System information

Please complete the following information:

  • GPyTorch version: 1.13
  • PyTorch version: 2.4.1
  • OS: macOS Sonoma 14.5

Additional context

The error was because the new_covar_cache created in this line which is the updated precomputed cache of the training data covariance matrix with the new observations, is still part of the computational graph(and therefore tracks gradients).

Detaching this value from the computational graph in the same line, solves the issue because now this matrix is a gradient-free tensor and the model can be JIT traced. I can make a PR with this fix if that's helpful.

Detaching this value from the computational graph in the same line, solves the issue because now this matrix is a gradient-free tensor and the model can be JIT traced. I can make a PR with this fix if that's helpful.

What if we want to differentiate some downstream computation of the fantasized model w.r.t. the training inputs (or the fantasy location)? Detaching this always would prevent that? I guess we could detach this when trying to JIT this instead?

We already have a trace_mode setting that indicates the model is being called with the intention of computing traceable caches (that should be detached). This should be fixable I think by just detaching if that setting is on.

@jacobrgardner yes I think that's correct. It's just a bug that the new_covar_cache is not detached since in the usual (unfantasized) case the precomputed_cache in this line is always detached.

So, I think it should work if we can just detach new_covar_cache before actually adding it to the cache in get_fantasy_strategy method.

@Balandat @jacobrgardner thanks for your inputs on this. I have added a PR with this fix in #2605 . Please let me know if you have any comments on that.