[Help] 6b-int4 lora 微调使用 adam 优化器时梯度爆炸
wizardforcel opened this issue · 1 comments
wizardforcel commented
Is there an existing issue for this?
- I have searched the existing issues
Current Behavior
LORA 实现:https://github.com/DracoUnion/chatglm2-6b-int4-lora/blob/master/model/lora.py
训练代码:
import sys
from utils import *
from os import path
import torch
import argparse
def train_handle(args):
if not args.fname.endswith('.jsonl'):
print('请提供 JSONL 文件')
return
ds = open(args.fname, encoding='utf8').read().split('\n')
ds = [
json.loads(line) for line in ds if line.strip()
]
if not args.lora_path and path.isfile(args.save_path):
args.lora_path = args.save_path
llm, tokenizer = load_pytorch_llm(args.base_path, args.model_path, args.lora_path)
llm.attach_lora()
torch.autograd.set_detect_anomaly(True)
if args.adam:
optimizer = torch.optim.Adam(llm.parameters(), lr=args.lr)
else:
optimizer = torch.optim.SGD(llm.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=args.schedule_step,
gamma=args.schedule_gamma
)
step = 0
for epoch in range(args.n_epoch):
for i, dsi in enumerate(ds):
# 组装问答和问答提示词
ques = tokenizer.build_prompt(combine_prompt_args(args.ques_prompt, dsi))
ans = combine_prompt_args(args.ans_prompt, dsi)
# 问答转成问答 ID
ques_ids = tokenizer.encode(text=ques, add_special_tokens=True, truncation=True)
ans_ids = tokenizer.encode(text=ans, add_special_tokens=False, truncation=True)
# 问答 ID 拼接输入 ID
input_ids = ques_ids + ans_ids + [tokenizer.eos_token_id]
output_ids = [tokenizer.pad_token_id] * len(ques_ids) + ans_ids + [tokenizer.eos_token_id]
# 忽略 <PAD>
output_ids = [(oid if oid != tokenizer.pad_token_id else -100) for oid in output_ids]
# 因为批量大小为 1,无需填充
optimizer.zero_grad()
input_ids = torch.tensor([input_ids]).cuda()
output_ids = torch.tensor([output_ids]).cuda()
with torch.autograd.detect_anomaly():
loss = llm.forward(input_ids=input_ids, labels=output_ids, return_dict=True).loss
loss.backward()
print(
f'epoch: {epoch}\n' +
f'step: {step}\n' +
f'ques: {json.dumps(ques, ensure_ascii=False)}\n' +
f'ans: {json.dumps(ans, ensure_ascii=False)}\n' +
f'loss: {loss}'
)
# 更新梯度
torch.nn.utils.clip_grad_norm_(llm.parameters(), 0.1)
optimizer.step()
scheduler.step()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# 一定步骤保存权重
if step % args.save_step == 0:
torch.save(llm.lora_state_dict(), save_path)
step += 1
# 如果批量大小不等于 1,需要乘批量大小
# prog_callback(int(step / self_args['n_epoch'] / len(ds) * 100))
if step % args.save_step != 0:
torch.save(llm.lora_state_dict(), save_path)
def main():
parser = argparse.ArgumentParser(prog="GPTTestTrain", description="GPTTestTrain", formatter_class=argparse.RawDescriptionHelpFormatter)
# parser.add_argument("-v", "--version", action="version", version=f"BookerMarkdownTool version: {__version__}")
# parser.set_defaults(func=lambda x: parser.print_help())
# subparsers = parser.add_subparsers()
# train_parser = subparsers.add_parser("train", help="train GLM model")
parser.add_argument("fname", help="jsonl file name")
parser.add_argument("-q", "--ques-prompt", default="{question}", help="prompt for question")
parser.add_argument("-a", "--ans-prompt", default="{answer}", help="prompt for answer")
parser.add_argument("-b", "--base-path", default='/data/chatglm2-6b-int4-lora/model', help="path for model code")
parser.add_argument("-m", "--model-path", help="path for model param (optional)")
parser.add_argument("-l", "--lora-path", help="path for lora param")
parser.add_argument("--adam", action='store_true', help="use adam")
parser.add_argument("save_path", help="path to save lora param")
parser.add_argument("--lr", type=float, default=1e-4, help="lr")
parser.add_argument("--schedule-step", type=int, default=500, help="lr schedule step")
parser.add_argument("--schedule-gamma", type=float, default=0.9, help="lr schedule gamma")
parser.add_argument("--save-step", type=int, default=30, help="save_step")
parser.add_argument("-n", "--n-epoch", type=int, default=15, help="n_epoch")
parser.set_defaults(func=train_handle)
args = parser.parse_args()
args.func(args)
if __name__ == '__main__': main()
数据集:AdvertiseGen
使用Adam优化器,在第二个样本训练的时候检测到了 NAN,位于 logsoftmax:
# python train.py /data/AdvertiseGen/dev.jsonl test.pth -q "请根据以下关键词生成文案:{content}" -a "{summary}" --adam --lr='1e-4'
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
train.py:57: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
epoch: 0
step: 0
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞\n\n 答:"
ans: "简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。"
loss: 5.8046875
/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in LogSoftmaxBackward0. Traceback of forward call that caused the error:
File "train.py", line 107, in <module>
if __name__ == '__main__': main()
File "train.py", line 105, in main
args.func(args)
File "train.py", line 58, in train_handle
loss = llm.forward(input_ids=input_ids, labels=output_ids, return_dict=True).loss
File "/root/.cache/huggingface/modules/transformers_modules/model/modeling_chatglm.py", line 960, in forward
loss = loss_fct(shift_logits, shift_labels)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 3053, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "train.py", line 107, in <module>
if __name__ == '__main__': main()
File "train.py", line 105, in main
args.func(args)
File "train.py", line 59, in train_handle
loss.backward()
File "/root/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.
使用 SGD 则完全没问题:
# python train.py /data/AdvertiseGen/dev.jsonl test.pth -q "请根据以
下关键词生成文案:{content}" -a "{summary}"
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
train.py:57: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
with torch.autograd.detect_anomaly():
epoch: 0
step: 0
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞\n\n 答:"
ans: "简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。"
loss: 5.8046875
epoch: 0
step: 1
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#裙*材质#针织*颜色#纯色*风格#复古*风格#文艺*风格#简约*图案#格子*图案#纯色*图案#复古*裙型#背带裙*裙长#连衣裙*裙领型#半高领\n\n答:"
ans: "这款BRAND针织两件套连衣裙,简约的纯色半高领针织上衣,修饰着颈部线,尽显优雅气质。同时搭配叠穿起一条背带式的复古格纹裙,整体 散发着一股怀旧的时髦魅力,很是文艺范。"
loss: 7.48828125
epoch: 0
step: 2
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#上衣*风格#嘻哈*图案#卡通*图案#印花*图案#撞色*衣样式#卫衣*衣款式#连帽\n\n答 :"
ans: "嘻哈玩转童年,随时<UNK>,没错,出街还是要靠卫衣来装酷哦!时尚个性的连帽设计,率性有范还防风保暖。还有胸前撞色的卡通印花设计 ,靓丽抢眼更富有趣味性,加上前幅大容量又时尚美观的袋鼠兜,简直就是孩子耍帅装酷必备的利器。"
loss: 6.48046875
epoch: 0
step: 3
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#裤*风格#英伦*风格#简约\n\n答:"
ans: "裤子是简约大方的版型设计,带来一种极简主义风格而且不乏舒适优雅感,是衣橱必不可少的一件百搭单品。标志性的logo可以体现出一股子浓郁的英伦风情,轻而易举带来独一无二的<UNK>体验。"
loss: 7.12109375
epoch: 0
step: 4
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#裙*裙下摆#弧形*裙腰型#高腰*裙长#半身裙*裙款式#不规则*裙款式#收腰\n\n答:"
ans: "这款来自梵凯的半身裙富有十足的设计感,采用了别致的不规则设计,凸显出时尚前卫的格调,再搭配俏皮的高腰设计,收腰提臀的同时还勾勒出优美迷人的身材曲线,而且还帮你拉长腿部比例,释放出优雅娇俏的小女人味。并且独特的弧形下摆还富有流畅的线条美,一颦一动间展现出灵动柔美的气质。"
loss: 5.9609375
不知道如何排查,是超参数设置不对吗?
Expected Behavior
No response
Steps To Reproduce
...
Environment
- OS: Linux autodl-container-57094f9364-479b788e 5.4.0-153-generic #170-Ubuntu SMP Fri Jun 16 13:43:31 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux
- Python: Python 3.8.10
- Transformers: 4.33.2
- PyTorch: 2.1.1+cu118
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) : True
Anything else?
No response
wizardforcel commented
已解决:FP16 使用 Adam 优化的时候需要eps=1e-3