🇨🇳中文 | 🌐English | 📖文档/Docs | 🤖模型/Models
MedicalGPT: Training Medical GPT Model
📖 Introduction
MedicalGPT training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining, Supervised Finetuning, Reward Modeling and Reinforcement Learning.
MedicalGPT 训练医疗大模型,实现包括二次预训练、有监督微调、奖励建模、强化学习训练。
分四阶段训练GPT模型,来自Andrej Karpathy的演讲PDFState of GPT,视频Video
基于此,训练领域模型--医疗模型,分四阶段:
- 第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练LLaMA模型,以注入领域知识
- 第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图
- 第三阶段:RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来对齐人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"
- 第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF),用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
▶️ Demo
- Hugging Face Demo: doing
我们提供了一个简洁的基于gradio的交互式web界面,启动服务后,可通过浏览器访问,输入问题,模型会返回答案。
启动服务,命令如下:
python scripts/gradio_demo.py --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
参数说明:
--base_model {base_model}
:存放HF格式的LLaMA模型权重和配置文件的目录,也可使用HF Model Hub模型调用名称--lora_model {lora_model}
:LoRA文件所在目录,也可使用HF Model Hub模型调用名称。若lora权重已经合并到预训练模型,则删除--lora_model参数--tokenizer_path {tokenizer_path}
:存放对应tokenizer的目录。若不提供此参数,则其默认值与--lora_model相同;若也未提供--lora_model参数,则其默认值与--base_model相同--use_cpu
: 仅使用CPU进行推理--gpus {gpu_ids}
: 指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
🚀 Training Pipeline
Training Stage:
Stage | Introduction | Notebook | Colab | Python script | Shell script |
---|---|---|---|---|---|
Stage 1: Continue Pretraining | 增量预训练 | run_pretraining.ipynb | pretraining.py | run_pt.sh | |
Stage 2: Supervised Fine-tuning | 有监督微调 | run_supervised_finetuning.ipynb | supervised_finetuning.py | run_sft.sh | |
Stage 3: Reward Modeling | 奖励模型建模 | run_reward_modeling.ipynb | reward_modeling.py | run_rm.sh | |
Stage 4: Reinforcement Learning | 强化学习 | run_rl_training.ipynb | rl_training.py | run_rl.sh |
🔥 Inference
训练完成后,现在我们加载训练好的模型,验证模型生成文本的效果。
python scripts/inference.py \
--base_model path_to_llama_hf_dir \
--lora_model path_to_lora \
--with_prompt \
--interactive
参数说明:
--base_model {base_model}
:存放HF格式的LLaMA模型权重和配置文件的目录--lora_model {lora_model}
:LoRA解压后文件所在目录,也可使用HF Model Hub模型调用名称。如果已经合并了LoRA权重到预训练模型,则可以不提供此参数--tokenizer_path {tokenizer_path}
:存放对应tokenizer的目录。若不提供此参数,则其默认值与--lora_model相同;若也未提供--lora_model参数,则其默认值与--base_model相同--with_prompt
:是否将输入与prompt模版进行合并。如果加载Alpaca模型,请务必启用此选项!--interactive
:以交互方式启动,以便进行多次单轮问答--data_file {file_name}
:非交互方式启动下,按行读取file_name中的的内容进行预测--predictions_file {file_name}
:非交互式方式下,将预测的结果以json格式写入file_name--use_cpu
: 仅使用CPU进行推理--gpus {gpu_ids}
: 指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
Inference Examples
Human | Assitant |
---|---|
怎么让自己精力充沛,列5点建议 | 以下是一些可以让你保持精力充沛的建议: 1. 睡眠充足。确保每天晚上得到充足的睡眠时间。 2. 多运动。进行适度的有氧运动,如慢跑、游泳等,可以帮助你提高能量和耐力。 3. 吃健康食品。选择富含蛋白质的食物,如鸡肉、鱼肉和鸡蛋,以及新鲜的水果和蔬菜。 4. 减少咖啡因的摄入量。尽量避免在白天喝含咖啡因的饮料或吃含有咖啡因的食物。 5. 放松身心。尝试冥想、瑜伽或其他形式的放松活动来减轻压力和焦虑感。 |
📚 Dataset
- 240万条中文医疗数据集(包括预训练、指令微调和奖励数据集):shibing624/medical
附上一些通用数据集和医疗数据集的链接
- 50万条中文ChatGPT指令Belle数据集:BelleGroup/train_0.5M_CN
- 100万条中文ChatGPT指令Belle数据集:BelleGroup/train_1M_CN
- 5万条英文ChatGPT指令Alpaca数据集:50k English Stanford Alpaca dataset
- 2万条中文GPT-4指令Alpaca数据集:shibing624/alpaca-zh
- 69万条中文指令Guanaco数据集(Belle50万条+Guanaco19万条):Chinese-Vicuna/guanaco_belle_merge_v1.0
- 22万条中文医疗对话数据集(华佗项目):FreedomIntelligence/HuatuoGPT-sft-data-v1
✅ Todo
- Added multi-round dialogue data fine-tuning method
- add reward model fine-tuning
- add rl fine-tuning
- add medical reward dataset
- add llama in8/int4 training
- add all training and predict demo in colab
☎️ Contact
- Issue(建议) :
- 邮件我:xuming: xuming624@qq.com
- 微信我: 加我微信号:xuming624, 备注:姓名-公司名-NLP 进NLP交流群。
⚠️ 局限性、使用限制与免责声明
基于当前数据和基础模型训练得到的SFT模型,在效果上仍存在以下问题:
-
在涉及事实性的指令上可能会产生违背事实的错误回答。
-
对于具备危害性的指令无法很好的鉴别,由此会产生危害性言论。
-
在一些涉及推理、代码、多轮对话等场景下模型的能力仍有待提高。
基于以上模型局限性,我们要求开发者仅将我们开源的模型权重及后续用此项目生成的衍生物用于研究目的,不得用于商业,以及其他会对社会带来危害的用途。
本项目仅可应用于研究目的,项目开发者不承担任何因使用本项目(包含但不限于数据、模型、代码等)导致的危害或损失。详细请参考免责声明。
项目代码的授权协议为 The Apache License 2.0,代码可免费用做商业用途,模型权重和数据只能用于研究目的。请在产品说明中附加MedicalGPT的链接和授权协议。
😇 Citation
如果你在研究中使用了MedicalGPT,请按如下格式引用:
@misc{MedicalGPT,
title={MedicalGPT: Training Medical GPT Model},
author={Ming Xu},
year={2023},
howpublished={\url{https://github.com/shibing624/MedicalGPT}},
}
😍 Contribute
项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
- 在
tests
添加相应的单元测试 - 使用
python -m pytest
来运行所有单元测试,确保所有单测都是通过的
之后即可提交PR。
💕 Acknowledgements
Thanks for their great work!