a-antoniades/Neuroformer

Providing configs/NF/finetune_visnav_all.yaml config file

Closed this issue · 11 comments

First of all, thank you for your awesome work!

While following your instructions to fine-tune the Neuroformer model,
I noticed that the configs/NF/finetune_visnav_all.yaml file is missing in the repository.

If possible, could you kindly provide the file or a snippet of how to add Modalities.Behavior.Variables.(Data, dt, Predict, Objective) to the config file as stated in the readme? This would be greatly appreciated.

Thank you!

Hi! Thanks for using our codebase, and for your kind words.

Yes, basically all you need to do is change the modalities configuration to the following:

modalities:
  behavior:
    n_layers: 4
    variables:
      phi:
        data: phi
        dt: 0.05
        objective: regression
        predict: true
      speed:
        data: speed
        dt: 0.05
        objective: regression
        predict: true
      th:
        data: th
        dt: 0.05
        objective: regression
        predict: true
    window: 0.05

Notice predict=True , this means that the modality will be inferred, not used as input. You can also have a look ./configs/Visnav/lateral/mconf_predict_all.yaml. Please let me know if this helps!

Thank you so much for your prompt response!

When I tried to apply the config you provided to fine-tune a pretrained model created by running neuroformer_train.py with configs/Visnav/lateral/mconf_pretrain.yaml, I came across the following error.

Is this error expected (i.e., are we supposed to use behavioral inputs/outputs during pretraining in order to fine-tune later) or is there any thing you might think that I did wrong? Thank you for your kind help!

[Error message]

