wnma3mz/llama

有关训练模型保存问题

Gary3410 opened this issue · 3 comments

您好! 感谢您开源了这个很棒的项目。
我目前在3090上训练LLaMA,所以您提供的模型切分训练脚本对降低显存非常有帮助。可我这边在训练时,保存的模型似乎也是其中的一个分块(原本13G的模型,训练保存的checkpont为6G),这是直接torch.save()的问题吗?

说下我个人的理解,不一定对。
首先,在训练中预训练模型的参数已经冻结了,需要保存的是encode_embedding的参数,而这部分参数只需要几MB即可,代码ft_main.py中可以查看“prefix.pth”部分处的代码。
在推理过程中,读取模型参数也是分两部分读取,其中decode部分读取预训练参数,encode读取finetune的参数。
我个人改的方法是在trainer后面加了个embedding参数的保存代码,按照rank将encode_embedding的参数分别保存,然后在推理的时候按照rank依次读取。

您好! 感谢您开源了这个很棒的项目。 我目前在3090上训练LLaMA,所以您提供的模型切分训练脚本对降低显存非常有帮助。可我这边在训练时,保存的模型似乎也是其中的一个分块(原本13G的模型,训练保存的checkpont为6G),这是直接torch.save()的问题吗?

我理解是torch.splittorch.save两个函数的问题(大概率是前者),在split之后依旧保留了原来数据的信息,所以save的时候会把所有数据save。

之后可以试试在split之后做一些类似clone或者detach的操作?因为可以重新load+save恢复到原始大小,加上有其他事情所以就暂时忽略了这个问题😂

您好! 感谢您开源了这个很棒的项目。 我目前在3090上训练LLaMA,所以您提供的模型切分训练脚本对降低显存非常有帮助。可我这边在训练时,保存的模型似乎也是其中的一个分块(原本13G的模型,训练保存的checkpont为6G),这是直接torch.save()的问题吗?

我理解是torch.splittorch.save两个函数的问题(大概率是前者),在split之后依旧保留了原来数据的信息,所以save的时候会把所有数据save。

之后可以试试在split之后做一些类似clone或者detach的操作?因为可以重新load+save恢复到原始大小,加上有其他事情所以就暂时忽略了这个问题😂

split之后时候使用clone函数重新把对象复制出来即可,

ckpt[k] = w.clone()