[Selected index k out of range] while executing torch.topk
Opened this issue · 3 comments
SeungHunJeon commented
In the utils, density estimation function gets top k elements among the cdist.
The author manually set the k as 1000, but it seems the out of range.
(I refer the notebooks/latent_go_explore_maze.ipynb)
File [~/workspace/lge/lge/lge.py:330](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:330), in LatentGoExplore.explore(self, total_timesteps, callback)
[328](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:328) else:
[329](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:329) callback = [self.module_learner]
--> [330](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/lge.py:330) self.model.learn(total_timesteps, callback=callback, log_interval=1000)
File [~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:309](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:309), in SAC.learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
[299](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:299) def learn(
[300](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:300) self: SelfSAC,
[301](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:301) total_timesteps: int,
ref='~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:0'>0</a>;32m (...)
[306](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:306) progress_bar: bool = False,
[307](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:307) ) -> SelfSAC:
--> [309](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:309) return super().learn(
[310](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:310) total_timesteps=total_timesteps,
[311](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:311) callback=callback,
[312](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:312) log_interval=log_interval,
[313](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:313) tb_log_name=tb_log_name,
[314](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/anaconda3/envs/lge/lib/python3.9/site-packages/stable_baselines3/sac/sac.py:314) reset_num_timesteps=reset_num_timesteps,
...
[70](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/utils.py:70) cdist = torch.cdist(x, samples)
---> [71](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/utils.py:71) dist_to_kst = cdist.topk(k, largest=False)[0][:, -1]
[72](https://file+.vscode-resource.vscode-cdn.net/home/oem/workspace/lge/notebooks/~/workspace/lge/lge/utils.py:72) return -dist_to_kst
RuntimeError: selected index k out of range
datake commented
same issue here
ildefons commented
Same for me
ildefons commented
@datake @SeungHunJeon, a possible solution to deal with small sample size in the beginning is to add this 2 lines after K=1000 in utils.py:
k = 1000
if samples.shape[0] < int(k/2):
k = int(samples.shape[0]/2)