arsedler9/lfads-torch

What are the learning parameters to avoid nan values?

Opened this issue · 9 comments

Hi Andrew,

I applied the multisession_PCR analysis based on the tutorial to my original data set. Then I got the following error.
It seems to be an error due to the training parameter. Below are the parameters for the architecture of the multisession_PCR.yaml file used for the test.
Can you please advise me which value to change and how to fix it?

[Error message]
ERROR trial_runner.py:993 -- Trial run_model_2595b_00009: Error processing event.
ray.exceptions.RayTaskError(ValueError): ray::ImplicitFunc.train() (pid=2684, ip=127.0.0.1, repr=run_model)
File "python\ray_raylet.pyx", line 859, in ray._raylet.execute_task
File "python\ray_raylet.pyx", line 863, in ray._raylet.execute_task
File "python\ray_raylet.pyx", line 810, in ray._raylet.execute_task.function_executor
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray_private\function_manager.py", line 674, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\util\tracing\tracing_helper.py", line 466, in _resume_span
return method(self, *_args, **_kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\trainable.py", line 355, in train
raise skipped from exception_cause(skipped)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\function_trainable.py", line 325, in entrypoint
return self._trainable_func(
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\util\tracing\tracing_helper.py", line 466, in _resume_span
return method(self, *_args, **_kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\function_trainable.py", line 651, in _trainable_func
output = fn()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\util.py", line 365, in inner
trainable(config, **fn_kwargs)
File "c:\windows\system32\lfads-torch\lfads_torch\run_model.py", line 78, in run_model
trainer.fit(
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 771, in fit
self._call_and_handle_interrupt(
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 724, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 812, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1237, in _run
results = self._run_stage()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1324, in _run_stage
return self._run_train()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1354, in _run_train
self.fit_loop.run()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run
self.advance(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 269, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run
self.advance(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 208, in advance
batch_output = self.batch_loop.run(batch, batch_idx)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run
self.advance(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run
self.advance(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 203, in advance
result = self._run_optimization(
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 256, in _run_optimization
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 369, in _optimizer_step
self.trainer._call_lightning_module_hook(
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1596, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\core\lightning.py", line 1625, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\core\optimizer.py", line 168, in step
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\strategies\strategy.py", line 193, in optimizer_step
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\plugins\precision\precision_plugin.py", line 155, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\optim\optimizer.py", line 140, in wrapper
out = func(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\optim\adamw.py", line 120, in step
loss = closure()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\plugins\precision\precision_plugin.py", line 140, in _wrap_closure
closure_result = closure()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 148, in call
self._result = self.closure(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 134, in closure
step_output = self._step_fn()
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 427, in _training_step
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1766, in _call_strategy_hook
output = fn(*args, **kwargs)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\strategies\strategy.py", line 333, in training_step
return self.model.training_step(*args, **kwargs)
File "c:\windows\system32\lfads-torch\lfads_torch\model.py", line 487, in training_step
return self._shared_step(batch, batch_idx, "train")
File "c:\windows\system32\lfads-torch\lfads_torch\model.py", line 357, in _shared_step
output = self.forward(
File "c:\windows\system32\lfads-torch\lfads_torch\model.py", line 231, in forward
ic_post = self.ic_prior.make_posterior(ic_mean, ic_std)
File "c:\windows\system32\lfads-torch\lfads_torch\modules\priors.py", line 30, in make_posterior
return Independent(Normal(post_mean, post_std), 1)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\distributions\normal.py", line 56, in init
super(Normal, self).init(batch_shape, validate_args=validate_args)
File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\distributions\distribution.py", line 56, in init
raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (980, 100)) of distribution Normal(loc: torch.Size([980, 100]), scale: torch.Size([980, 100])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], grad_fn=)

[parameters for the architecture of the multisession_PCR.yaml]
encod_data_dim: 150
encod_seq_len: 40
recon_seq_len: ${model.encod_seq_len}
ext_input_dim: 0
ic_enc_seq_len: 0
ic_enc_dim: 100
ci_enc_dim: 100
ci_lag: 1
con_dim: 100
co_dim: 6
ic_dim: 100
gen_dim: 100
fac_dim: 150

Hi! Does this happen immediately or after a few training steps? I think it's likely that the learning rate is too high. I would try running just a single model with scripts/run_single.py (using your new config) on the data a few times with lower learning rates to get a sense of what learning rates would be appropriate.

Thanks for the reply.
This error occurs after several training steps.
Which variable do you mean by learning rate?

This lr_init parameter controls the initial learning rate, which will be reduced over the course of training. I’d recommend trying a few different values (1e-3, 3e-4, 1e-4, 3e-5, etc.) and visualizing loss curves with tensorboard / wandb to see which allows a quick but stable descent.

Thank you for your kind suggestion.
I have tried various learning rates (1e-4, 1e-5, 1e-6, 1e-7, 1e-8) and still get the same error.
Are there any other possible causes?

Hmmm… what are your batch size, sequence length, and number of neurons? It looks like sequence length may be close to 1k. 100-300 steps is more typical. You’d probably find it easier to fit that data with a 3-5x larger bin size.

Those parameters are set in accordance with the tutorial. The batch size is 1000, the sequence size is 40 (20 ms bin × 40), and the number of neurons varies from session to session, ranging from 30 to 200. The number of sessions varies from 20 to 80 per brain region, and the number of conditions is 20 or 30.

Hmm, could you upload your files to Google Drive and send me a link to arsedler9@gmail.com?

Hey @Riverside-ms just checking in-- were you able to resolve this?

Sorry for the late reply. I have emailed you, please check it.