takuseno/d3rlpy

[QUESTION] Breaking changes in Replay Buffer and TransitionMinibatch

jamartinh opened this issue · 6 comments

Describe the bug
Hi @takuseno , in the previous version of d3rlpy I was able to use a different ReplayBuffer library.

To do that possible I was able to create a MiniBatch in the following way:

class MiniBatch(TransitionMiniBatch):
    # _transitions: list = list()
    _observations: np.ndarray
    _actions: np.ndarray
    _rewards: np.ndarray
    _next_observations: np.ndarray
    _terminals: np.ndarray
    _n_steps: np.ndarray

    def set_data(self, obs, act, rew, next_obs, done, n_steps):
        # self._transitions = list()
        self._observations = obs
        self._actions = act
        self._rewards = rew
        self._next_observations = next_obs
        self._terminals = done
        self._n_steps = n_steps

    @property
    def observations(self):
        return self._observations

    @property
    def actions(self):
        return self._actions

    @property
    def rewards(self):
        return self._rewards

    @property
    def next_observations(self):
        return self._next_observations

    @property
    def terminals(self):
        return self._terminals

    @property
    def n_steps(self):
        return self._n_steps

def get_d3rlpy_batch(sample):
    """
    utility to be able to convert any batch type to d3rlpy.TransitionMiniBatch
    """
    observation_shape = sample["obs"][0].shape
    action_size = sample["act"][0].shape[0]
    observation = sample["obs"][0]
    action = sample["act"][0]
    reward = sample["rew"][0]
    next_observation = sample["next_obs"][0]
    terminal = sample["done"][0]
    n_steps = np.ones_like(sample["done"][0], dtype=int)

    transition = Transition(observation_shape, action_size, observation, action, reward, next_observation, terminal)
    batch = MiniBatch([transition])
    batch.set_data(sample["obs"],
                   sample["act"],
                   sample["rew"],
                   sample["next_obs"],
                   sample["done"],
                   n_steps)
    return batch

And then all things worked well.

However, with the current implementation, which initially looked simpler I am facing the "problem?" of the intervals values.
As you may see I have reviewed the way intervals are used as: gamma**intervals so I opted to putting ones, another alternative is to putting it all zeroes which makes basically gamma=0.99^0 = 1 all the time.

def get_d3rlpy_batch(sample):
    """
    utility to be able to convert any batch type to d3rlpy.TransitionMiniBatch
    """

    batch = TransitionMiniBatch(observations=cast_recursively(sample["obs"], np.float32),
                                actions=cast_recursively(sample["act"], np.float32),
                                rewards=cast_recursively(sample["rew"], np.float32),
                                next_observations=cast_recursively(sample["next_obs"], np.float32),
                                terminals=cast_recursively(sample["done"], np.float32),
                                intervals=cast_recursively(np.ones_like(sample["done"]), np.float32),
                                )
    return batch
 

It seems that the handling of the discount factor is now handled with the help of the new ReplayBuffer implementation?

Now my experiments diverge since the new refactored code.

For instance:

Before:

            batch = get_d3rlpy_batch(sample)
            if from_dataset:
                metric = {"critic_loss": algo.impl.update_critic(batch)}
                losses.append(metric)
                # losses.append(algo.update(batch))
            else:
                # update algo parameters
                losses.append(algo.update(batch))

Now I do:

        batch = get_d3rlpy_batch(sample)
        if from_dataset:
            t_batch = TorchMiniBatch.from_batch(batch, args.device, observation_scaler=algo.observation_scaler,
                                                action_scaler=algo.action_scaler)
            metric = algo.impl.update_critic(t_batch)
            losses.append(metric)  # losses.append(algo.update(batch))
        else:
            # update algo parameters
            metrics = algo.update(batch)
            losses.append(metrics)

But as said, now the SAC algorithm diverges in a task where with the previous version it converged beautyfully.

Thanks in advance for any help on this

@jamartinh Thank you for reporting this. This issue was initially reported here #346 and considered to be fixed at this commit eab9e9f . If you use v2.2.0, it's should be fixed. Could you check which version you're using?

Thanks @takuseno , so in case of not having at hand the intervals value since the batch is created from a different ReplayBuffer that does not store the intervals?

How should one fill this?
Which is the purpose of the intervals variable? and why it is used in SAC (online learning) ?

intervals is used to support multi-step training. For the normal case (1-step TD), intervals are supposed to be ones. This is used for all Q-learning-based algorithms in d3rlpy.

FYI: This is the place where interval is set:

It seems that the issue has been resolved. Let me close this issue. Feel free to reopen this if there is any further discussion.