awslabs/fortuna

bug: `ConcretizationTypeError` when trying to use `prob_model.predictive()`

PaulScemama opened this issue · 5 comments

Bug Report

Hi! I've trained a prob_model and created checkpoints. I then run prob_model.load_state and attempt to produce predictions on the test set. However, I'm getting the following error:

...
  pspec=PartitionSpec('processes',)
] b
    from line [/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63](https://vscode-remote+ssh-002dremote-002brapidstart.vscode-resource.vscode-cdn.net/home/pscemama/bayesian-conformal-sets/.venv/lib/python3.10/site-packages/orbax/checkpoint/utils.py:63) (sync_global_devices)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The only thing I've done that is not standard is use my own custom model, which is here:

from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import jax

act = jax.nn.swish


class AlexNet(nn.Module):
    output_dim: int
    dtype: Any = jnp.float32
    """
    An AlexNet model for Cifar10.
    """

    def setup(self):
        self.hidden_layers = AlexNetHiddenLayers(dtype=self.dtype)
        self.last_layer = AlexNetLastLayer(output_dim=self.output_dim, dtype=self.dtype)

    def __call__(self, x: jnp.ndarray, train: bool = True) -> jnp.ndarray:
        x = self.hidden_layers(x, train)
        x = self.last_layer(x, train)
        return x


class AlexNetHiddenLayers(nn.Module):
    dtype: Any = jnp.float32
    """
    Hidden Convolutional layers of AlexNet model
    """

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True):
        # [32, 32, 3]
        x = nn.Conv(features=64, kernel_size=(3,))(x)
        # [32, 32, 64]
        x = act(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # [16, 16, 64]

        x = nn.Conv(features=128, kernel_size=(3,))(x)
        # [16, 16, 128]
        x = act(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # [8, 8, 128]

        x = nn.Conv(features=256, kernel_size=(2,))(x)
        # [8, 8, 256]
        x = act(x)

        x = nn.Conv(features=128, kernel_size=(2,))(x)
        # [8, 8, 128]
        x = act(x)

        x = nn.Conv(features=64, kernel_size=(2,))(x)
        # [8, 8, 64]
        x = act(x)

        x = x.reshape((x.shape[0], -1))
        return x


class AlexNetLastLayer(nn.Module):
    output_dim: int
    dtype: Any = jnp.float32

    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = True):
        x = nn.Dense(features=256, dtype=self.dtype)(x)
        x = act(x)
        x = nn.Dense(features=256, dtype=self.dtype)(x)
        x = act(x)
        x = nn.Dense(features=self.output_dim, dtype=self.dtype)(x)
        return x

Steps to reproduce:

# // Model
prob_model = ProbClassifier(
    model=AlexNet(output_dim=10), 
    posterior_approximator=LaplacePosteriorApproximator(),
    prior=IsotropicGaussianPrior(log_var=jnp.log(PRIOR_VAR))
)
prob_model.load_state("../sgd_checkpoints/checkpoint_11532/")
test_log_probs = prob_model.predictive.log_prob(data_loader=test_loader)
# RAISES ERROR

Other information:

The data is coming from a torch dataloader, and converted with .from_torch_dataloader(). Let me know if you need more information on the actual data.

My hunch is that maybe I'm doing something wrong here. Any guidance is appreciated :)

Hi Paul,
could you provide a reproducible example? The error you get is not at Fortuna's level, so I don't really know what's going on 😄

@gianlucadetommaso it looks like it was my mistake! I passed in the directory "/checkpoint_11532" to load_state instead of the file "/checkpoint_15532/checkpoint/". It might be useful to catch such an error (e.g. check if the input is a directory or file) because with orbax you pass in the directory.

Alright! I'm anyway refactoring the checkpointing to work with Orbax instead. This is part of #96 which will also enable model sharding.

Looking forward to it! @gianlucadetommaso. As I've been using fortuna more, I've had some thoughts as well for possible feature enhancement / pull requests. What would be the best place to discuss such things?

If you want to discuss something at high-level, I would open a Discussion. If you have a more concrete bug/feature request in mind, I'd suggest to open an issue. If you find some small problems and want directly to propose a quick fix, feel free to open a PR.