[BUG] Issue with CUDA Device-Side Assertion Failure During Training
Lucien-Evans-123 opened this issue ยท 2 comments
Hello d3rlpy author,
I've encountered a recurring CUDA device-side assertion failure while training a DiscreteDecisionTransformer model with my dataset. The training process starts but almost immediately fails with a device-side assert triggered error. This issue persists even after verifying tensor device consistency and index ranges. Below is a summary of the code leading to the error and the error message itself (with repetitive parts abbreviated for clarity).
My code:
dt = d3rlpy.algos.DiscreteDecisionTransformerConfig().create(device="cuda:0")
dt.fit(
dataset_mr,
n_steps=100000,
n_steps_per_epoch=1000,
eval_env=env,
eval_target_return=0.001, # specify target environment return
)
Some Error Message:
2024-03-05 15:41.25 [info ] Signatures have been automatically determined. action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]) observation_signature=Signature(dtype=[dtype('float32')], shape=[(1127,)]) reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)])
2024-03-05 15:41.25 [info ] Action-space has been automatically determined. action_space=<ActionSpace.DISCRETE: 2>
2024-03-05 15:41.25 [info ] Action size has been automatically determined. action_size=60
2024-03-05 15:41.25 [info ] dataset info dataset_info=DatasetInfo(observation_signature=Signature(dtype=[dtype('float32')], shape=[(1127,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=60)
2024-03-05 15:41.25 [info ] Directory is created at d3rlpy_logs/DiscreteDecisionTransformer_20240305154125
2024-03-05 15:41.25 [debug ] Building models...
2024-03-05 15:41.26 [debug ] Models have been built.
2024-03-05 15:41.26 [info ] Parameters params={'observation_shape': [1127], 'action_size': 60, 'config': {'type': 'discrete_decision_transformer', 'params': {'batch_size': 128, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'context_size': 20, 'max_timestep': 1000, 'learning_rate': 0.0006, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'optim_factory': {'type': 'adam', 'params': {'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'num_heads': 8, 'num_layers': 6, 'attn_dropout': 0.1, 'resid_dropout': 0.1, 'embed_dropout': 0.1, 'activation_type': 'gelu', 'embed_activation_type': 'tanh', 'position_encoding_type': <PositionEncodingType.GLOBAL: 'global'>, 'warmup_tokens': 10240, 'final_tokens': 30000000, 'clip_grad_norm': 1.0, 'compile': False}}}
Epoch 1/100: 0%| | 0/1000 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [43,0,0], thread: [64,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [43,0,0], thread: [65,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [43,0,0], thread: [66,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
File "anaconda3/envs/dt/lib/python3.9/site-packages/d3rlpy/models/torch/transformers.py", line 204, in forward
global_embedding = torch.gather(batched_global_embedding, 1, last_t)
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
The device consistency between tensors (batched_global_embedding and last_t) is confirmed, and both tensors are correctly positioned on the same CUDA device. Furthermore, the index values used with torch.gather are within the bounds of the tensor dimensions, suggesting that the issue might not be directly related to tensor indexing or device allocation.
Could you please provide any insights or suggestions on how to address or debug this issue further? I'm running this on PyTorch version 2.2.1+cu121 with a compatible CUDA version. This problem has been quite persistent, and any guidance or assistance would be greatly appreciated.
Thank you for your time and help.
@Lucien-Evans-123 Hi, sorry for the late response. My guess is that episode lengths in your datasets could be longer than max_timestep
, whose default value is 1000
. If it's the case, you can solve this by:
dt = d3rlpy.algos.DiscreteDecisionTransformerConfig(max_timestep=<max_episode_length>).create(device="cuda:0")
@Lucien-Evans-123 Hi, sorry for the late response. My guess is that episode lengths in your datasets could be longer than
max_timestep
, whose default value is1000
. If it's the case, you can solve this by:dt = d3rlpy.algos.DiscreteDecisionTransformerConfig(max_timestep=<max_episode_length>).create(device="cuda:0")
It work! Thank you!