AdaptiveMotorControlLab/CEBRA

Adding example script with custom Loader for PyTorch API documentation

timonmerk opened this issue ยท 3 comments

The CEBRA documentation is very comprehensive and presents in a lot of detail the parameterization.
In the current form however the focus seems to explain the scikit-learn API and there is no example script for using the PyTorch API: https://cebra.ai/docs/usage.html

But for many options I am unsure how to parametrize them in the scikit-learn API. For example when using discrete behavioral data, it's currently not possible to specify empirical or discretesampling:

prior: str = dataclasses.field(

I think this is also intended to not overload the cebra.Cebra intialization or the model.fit() function with too many parameters?

Therefore I thought that maybe adding a minimal example in the usage.rst of how a dataloader with "non-scikitlearn API" conform parameters could be used using PyTorch directly:

import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch

neural_data = cebra.load_data(file="neural_data.npz", key="neural")
# continuous_label = cebra.load_data(
#    file="auxiliary_behavior_data.h5",
#    key="auxiliary_variables",
#    columns=["continuous1", "continuous2", "continuous3"],
# )

discrete_label = cebra.load_data(
    file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
    torch.from_numpy(neural_data).type(torch.FloatTensor),
    # continuous=torch.from_numpy(np.array(continuous_label)).type(torch.FloatTensor),
    discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to("cpu")

# 2. Define Cebra Model
neural_model = cebra.models.init(
    name="offset10-model",
    num_neurons=InputData.input_dimension,
    num_units=32,
    num_output=2,
).to("cpu")

InputData.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
    # temperature=0.001,
    # min_temperature=0.0001
).to("cpu")

Opt = torch.optim.Adam(
    list(neural_model.parameters()) + list(Crit.parameters()),
    # lr=0.001,
    weight_decay=0,
)

# 4. Initialize Cebra Model
cebra_model = cebra.solver.init(
    name="single-session",
    model=neural_model,
    criterion=Crit,
    optimizer=Opt,
    tqdm_on=True,
).to("cpu")

# 5. Define Data Loader
# loader = cebra.data.single_session.ContinuousDataLoader(
#    dataset=InputData, num_steps=1000, batch_size=200
# )
loader = cebra.data.single_session.DiscreteDataLoader(
    dataset=InputData, num_steps=1000, batch_size=200, prior="uniform"
)

# 6. Fit model
cebra_model.fit(loader=loader)

# 7. Transform embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
    neural_data, neural_model.get_offset().__len__(), axis=0
)
X_train_emb = cebra_model.transform(
    torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to("cpu")
).to("cpu")

# 8. Potentially plot embedding
plot_embedding(
    X_train_emb,
    discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
    markersize=10,
)

Thanks @timonmerk I think that's a great idea! We can also link it to the Allen Demo, which uses the PyTorch API: https://cebra.ai/docs/demo_notebooks/Demo_Allen.html -- would you like to make a PR?

Edit to add, we DO have an example, just to be sure you saw it, https://cebra.ai/docs/usage.html#quick-start-torch-api-example, but agree it's not nearly as expressive as your good suggestion to make the sklearn Quick Start guide!

Thanks! Yes, that notebook was super helpful already but might be good to link it also in the usage.rst.
I also compiled the documentation and ran my linked example locally, but of course it's a bit difficult to test it since it's in a rst file only..

stes commented

but of course it's a bit difficult to test it since it's in a rst file only..

Actually this is included in the unit tests!

Locally, you can run make test or also directly (adapt as needed)

python -m pytest --ff --doctest-modules -m "not requires_dataset" tests ./docs/source/usage.rst cebra