[BUG] error when running dt on atari
AndssY opened this issue · 2 comments
AndssY commented
https://github.com/takuseno/d3rlpy/releases/tag/v2.2.0
running
import d3rlpy
dataset, env = d3rlpy.datasets.get_atari_transitions(
'breakout',
fraction=0.01,
num_stack=4,
)
dt = d3rlpy.algos.DiscreteDecisionTransformerConfig(
batch_size=64,
num_heads=1,
learning_rate=1e-4,
max_timestep=1000,
num_layers=3,
position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
encoder_factory=d3rlpy.models.VectorEncoderFactory([128], exclude_last_activation=True),
observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
context_size=20,
warmup_tokens=100000,
).create()
dt.fit(
dataset,
n_steps=100000,
n_steps_per_epoch=1000,
eval_env=env,
eval_target_return=500,
)
return error
Traceback (most recent call last):
......
......
File "...d3rlpy/d3rlpy/preprocessing/observation_scalers.py", line 322, in fit_with_trajectory_slicer
total_sum += np.sum(traj.observations, axis=0)
ValueError: non-broadcastable output operand with shape (1,84,84) doesn't match the broadcast shape (4,84,84)
takuseno commented
@AndssY Hi, thank you for the issue. Please check this example:
You need to change encoder_factory
and observation_scaler
to resolve the issue.
AndssY commented
Thanks!