pytorch-labs/gpt-fast

mmap issue in bf16 of gpt-fast

yanbing-j opened this issue · 1 comments

gpt-fast will use torch.load with mmap=True to load checkpoints of models. This may help speed up model load time. However, eventually, mmap is not used in bf16, because in https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L247, model will to bfloat16 from float16 when running bf16 model. to will malloc a new memory area, mapped file is not used.

Meanwhile, in int8/int4, the logic of https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L247 does not make sense. int8 model should not convert to bfloat16 data type. Now, int8/int4 can work well, because weight is not a parameter of int8/int4 modules by chance.

@yanbing-j The goal of gpt-fast is to demonstrate the list of optimization we did to accelerate inference, model loading is not the major bottleneck, so we didn't do too much optimization. But I do agree with your points, we are also welcome PRs for these optimizations.