PaddlePaddle/PaddleNLP

[Bug]: 使用amp_master_grad的同时开启recompute,weight没有main_grad

Opened this issue · 5 comments

软件环境

- paddlepaddle: 
- paddlepaddle-gpu: 2.6
- paddlenlp: 2.7.1.post0

重复问题

  • I have searched the existing issues

错误描述

正常情况下,开启--amp_master_grad后,所有的weight都会有main_grad。
但是当使用recompute=full后,自定义python op 的backward中的weight却没有main_grad。

稳定复现步骤 & 代码

以llama训练为例

  • --amp_master_grad true开启main_grad
  • 设置--recompute true --recompute_granularity full来开启recompute,
  • 设置--enable_linear_fused_grad_add true来调用llm/llama/fused_layers.py。因为这个问题是我在开发一个类似linear_fused_grad_add的功能时发现的。

修改fused_layers.py #L32-L41的代码为:

    def forward(ctx, x, weight, bias=None, name=None):
        y = origin_linear(x, weight, bias)

        ctx.save_for_backward(weight)
        ctx.x = x
        ctx.bias = bias
        return y

    @staticmethod
    def backward(ctx, y_grad):
        weight, = ctx.saved_tensor()  #这个weight没有main_grad
        x = ctx.x
        bias = ctx.bias
        if hasattr(weight, "main_grad"):
            print("weight has main_grad")
        else:
            print("weight has no main_grad")

运行llama训练,backward就会报weight没有main_grad

而如果不使用ctx.save_for_backwardctx.saved_tensor(),用ctx.weight=weightweight=ctx.weight替代,则weight会有main_grad。

我debug发现,这大概是因为在开启recompute时,save_for_backward会触发recompute.py#L340这里的拷贝,将weight拷贝给一个名为weight.name+"cpy"的tensor,但并没有拷贝main_grad。

        ctx.save_for_backward(weight)
        ctx.x = x
        ctx.bias = bias

这里为什么是拆开写的?试试下面的写法?

ctx.save_for_backward(x, weight, bias) 
x, weight, bias = ctx.saved_tensor()

@GuoxiaWang 因为这个issue里面我关心的重点是:开启recompute的时候ctx.save_for_backward(weight)这种写法会遇到backward中的weight没有main_grad的问题。

你说的这种写法是fused_layer.py中原本的写法,我也测试过,开启recompute=full后会遇到下面这个奇怪的错误,这就需要开另外一个issue了。

    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py", line 37, in forward
    output = self._layers(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1913, in forward
    outputs = self.llama(
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1664, in forward
    layer_outputs = self.recompute_training_full(
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1535, in recompute_training_full
    hidden_states = self.recompute_func(
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/utils/__init__.py", line 142, in recompute
    return fleet.recompute.recompute(function, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/recompute/recompute.py", line 532, in recompute
    return _recompute_without_reentrant(function, preserve, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/recompute/recompute.py", line 399, in _recompute_without_reentrant
    outputs = function(*args, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1531, in custom_forward
    return module(*inputs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1228, in forward
    outputs = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 901, in forward
    query_states = self.q_proj(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/layers/mpu/mp_layers.py", line 516, in forward
    output_parallel = self.linear(
  File "/workspace/PaddleNLP/llm/fused_layers.py", line 36, in forward
    ctx.save_for_backward(x, weight, bias)
  File "/usr/local/lib/python3.10/dist-packages/paddle/autograd/py_layer.py", line 91, in save_for_backward
    self.container = tensors
ValueError: (InvalidArgument) save_for_backward only support Tensor, list of Tensor, tuple of Tensor. (at /opt/paddle/paddle/paddle/fluid/pybind/eager_py_layer.cc:644)

更新一下,recompute设置reentrant=True,可以避开这个bug。仅reentrant = False会遇到这个bug。

更新一下,recompute设置reentrant=True,可以避开这个bug。仅reentrant = False会遇到这个bug。

@Xreki 麻烦帮忙找Paddle这边熟悉recompute的同学看一下

Xreki commented

![image](https://github.com@Wong4j PaddlePaddle/PaddleNLP/assets/12538138/84258d77-048e-41a2-9641-6d7a303ba6bf)

@Wong4j 这个倒是reentrant=False时的已知问题