[Bug] Models not JIT traceable/exportable to TorchScript after Fantasization
Opened this issue ยท 4 comments
๐ Bug
Fantasization / conditioning model on new data points renders the model unexportable to TorchScript/not traceable with JIT. Models cannot be JIT traced/exported to Torchscript once get_fantasy_model
method is called.
To reproduce
** Code snippet to reproduce **
import math
import torch
import gpytorch
X = torch.linspace(0, 1, 100)
y = torch.sin(X * (2 * math.pi)) + torch.randn(X.size()) * 0.2
train_x, fantasy_x = X[:90], X[90:]
train_y, fantasy_y = y[:90], y[90:]
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)
training_iter = 5
model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()
model.eval()
X_test = torch.linspace(0, 1, 51)
pred = model(X_test)
fantasized_model = model.get_fantasy_model(fantasy_x, fantasy_y)
class MeanVarModelWrapper(torch.nn.Module):
def __init__(self, gp):
super().__init__()
self.gp = gp
def forward(self, x):
output_dist = self.gp(x)
return output_dist.mean, output_dist.variance
with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode():
fantasized_model.eval()
pred = fantasized_model(X_test) # Do precomputation
traced_model = torch.jit.trace(MeanVarModelWrapper(fantasized_model), X_test)
** Stack trace/error message **
{
"name": "RuntimeError",
"message": "Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
<Outputs the new `train_train_covar` matrix which includes the new data points>
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[8], line 15
13 test_x = torch.linspace(0, 1, 51)
14 pred = fantasized_model(test_x) # Do precomputation
---> 15 traced_model = torch.jit.trace(MeanVarModelWrapper(fantasized_model), test_x)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1000, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
993 from torch._utils_internal import (
994 check_if_torch_exportable,
995 log_torch_jit_trace_exportability,
996 log_torchscript_usage,
997 )
999 log_torchscript_usage(\"trace\")
-> 1000 traced_func = _trace_impl(
1001 func,
1002 example_inputs,
1003 optimize,
1004 check_trace,
1005 check_inputs,
1006 check_tolerance,
1007 strict,
1008 _force_outplace,
1009 _module_class,
1010 _compilation_unit,
1011 example_kwarg_inputs,
1012 _store_inputs,
1013 )
1015 if check_if_torch_exportable():
1016 from torch._export.converter import TS2EPConverter
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:695, in _trace_impl(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
693 else:
694 raise RuntimeError(\"example_kwarg_inputs should be a dict\")
--> 695 return trace_module(
696 func,
697 {\"forward\": example_inputs},
698 None,
699 check_trace,
700 wrap_check_inputs(check_inputs),
701 check_tolerance,
702 strict,
703 _force_outplace,
704 _module_class,
705 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
706 _store_inputs=_store_inputs,
707 )
708 if (
709 hasattr(func, \"__self__\")
710 and isinstance(func.__self__, torch.nn.Module)
711 and func.__name__ == \"forward\"
712 ):
713 if example_inputs is None:
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1275, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1273 else:
1274 example_inputs = make_tuple(example_inputs)
-> 1275 module._c._create_method_from_trace(
1276 method_name,
1277 func,
1278 example_inputs,
1279 var_lookup_fn,
1280 strict,
1281 _force_outplace,
1282 argument_names,
1283 _store_inputs,
1284 )
1286 check_trace_method = module._c._get_method(method_name)
1288 # Check the trace against new traces created from user-specified inputs
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1543, in Module._slow_forward(self, *input, **kwargs)
1541 recording_scopes = False
1542 try:
-> 1543 result = self.forward(*input, **kwargs)
1544 finally:
1545 if recording_scopes:
Cell In[8], line 7, in MeanVarModelWrapper.forward(self, x)
6 def forward(self, x):
----> 7 output_dist = self.gp(x)
8 return output_dist.mean, output_dist.variance
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:333, in ExactGP.__call__(self, *args, **kwargs)
328 # Make the prediction
329 with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
330 (
331 predictive_mean,
332 predictive_covar,
--> 333 ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
335 # Reshape predictive mean to match the appropriate event shape
336 predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:322, in DefaultPredictionStrategy.exact_prediction(self, joint_mean, joint_covar)
317 test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
318 test_train_covar = joint_covar[..., self.num_train :, : self.num_train]
320 return (
321 self.exact_predictive_mean(test_mean, test_train_covar),
--> 322 self.exact_predictive_covar(test_test_covar, test_train_covar),
323 )
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:409, in DefaultPredictionStrategy.exact_predictive_covar(self, test_test_covar, test_train_covar)
405 return test_test_covar + MatmulLinearOperator(test_train_covar, covar_correction_rhs.mul(-1))
407 precomputed_cache = self.covar_cache
--> 409 covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
410 if torch.is_tensor(test_test_covar):
411 return to_linear_operator(
412 torch.add(
413 test_test_covar, covar_inv_quad_form_root @ covar_inv_quad_form_root.transpose(-1, -2), alpha=-1
414 )
415 )
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_prediction_strategies.py:117, in DefaultPredictionStrategy._exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_train_covar)
104 r\"\"\"
105 Computes :math:`K_{X^{*}X} S` given a precomputed cache
106 Where :math:`S` is a tensor such that :math:`SS^{\\top} = (K_{XX} + \\sigma^2 I)^{-1}`
(...)
113 :obj:`~linear_operator.operators.LinearOperator`: :math:`K_{X^{*}X} S`
114 \"\"\"
115 # Here the precomputed cache represents S,
116 # where S S^T = (K_XX + sigma^2 I)^-1
--> 117 return test_train_covar.matmul(precomputed_cache)
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
<Outputs the new `train_train_covar` matrix which includes the new data points>
}
Expected Behavior
Should have been able to export the model to torchscript/JIT trace the model
System information
Please complete the following information:
- GPyTorch version: 1.13
- PyTorch version: 2.4.1
- OS: macOS Sonoma 14.5
Additional context
The error was because the new_covar_cache
created in this line which is the updated precomputed cache of the training data covariance matrix with the new observations, is still part of the computational graph(and therefore tracks gradients).
Detaching this value from the computational graph in the same line, solves the issue because now this matrix is a gradient-free tensor and the model can be JIT traced. I can make a PR with this fix if that's helpful.
Detaching this value from the computational graph in the same line, solves the issue because now this matrix is a gradient-free tensor and the model can be JIT traced. I can make a PR with this fix if that's helpful.
What if we want to differentiate some downstream computation of the fantasized model w.r.t. the training inputs (or the fantasy location)? Detaching this always would prevent that? I guess we could detach this when trying to JIT this instead?
We already have a trace_mode
setting that indicates the model is being called with the intention of computing traceable caches (that should be detached). This should be fixable I think by just detaching if that setting is on.
@jacobrgardner yes I think that's correct. It's just a bug that the new_covar_cache
is not detached since in the usual (unfantasized) case the precomputed_cache
in this line is always detached.
So, I think it should work if we can just detach new_covar_cache
before actually adding it to the cache in get_fantasy_strategy
method.
@Balandat @jacobrgardner thanks for your inputs on this. I have added a PR with this fix in #2605 . Please let me know if you have any comments on that.