Farama-Foundation/Minari

[Question] Counting total steps takes very long

Closed this issue · 2 comments

Question

The current implementation appears to iterate through the entire dataset to determine the length of each episode and then sums them to get the total number of steps. Is there a more efficient way to achieve this? Per-episode lengths or total steps seem like properties that could be accessed in O(1) time.

Edit: I found dataset.storage.total_steps can return the total step very fast without iterating all episodes. Is there any reason why we need to check if self.episode_indices is None?

The reason behind it is that a MinariDataset can point to a subset of the data in a MinariStorage, for example after filter_episodes.

However, I am noticing right now that there is a bug in the implementation that makes total_steps always compute. Specifically self.episode_indices is always not None.

Thanks for pointing this out! Do you want to make a PR to fix it? I did it as I think it is important to include it in 0.5.0 that we are releasing today, thanks!

To fix it, I would cache the total steps here

if episode_indices is None:
episode_indices = np.arange(self._data.total_episodes)

And remove the check here

if self.episode_indices is None:
self._total_steps = self.storage.total_steps

Sorry, I missed the notification. Thank you for the quick fix!