`ValueError` when passing both a continuous and discrete auxiliary variable
Closed this issue · 1 comments
CeliaBenquet commented
to reproduce:
import cebra
import numpy as np
cebra_model = cebra.CEBRA(max_iterations=10, batch_size=512)
cebra_model.fit(np.random.rand(100, 1000),
np.random.rand(2, 1000),
np.arange(0, 1000))
File ~/CEBRA-dev/cebra/integrations/sklearn/cebra.py:863, in CEBRA._prepare_fit(self, X, *y)
861 self.device_ = sklearn_utils.check_device(self.device)
862 self.offset_ = self._compute_offset()
--> 863 dataset, is_multisession = self._prepare_data(X, y)
865 loader, solver_name = self._prepare_loader(
866 dataset,
867 max_iterations=self.max_iterations,
868 is_multisession=is_multisession)
869 model = self._prepare_model(dataset, is_multisession)
File ~/CEBRA-dev/cebra/integrations/sklearn/cebra.py:655, in CEBRA._prepare_data(self, X, y)
653 else:
654 if not _are_sessions_equal(X, y):
--> 655 raise ValueError(
656 f\"Invalid number of samples or labels sessions: provide one session for single-session training, \"
657 f\"and make sure the number of samples in X and y need match, \"
658 f\"got {len(X)} and {[len(y_i) for y_i in y]}.\")
659 is_multisession = False
660 dataset = _get_dataset(X, y)
ValueError: Invalid number of samples or labels sessions: provide one session for single-session training, and make sure the number of samples in X and y need match, got 100 and [2, 1000]."
this is related to the _are_sessions_equal
method which doesn't handle multiple y, as it consider the case of y being a list of labels for the multisession training setup:
def _are_sessions_equal(X, y):
"""Check if data and labels have the same number of sessions for all sets of labels."""
return np.array([len(X) == len(y_i) for y_i in y]).all()
CeliaBenquet commented
FIxed... I should do
cebra_model.fit(np.random.rand(1000, 100),
np.random.rand(1000, 2),
np.arange(0, 1000))