基于ChatGLM2的RLHF训练问题
UltraZeroyH opened this issue · 2 comments
[2023-08-12 01:22:11,409] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.10.0, git-hash=unknown, git-branch=unknown
Traceback (most recent call last):
File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/train_rlhf.py", line 373, in
main()
File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/train_rlhf.py", line 237, in main
rlhf_engine = DeepSpeedRLHFEngine(
File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/models/rlhf_engine.py", line 146, in init
self.ref = self._init_ref(
File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/models/rlhf_engine.py", line 245, in init_ref
ref_engine, * = deepspeed.initialize(model=ref_model,
File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/init.py", line 157, in initialize
config_class = DeepSpeedConfig(config, mpu)
File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 769, in init
self._configure_train_batch_size()
File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 942, in _configure_train_batch_size
self._batch_assertion()
File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 890, in _batch_assertion
assert train_batch == micro_batch * grad_acc * self.world_size, (
AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * gradient_acc_step * world_size 1024 != 4 * 1 * 8
在使用ChatGLM2作为sft和reward模型,在A100*8的环境上训练的时候,在第三阶段train_rlhf时出现如上报错,尝试了很多方法都没有解决,deepspeed版本是0.10.0,奇怪的点是当--actor_zero_stage是2的时候,能够成功装载actor模型,但是装载reference的时候仍然会报这个错,想请问一下作者有什么建议吗?
这个原因应该是系统认为在运行deepspeed.initialize()
之前world_size
一直都是1,所以ds_config['train_batch_size']
不需要乘上world_size
。只能在运行deepspeed.initialize()
之前,才把ds_config['train_batch_size']
改为乘上world_size
。
RL部分的代码还没来得及修复这个问题,具体可以参见pretrain_wo_trainer.py 第220-221行和pretrain_wo_trainer.py 第292行