THUDM/ChatGLM2-6B

[Help] 6b-int4 lora 微调使用 adam 优化器时梯度爆炸

wizardforcel opened this issue · 1 comments

Is there an existing issue for this?

  • I have searched the existing issues

Current Behavior

LORA 实现:https://github.com/DracoUnion/chatglm2-6b-int4-lora/blob/master/model/lora.py

训练代码:

import sys
from utils import *
from os import path
import torch
import argparse

def train_handle(args):
    if not args.fname.endswith('.jsonl'):
        print('请提供 JSONL 文件')
        return
    ds = open(args.fname, encoding='utf8').read().split('\n')
    ds = [
        json.loads(line) for line in ds if line.strip()
    ]
    if not args.lora_path and path.isfile(args.save_path):
        args.lora_path = args.save_path
    llm, tokenizer = load_pytorch_llm(args.base_path, args.model_path, args.lora_path)
    llm.attach_lora()
    torch.autograd.set_detect_anomaly(True)
    if args.adam:
        optimizer = torch.optim.Adam(llm.parameters(), lr=args.lr)
    else:
        optimizer = torch.optim.SGD(llm.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=args.schedule_step, 
        gamma=args.schedule_gamma
    )
    step = 0
    for epoch in range(args.n_epoch):
        for i, dsi in enumerate(ds):
            # 组装问答和问答提示词
            ques = tokenizer.build_prompt(combine_prompt_args(args.ques_prompt, dsi))
            ans = combine_prompt_args(args.ans_prompt, dsi)
            # 问答转成问答 ID
            ques_ids = tokenizer.encode(text=ques, add_special_tokens=True, truncation=True)
            ans_ids = tokenizer.encode(text=ans, add_special_tokens=False, truncation=True)
            # 问答 ID 拼接输入 ID
            input_ids = ques_ids + ans_ids + [tokenizer.eos_token_id]
            output_ids = [tokenizer.pad_token_id] * len(ques_ids) + ans_ids + [tokenizer.eos_token_id] 
            # 忽略 <PAD>
            output_ids = [(oid if oid != tokenizer.pad_token_id else -100) for oid in output_ids]
            # 因为批量大小为 1,无需填充
            optimizer.zero_grad()
            input_ids = torch.tensor([input_ids]).cuda()
            output_ids = torch.tensor([output_ids]).cuda()
            with torch.autograd.detect_anomaly(): 
                loss = llm.forward(input_ids=input_ids, labels=output_ids, return_dict=True).loss
                loss.backward()
            print(
                f'epoch: {epoch}\n' + 
                f'step: {step}\n' + 
                f'ques: {json.dumps(ques, ensure_ascii=False)}\n' + 
                f'ans: {json.dumps(ans, ensure_ascii=False)}\n' + 
                f'loss: {loss}'
            )
            # 更新梯度
            torch.nn.utils.clip_grad_norm_(llm.parameters(), 0.1)
            optimizer.step()
            scheduler.step()
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            # 一定步骤保存权重
            if step % args.save_step == 0:
                torch.save(llm.lora_state_dict(), save_path)
            step += 1
            # 如果批量大小不等于 1,需要乘批量大小
            # prog_callback(int(step / self_args['n_epoch'] / len(ds) * 100))

    if step % args.save_step != 0:
        torch.save(llm.lora_state_dict(), save_path)
        
def main():
    parser = argparse.ArgumentParser(prog="GPTTestTrain", description="GPTTestTrain", formatter_class=argparse.RawDescriptionHelpFormatter)
    # parser.add_argument("-v", "--version", action="version", version=f"BookerMarkdownTool version: {__version__}")
    # parser.set_defaults(func=lambda x: parser.print_help())
    # subparsers = parser.add_subparsers()
    # train_parser = subparsers.add_parser("train", help="train GLM model")
    parser.add_argument("fname", help="jsonl file name")
    parser.add_argument("-q", "--ques-prompt", default="{question}", help="prompt for question")
    parser.add_argument("-a", "--ans-prompt", default="{answer}", help="prompt for answer")
    parser.add_argument("-b", "--base-path", default='/data/chatglm2-6b-int4-lora/model', help="path for model code")
    parser.add_argument("-m", "--model-path", help="path for model param (optional)")
    parser.add_argument("-l", "--lora-path", help="path for lora param")
    parser.add_argument("--adam", action='store_true', help="use adam")
    parser.add_argument("save_path", help="path to save lora param")
    parser.add_argument("--lr", type=float, default=1e-4, help="lr")
    parser.add_argument("--schedule-step", type=int, default=500, help="lr schedule step")
    parser.add_argument("--schedule-gamma", type=float, default=0.9, help="lr schedule gamma")
    parser.add_argument("--save-step", type=int, default=30, help="save_step")
    parser.add_argument("-n", "--n-epoch", type=int, default=15, help="n_epoch")
    parser.set_defaults(func=train_handle)
    
    args = parser.parse_args()
    args.func(args)

