QData/spacetimeformer

Pickle issue due to lambda functions

josh0tt opened this issue · 2 comments

Traceback (most recent call last):
  File "train.py", line 961, in <module>
    main(args)
  File "train.py", line 941, in main
    trainer.fit(forecaster, datamodule=data_module)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 113, in launch
    mp.start_processes(
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 189, in start_processes
    process.start()
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/multiprocessing/context.py", line 283, in _Popen
    return Popen(process_obj)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/home/jott2/miniconda3/envs/spacetimeformer/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'Embedding.__init__.<locals>.<lambda>'

I'm not able to confirm this, but this is very likely to be a pytorch lightning version control issue. Try v1.6. I had to upgrade several times over the course of the project. I thought this was in the requirements.txt but that's apparently missing from the public version. There is one more code update coming with the final version of this paper over the next few weeks, but I don't expect to future-proof this repo to pytorch lightning because there are just too many breaking changes. Probably applies to #64 as well.... these dependencies change fast and you need to use a late 2021 or early 2022 version of everything.

Edit: if you have to use a newer version of pytorch lightning, this could be a ddp vs. dp problem. ddp launches multiple processes and pickles everything between them. Because of the efficient attention layers I only ever trained models on one multi-gpu node, so the training scripts are not tested with distributed training. If you are training on one node make sure you're using the slower but simpler data parallel mode.

Yeah looks like it was a pytorch lightning version control issue. I was able to get it running with pytorch-lightning==1.9.4 with the following modifications here:

accelerator="gpu",
devices=4,
strategy="ddp",

Looking forward to the updated final version!