[QUESTION] Breaking changes in Replay Buffer and TransitionMinibatch
jamartinh opened this issue · 6 comments
Describe the bug
Hi @takuseno , in the previous version of d3rlpy I was able to use a different ReplayBuffer library.
To do that possible I was able to create a MiniBatch in the following way:
class MiniBatch(TransitionMiniBatch):
# _transitions: list = list()
_observations: np.ndarray
_actions: np.ndarray
_rewards: np.ndarray
_next_observations: np.ndarray
_terminals: np.ndarray
_n_steps: np.ndarray
def set_data(self, obs, act, rew, next_obs, done, n_steps):
# self._transitions = list()
self._observations = obs
self._actions = act
self._rewards = rew
self._next_observations = next_obs
self._terminals = done
self._n_steps = n_steps
@property
def observations(self):
return self._observations
@property
def actions(self):
return self._actions
@property
def rewards(self):
return self._rewards
@property
def next_observations(self):
return self._next_observations
@property
def terminals(self):
return self._terminals
@property
def n_steps(self):
return self._n_steps
def get_d3rlpy_batch(sample):
"""
utility to be able to convert any batch type to d3rlpy.TransitionMiniBatch
"""
observation_shape = sample["obs"][0].shape
action_size = sample["act"][0].shape[0]
observation = sample["obs"][0]
action = sample["act"][0]
reward = sample["rew"][0]
next_observation = sample["next_obs"][0]
terminal = sample["done"][0]
n_steps = np.ones_like(sample["done"][0], dtype=int)
transition = Transition(observation_shape, action_size, observation, action, reward, next_observation, terminal)
batch = MiniBatch([transition])
batch.set_data(sample["obs"],
sample["act"],
sample["rew"],
sample["next_obs"],
sample["done"],
n_steps)
return batch
And then all things worked well.
However, with the current implementation, which initially looked simpler I am facing the "problem?" of the intervals values.
As you may see I have reviewed the way intervals are used as: gamma**intervals so I opted to putting ones, another alternative is to putting it all zeroes which makes basically gamma=0.99^0 = 1 all the time.
def get_d3rlpy_batch(sample):
"""
utility to be able to convert any batch type to d3rlpy.TransitionMiniBatch
"""
batch = TransitionMiniBatch(observations=cast_recursively(sample["obs"], np.float32),
actions=cast_recursively(sample["act"], np.float32),
rewards=cast_recursively(sample["rew"], np.float32),
next_observations=cast_recursively(sample["next_obs"], np.float32),
terminals=cast_recursively(sample["done"], np.float32),
intervals=cast_recursively(np.ones_like(sample["done"]), np.float32),
)
return batch
It seems that the handling of the discount factor is now handled with the help of the new ReplayBuffer implementation?
Now my experiments diverge since the new refactored code.
For instance:
Before:
batch = get_d3rlpy_batch(sample)
if from_dataset:
metric = {"critic_loss": algo.impl.update_critic(batch)}
losses.append(metric)
# losses.append(algo.update(batch))
else:
# update algo parameters
losses.append(algo.update(batch))
Now I do:
batch = get_d3rlpy_batch(sample)
if from_dataset:
t_batch = TorchMiniBatch.from_batch(batch, args.device, observation_scaler=algo.observation_scaler,
action_scaler=algo.action_scaler)
metric = algo.impl.update_critic(t_batch)
losses.append(metric) # losses.append(algo.update(batch))
else:
# update algo parameters
metrics = algo.update(batch)
losses.append(metrics)
But as said, now the SAC algorithm diverges in a task where with the previous version it converged beautyfully.
Thanks in advance for any help on this
@jamartinh Thank you for reporting this. This issue was initially reported here #346 and considered to be fixed at this commit eab9e9f . If you use v2.2.0
, it's should be fixed. Could you check which version you're using?
Thanks @takuseno , so in case of not having at hand the intervals value since the batch is created from a different ReplayBuffer that does not store the intervals?
How should one fill this?
Which is the purpose of the intervals variable? and why it is used in SAC (online learning) ?
intervals
is used to support multi-step training. For the normal case (1-step TD), intervals
are supposed to be ones. This is used for all Q-learning-based algorithms in d3rlpy.
FYI: This is the place where interval
is set:
It seems that the issue has been resolved. Let me close this issue. Feel free to reopen this if there is any further discussion.