google-research/batch_rl

How to train offline agent on the huge dataset (50 Million) ?

LQNew opened this issue · 3 comments

LQNew commented

Hi, I have read your paper which was published on ICML 2020, now I try to do some research on the offline image data. I have noticed that when training the online agent, such as DQN, replay buffer capacity is usually set to be 1 million, when the size of collected data is above 1 million, the new data will cover the oldest data in the replay buffer. But when training DQN on offline data, such as your contributed data, the size of data is 50M, how do I train the agent on this so huge dataset? Since memory of the machine is limited, it's impossible to load 50M data into the memory once. I wonder that how you solved this problem and if you implement your idea in this project, please refer it to me. At last, I really appreciate your great job and your open-source code!

That's correct, it's not possible to load the entire dataset into memory. So, the way the dataset is stored is in 50 files of size 1M each corresponding to replay buffers for every 4 iterations seen during training.

Now, to load this huge dataset, I simply created a new FixedReplayBuffer class for loading the offline dataset which is created in the offline agents (e.g., here) which loads 5 random replay buffers of size 1M at each iteration from the set of 50 buffers.

That said, experiment with the loading buffers of 5M still has a large RAM requirement, so there are two possible alternatives:

  1. If you want to use the entire 50M dataset, you can use the same Atari dataset released by DeepMind which can be directly loaded into tensorflow without suffering from RAM issues. See this colab for an example.

  2. If you want to stick to dopamine, I suggestion I have is to use the smaller subsampled datasets as done in Section 6 of the paper as well as a couple of other papers (e.g, this NeurIPS'20 paper, ICLR'20 submission).

The way to do this is very simple: simply set the size of the replay buffer to be smaller than 1M (let's say 50000, so it'll only load the first 50000 data points from the dataset files, this can done through changing the size in the gin file here.

The following functions for the FixedReplayBuffer class would be helpful in that regards (note the clipping done to free unused RAM if replay_capacity is smaller than 1M):

  def _load_buffer(self, suffix):
    """Loads a OutOfGraphReplayBuffer replay buffer."""
    try:
      # pytype: disable=attribute-error
      tf.logging.info(
          f'Starting to load from ckpt {suffix} from {self._data_dir}')
      replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
          *self._args, **self._kwargs)
      replay_buffer.load(self._data_dir, suffix)
      # pylint: disable = protected-access
      replay_capacity = replay_buffer._replay_capacity
      tf.logging.info(f'Capacity: {replay_buffer._replay_capacity}')
      for name, array in replay_buffer._store.items():
        # This frees unused RAM if replay_capacity is smaller than 1M
        replay_buffer._store[name] = array[:replay_capacity + 100].copy()
        tf.logging.info(f'{name}: {array.shape}')
      tf.logging.info('Loaded replay buffer ckpt {} from {}'.format(
          suffix, self._data_dir))
      # pylint: enable=protected-access
      # pytype: enable=attribute-error
      return replay_buffer
    except tf.errors.NotFoundError:
      return None

Please let me know if you have any more questions!

LQNew commented

That's correct, it's not possible to load the entire dataset into memory. So, the way the dataset is stored is in 50 files of size 1M each corresponding to replay buffers for every 4 iterations seen during training.

Now, to load this huge dataset, I simply created a new FixedReplayBuffer class for loading the offline dataset which is created in the offline agents (e.g., here) which loads 5 random replay buffers of size 1M at each iteration from the set of 50 buffers.

That said, experiment with the loading buffers of 5M still has a large RAM requirement, so there are two possible alternatives:

  1. If you want to use the entire 50M dataset, you can use the same Atari dataset released by DeepMind which can be directly loaded into tensorflow without suffering from RAM issues. See this colab for an example.
  2. If you want to stick to dopamine, I suggestion I have is to use the smaller subsampled datasets as done in Section 6 of the paper as well as a couple of other papers (e.g, this NeurIPS'20 paper, ICLR'20 submission).

The way to do this is very simple: simply set the size of the replay buffer to be smaller than 1M (let's say 50000, so it'll only load the first 50000 data points from the dataset files, this can done through changing the size in the gin file here.

The following functions for the FixedReplayBuffer class would be helpful in that regards (note the clipping done to free unused RAM if replay_capacity is smaller than 1M):

  def _load_buffer(self, suffix):
    """Loads a OutOfGraphReplayBuffer replay buffer."""
    try:
      # pytype: disable=attribute-error
      tf.logging.info(
          f'Starting to load from ckpt {suffix} from {self._data_dir}')
      replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
          *self._args, **self._kwargs)
      replay_buffer.load(self._data_dir, suffix)
      # pylint: disable = protected-access
      replay_capacity = replay_buffer._replay_capacity
      tf.logging.info(f'Capacity: {replay_buffer._replay_capacity}')
      for name, array in replay_buffer._store.items():
        # This frees unused RAM if replay_capacity is smaller than 1M
        replay_buffer._store[name] = array[:replay_capacity + 100].copy()
        tf.logging.info(f'{name}: {array.shape}')
      tf.logging.info('Loaded replay buffer ckpt {} from {}'.format(
          suffix, self._data_dir))
      # pylint: enable=protected-access
      # pytype: enable=attribute-error
      return replay_buffer
    except tf.errors.NotFoundError:
      return None

Please let me know if you have any more questions!

Hi, @agarwl, thank you for your detailed reply! I have learned a lot from your reply, now my question has been answered.

Reopening this for visibility :)