AdaptiveMotorControlLab/CEBRA

`ValueError` when passing both a continuous and discrete auxiliary variable

Closed this issue · 1 comments

to reproduce:

import cebra 
import numpy as np 
cebra_model = cebra.CEBRA(max_iterations=10, batch_size=512)
cebra_model.fit(np.random.rand(100, 1000), 
                np.random.rand(2, 1000), 
                np.arange(0, 1000))
File ~/CEBRA-dev/cebra/integrations/sklearn/cebra.py:863, in CEBRA._prepare_fit(self, X, *y)
    861 self.device_ = sklearn_utils.check_device(self.device)
    862 self.offset_ = self._compute_offset()
--> 863 dataset, is_multisession = self._prepare_data(X, y)
    865 loader, solver_name = self._prepare_loader(
    866     dataset,
    867     max_iterations=self.max_iterations,
    868     is_multisession=is_multisession)
    869 model = self._prepare_model(dataset, is_multisession)

File ~/CEBRA-dev/cebra/integrations/sklearn/cebra.py:655, in CEBRA._prepare_data(self, X, y)
    653 else:
    654     if not _are_sessions_equal(X, y):
--> 655         raise ValueError(
    656             f\"Invalid number of samples or labels sessions: provide one session for single-session training, \"
    657             f\"and make sure the number of samples in X and y need match, \"
    658             f\"got {len(X)} and {[len(y_i) for y_i in y]}.\")
    659     is_multisession = False
    660     dataset = _get_dataset(X, y)

ValueError: Invalid number of samples or labels sessions: provide one session for single-session training, and make sure the number of samples in X and y need match, got 100 and [2, 1000]."

this is related to the _are_sessions_equal method which doesn't handle multiple y, as it consider the case of y being a list of labels for the multisession training setup:

        def _are_sessions_equal(X, y):
            """Check if data and labels have the same number of sessions for all sets of labels."""
            return np.array([len(X) == len(y_i) for y_i in y]).all()

FIxed... I should do

cebra_model.fit(np.random.rand(1000, 100),
np.random.rand(1000, 2),
np.arange(0, 1000))