RuntimeError: Error(s) in loading state_dict for Neuroformer:
Missing key(s) in state_dict: "modality_embeddings.behavior.phi.mlp.ln.weight", "modality_embeddings.behavior.phi.mlp.ln.bias", "modality_embeddings.behavior.phi.mlp.mlp.0.weight", "modality_embeddings.behavior.phi.mlp.mlp.2.weight", "modality_embeddings.behavior.phi.temp_emb.div_term", "modality_embeddings.behavior.phi.temp_emb.pe", "modality_embeddings.behavior.speed.mlp.ln.weight", "modality_embeddings.behavior.speed.mlp.ln.bias", "modality_embeddings.behavior.speed.mlp.mlp.0.weight", "modality_embeddings.behavior.speed.mlp.mlp.2.weight", "modality_embeddings.behavior.speed.temp_emb.div_term", "modality_embeddings.behavior.speed.temp_emb.pe", "modality_embeddings.behavior.th.mlp.ln.weight", "modality_embeddings.behavior.th.mlp.ln.bias", "modality_embeddings.behavior.th.mlp.mlp.0.weight", "modality_embeddings.behavior.th.mlp.mlp.2.weight", "modality_embeddings.behavior.th.temp_emb.div_term", "modality_embeddings.behavior.th.temp_emb.pe", "modality_embeddings.behavior.pos_emb.inv_freq", "modality_projection_heads.behavior.phi.weight", "modality_projection_heads.behavior.phi.bias", "modality_projection_heads.behavior.speed.weight", "modality_projection_heads.behavior.speed.bias", "modality_projection_heads.behavior.th.weight", "modality_projection_heads.behavior.th.bias", "neural_visual_transformer.modalities_blocks.behavior.0.ln1.weight", "neural_visual_transformer.modalities_blocks.behavior.0.ln1.bias", "neural_visual_transformer.modalities_blocks.behavior.0.ln2.weight", "neural_visual_transformer.modalities_blocks.behavior.0.ln2.bias", "neural_visual_transformer.modalities_blocks.behavior.0.attn.query.weight", "neural_visual_transformer.modalities_blocks.behavior.0.attn.query.bias", "neural_visual_transformer.modalities_blocks.behavior.0.attn.key.weight", "neural_visual_transformer.modalities_blocks.behavior.0.attn.key.bias", "neural_visual_transformer.modalities_blocks.behavior.0.attn.value.weight", "neural_visual_transformer.modalities_blocks.behavior.0.attn.value.bias", "neural_visual_transformer.modalities_blocks.behavior.0.attn.proj.weight", "neural_visual_transformer.modalities_blocks.behavior.0.attn.proj.bias", "neural_visual_transformer.modalities_blocks.behavior.0.mlp.0.weight", "neural_visual_transformer.modalities_blocks.behavior.0.mlp.0.bias", "neural_visual_transformer.modalities_blocks.behavior.0.mlp.2.weight", "neural_visual_transformer.modalities_blocks.behavior.0.mlp.2.bias", "neural_visual_transformer.modalities_blocks.behavior.0.ln_f.weight", "neural_visual_transformer.modalities_blocks.behavior.0.ln_f.bias", "neural_visual_transformer.modalities_blocks.behavior.1.ln1.weight", "neural_visual_transformer.modalities_blocks.behavior.1.ln1.bias", "neural_visual_transformer.modalities_blocks.behavior.1.ln2.weight", "neural_visual_transformer.modalities_blocks.behavior.1.ln2.bias", "neural_visual_transformer.modalities_blocks.behavior.1.attn.query.weight", "neural_visual_transformer.modalities_blocks.behavior.1.attn.query.bias", "neural_visual_transformer.modalities_blocks.behavior.1.attn.key.weight", "neural_visual_transformer.modalities_blocks.behavior.1.attn.key.bias", "neural_visual_transformer.modalities_blocks.behavior.1.attn.value.weight", "neural_visual_transformer.modalities_blocks.behavior.1.attn.value.bias", "neural_visual_transformer.modalities_blocks.behavior.1.attn.proj.weight", "neural_visual_transformer.modalities_blocks.behavior.1.attn.proj.bias", "neural_visual_transformer.modalities_blocks.behavior.1.mlp.0.weight", "neural_visual_transformer.modalities_blocks.behavior.1.mlp.0.bias", "neural_visual_transformer.modalities_blocks.behavior.1.mlp.2.weight", "neural_visual_transformer.modalities_blocks.behavior.1.mlp.2.bias", "neural_visual_transformer.modalities_blocks.behavior.1.ln_f.weight", "neural_visual_transformer.modalities_blocks.behavior.1.ln_f.bias", "neural_visual_transformer.modalities_blocks.behavior.2.ln1.weight", "neural_visual_transformer.modalities_blocks.behavior.2.ln1.bias", "neural_visual_transformer.modalities_blocks.behavior.2.ln2.weight", "neural_visual_transformer.modalities_blocks.behavior.2.ln2.bias", "neural_visual_transformer.modalities_blocks.behavior.2.attn.query.weight", "neural_visual_transformer.modalities_blocks.behavior.2.attn.query.bias", "neural_visual_transformer.modalities_blocks.behavior.2.attn.key.weight", "neural_visual_transformer.modalities_blocks.behavior.2.attn.key.bias", "neural_visual_transformer.modalities_blocks.behavior.2.attn.value.weight", "neural_visual_transformer.modalities_blocks.behavior.2.attn.value.bias", "neural_visual_transformer.modalities_blocks.behavior.2.attn.proj.weight", "neural_visual_transformer.modalities_blocks.behavior.2.attn.proj.bias", "neural_visual_transformer.modalities_blocks.behavior.2.mlp.0.weight", "neural_visual_transformer.modalities_blocks.behavior.2.mlp.0.bias", "neural_visual_transformer.modalities_blocks.behavior.2.mlp.2.weight", "neural_visual_transformer.modalities_blocks.behavior.2.mlp.2.bias", "neural_visual_transformer.modalities_blocks.behavior.2.ln_f.weight", "neural_visual_transformer.modalities_blocks.behavior.2.ln_f.bias", "neural_visual_transformer.modalities_blocks.behavior.3.ln1.weight", "neural_visual_transformer.modalities_blocks.behavior.3.ln1.bias", "neural_visual_transformer.modalities_blocks.behavior.3.ln2.weight", "neural_visual_transformer.modalities_blocks.behavior.3.ln2.bias", "neural_visual_transformer.modalities_blocks.behavior.3.attn.query.weight", "neural_visual_transformer.modalities_blocks.behavior.3.attn.query.bias", "neural_visual_transformer.modalities_blocks.behavior.3.attn.key.weight", "neural_visual_transformer.modalities_blocks.behavior.3.attn.key.bias", "neural_visual_transformer.modalities_blocks.behavior.3.attn.value.weight", "neural_visual_transformer.modalities_blocks.behavior.3.attn.value.bias", "neural_visual_transformer.modalities_blocks.behavior.3.attn.proj.weight", "neural_visual_transformer.modalities_blocks.behavior.3.attn.proj.bias", "neural_visual_transformer.modalities_blocks.behavior.3.mlp.0.weight", "neural_visual_transformer.modalities_blocks.behavior.3.mlp.0.bias", "neural_visual_transformer.modalities_blocks.behavior.3.mlp.2.weight", "neural_visual_transformer.modalities_blocks.behavior.3.mlp.2.bias", "neural_visual_transformer.modalities_blocks.behavior.3.ln_f.weight", "neural_visual_transformer.modalities_blocks.behavior.3.ln_f.bias".