if __name__ == '__main__': main()
    

数据集:AdvertiseGen

使用Adam优化器,在第二个样本训练的时候检测到了 NAN,位于 logsoftmax:

# python train.py /data/AdvertiseGen/dev.jsonl  test.pth -q "请根据以下关键词生成文案:{content}" -a "{summary}" --adam --lr='1e-4'
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
train.py:57: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
epoch: 0
step: 0
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞\n\n 答:"
ans: "简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。"
loss: 5.8046875
/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in LogSoftmaxBackward0. Traceback of forward call that caused the error:
  File "train.py", line 107, in <module>
    if __name__ == '__main__': main()
  File "train.py", line 105, in main
    args.func(args)
  File "train.py", line 58, in train_handle
    loss = llm.forward(input_ids=input_ids, labels=output_ids, return_dict=True).loss
  File "/root/.cache/huggingface/modules/transformers_modules/model/modeling_chatglm.py", line 960, in forward
    loss = loss_fct(shift_logits, shift_labels)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1179, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 3053, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "train.py", line 107, in <module>
    if __name__ == '__main__': main()
  File "train.py", line 105, in main
    args.func(args)
  File "train.py", line 59, in train_handle
    loss.backward()
  File "/root/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.

使用 SGD 则完全没问题:

# python train.py /data/AdvertiseGen/dev.jsonl  test.pth -q "请根据以
下关键词生成文案:{content}" -a "{summary}"
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
train.py:57: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
epoch: 0
step: 0
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞\n\n 答:"
ans: "简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。"
loss: 5.8046875
epoch: 0
step: 1
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#裙*材质#针织*颜色#纯色*风格#复古*风格#文艺*风格#简约*图案#格子*图案#纯色*图案#复古*裙型#背带裙*裙长#连衣裙*裙领型#半高领\n\n答:"
ans: "这款BRAND针织两件套连衣裙,简约的纯色半高领针织上衣,修饰着颈部线,尽显优雅气质。同时搭配叠穿起一条背带式的复古格纹裙,整体 散发着一股怀旧的时髦魅力,很是文艺范。"
loss: 7.48828125
epoch: 0
step: 2
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#上衣*风格#嘻哈*图案#卡通*图案#印花*图案#撞色*衣样式#卫衣*衣款式#连帽\n\n答 :"
ans: "嘻哈玩转童年,随时<UNK>,没错,出街还是要靠卫衣来装酷哦!时尚个性的连帽设计,率性有范还防风保暖。还有胸前撞色的卡通印花设计 ,靓丽抢眼更富有趣味性,加上前幅大容量又时尚美观的袋鼠兜,简直就是孩子耍帅装酷必备的利器。"
loss: 6.48046875
epoch: 0
step: 3
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#裤*风格#英伦*风格#简约\n\n答:"
ans: "裤子是简约大方的版型设计,带来一种极简主义风格而且不乏舒适优雅感,是衣橱必不可少的一件百搭单品。标志性的logo可以体现出一股子浓郁的英伦风情,轻而易举带来独一无二的<UNK>体验。"
loss: 7.12109375
epoch: 0
step: 4
ques: "[Round 1]\n\n问:请根据以下关键词生成文案:类型#裙*裙下摆#弧形*裙腰型#高腰*裙长#半身裙*裙款式#不规则*裙款式#收腰\n\n答:"
ans: "这款来自梵凯的半身裙富有十足的设计感,采用了别致的不规则设计,凸显出时尚前卫的格调,再搭配俏皮的高腰设计,收腰提臀的同时还勾勒出优美迷人的身材曲线,而且还帮你拉长腿部比例,释放出优雅娇俏的小女人味。并且独特的弧形下摆还富有流畅的线条美,一颦一动间展现出灵动柔美的气质。"
loss: 5.9609375

不知道如何排查,是超参数设置不对吗?

Expected Behavior

No response

Steps To Reproduce

...

Environment

- OS: Linux autodl-container-57094f9364-479b788e 5.4.0-153-generic #170-Ubuntu SMP Fri Jun 16 13:43:31 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux
- Python: Python 3.8.10
- Transformers: 4.33.2
- PyTorch: 2.1.1+cu118
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) : True

Anything else?

No response

已解决:FP16 使用 Adam 优化的时候需要eps=1e-3

https://zhuanlan.zhihu.com/p/507889212