bug: `ConcretizationTypeError` when trying to use `prob_model.predictive()`
PaulScemama opened this issue · 5 comments
Bug Report
Hi! I've trained a prob_model
and created checkpoints. I then run prob_model.load_state
and attempt to produce predictions on the test set. However, I'm getting the following error:
...
pspec=PartitionSpec('processes',)
] b
from line [/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63](https://vscode-remote+ssh-002dremote-002brapidstart.vscode-resource.vscode-cdn.net/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63) (sync_global_devices)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The only thing I've done that is not standard is use my own custom model, which is here:
from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import jax
act = jax.nn.swish
class AlexNet(nn.Module):
output_dim: int
dtype: Any = jnp.float32
"""
An AlexNet model for Cifar10.
"""
def setup(self):
self.hidden_layers = AlexNetHiddenLayers(dtype=self.dtype)
self.last_layer = AlexNetLastLayer(output_dim=self.output_dim, dtype=self.dtype)
def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
x = self.hidden_layers(x, train)
x = self.last_layer(x, train)
return x
class AlexNetHiddenLayers(nn.Module):
dtype: Any = jnp.float32
"""
Hidden Convolutional layers of AlexNet model
"""
@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = True):
# [32, 32, 3]
x = nn.Conv(features=64, kernel_size=(3,))(x)
# [32, 32, 64]
x = act(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# [16, 16, 64]
x = nn.Conv(features=128, kernel_size=(3,))(x)
# [16, 16, 128]
x = act(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
# [8, 8, 128]
x = nn.Conv(features=256, kernel_size=(2,))(x)
# [8, 8, 256]
x = act(x)
x = nn.Conv(features=128, kernel_size=(2,))(x)
# [8, 8, 128]
x = act(x)
x = nn.Conv(features=64, kernel_size=(2,))(x)
# [8, 8, 64]
x = act(x)
x = x.reshape((x.shape[0], -1))
return x
class AlexNetLastLayer(nn.Module):
output_dim: int
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = True):
x = nn.Dense(features=256, dtype=self.dtype)(x)
x = act(x)
x = nn.Dense(features=256, dtype=self.dtype)(x)
x = act(x)
x = nn.Dense(features=self.output_dim, dtype=self.dtype)(x)
return x
Steps to reproduce:
# // Model
prob_model = ProbClassifier(
model=AlexNet(output_dim=10),
posterior_approximator=LaplacePosteriorApproximator(),
prior=IsotropicGaussianPrior(log_var=jnp.log(PRIOR_VAR))
)
prob_model.load_state("../sgd_checkpoints/checkpoint_11532/")
test_log_probs = prob_model.predictive.log_prob(data_loader=test_loader)
# RAISES ERROR
Other information:
The data is coming from a torch
dataloader, and converted with .from_torch_dataloader()
. Let me know if you need more information on the actual data.
My hunch is that maybe I'm doing something wrong here. Any guidance is appreciated :)
Hi Paul,
could you provide a reproducible example? The error you get is not at Fortuna's level, so I don't really know what's going on 😄
@gianlucadetommaso it looks like it was my mistake! I passed in the directory "/checkpoint_11532"
to load_state
instead of the file "/checkpoint_15532/checkpoint/"
. It might be useful to catch such an error (e.g. check if the input is a directory or file) because with orbax you pass in the directory.
Alright! I'm anyway refactoring the checkpointing to work with Orbax instead. This is part of #96 which will also enable model sharding.
Looking forward to it! @gianlucadetommaso. As I've been using fortuna more, I've had some thoughts as well for possible feature enhancement / pull requests. What would be the best place to discuss such things?
If you want to discuss something at high-level, I would open a Discussion. If you have a more concrete bug/feature request in mind, I'd suggest to open an issue. If you find some small problems and want directly to propose a quick fix, feel free to open a PR.