[The script I used to pretrain]

python neuroformer_train.py \
       --dataset lateral \
       --config configs/Visnav/lateral/mconf_pretrain.yaml

[The script I used to fine-tune]

python neuroformer_train.py --dataset lateral \
                            --finetune  \
                            --resume "models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/None/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/model.pt" \
                            --config configs/Visnav/lateral/mconf_finetune.yaml \
                            --loss_bprop speed phi th

Thanks for bringing this up. Seems like for some reason in this code model.load_state_dict() was not set to strict=False which is what allows to load the previously trained weights while initializing a model with new parameters (the behavioral prediction ones). Please pull the latest code from main. You should now be able to train! Please let me know if this works or you have any more questions!

Thank you for the prompt response!

After pulling the latest script, the code now runs past the point where the previous error occurred. However, another error related to the NF data loader has come up. I encountered this error previously when trying to pretrain the model with all modalities, and I suspect it might be related to the data loader for behavior data.

Have you seen this error before? I set up the repository according to the Readme, so I'm unsure why this issue is occurring.

Here are my environment details:

  • Python: 3.11.9
  • Conda: 23.9.0
  • Pip: 23.3.1
  • Pickle package (retrieved by executing import pickle; print(pickle.format_version)): 4.0

Thank you for your assistance!

File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'
Traceback (most recent call last):
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/queues.py", line 244, in _feed
obj = _ForkingPickler.dumps(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data//.conda/envs/neuroformer/lib/python3.11/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'NFDataloader.getitem..'

Looks like this is a pytorch versioning issue as the error is associated with the dataloader. Could you try upgrading Pytorch to the correct version for your own system setup? (cuda, os, etc)

(ps: no I have not seen that error before.)

Hi,

I've confirmed that I have CUDA version 12.2 and have upgraded PyTorch to version 2.3.0 by running pip3 install -U torch torchvision torchaudio, but the error still persists. Could you please let me know which CUDA and PyTorch versions you used to run this code?

Thank you so much for your help!

I believe I've found a temporary fix for those using CUDA 12.2 and PyTorch 2.3.0.

The error seems to be caused by the inability to pickle lambda functions (line 1085 and line 1103) within the getitem method of the NFDataLoader class.

I've updated the code as follows, and it now runs past the point where the NFDataLoader error previously occurred.

class NFDataloader(Dataset):
     ...
     def ret_defdict(self):
            return collections.defaultdict(dict)
     def __getitem__(self, idx):
                ...

                ## BEHAVIOR ##
                if self.modalities is not None:
                    x['modalities'] = collections.defaultdict(self.ret_defdict) # prev. line 1085
                    ...
                            if variable['predict'] is True:
                                # TODO: implement for more than just 0.05 curr window
                                # pick only current_state behavior
                                if 'modalities' not in y.keys():
                                    y['modalities'] = collections.defaultdict(self.ret_defdict) # prev. line 1103
                                ...

hmm, that's frustrating, sorry you had to go through that. Perhaps I should fix the versions of libraries to avoid issues like these in the future.

If you want, please feel free to make a pull request with this change, and I'll merge it to the main branch. If you don't do it in the next few days then I'll just do it myself.

Thanks a lot for engaging, and please do let me know if there are any other issues you're facing / need help with.

Thanks!

I have a minor question regarding fine-tuning. Should the training data be set up like this for fine-tuning? I'm referring to the last line in neuroformer_train.py:

    if args.finetune:
        trainer = Trainer(model, finetune_dataset, test_dataset, tconf, config)
    else:
        trainer = Trainer(model, train_dataset, test_dataset, tconf, config)
    trainer.train()

Depends on what you want to do. if you look at train_intervals, test_intervals, finetune_intervals = split_data_by_interval(intervals, r_split=0.8, r_split_ft=0.01) finetuning split is 0.01 of training data. Alternatively, you can just train on all the decoding training data, i.e instead of finetune_dataset you just use train_dataset to finetune.

ps: for any additional questions or issues, if they are unrelated to this thread please feel free to open a new issue or email me!

Got it! Thanks for the explanation! I'll reach out via email or open a new issue if I have any more questions. Thanks again for your help!