bug: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
Closed this issue · 0 comments
gaocegege commented
│ /home/gaocegege/applications/miniconda3/envs/dev/lib/python3.9/site-packages/accelerate/big_mode │
│ ling.py:108 in register_empty_parameter │
│ │
│ 105 │ │ if param is not None: │
│ 106 │ │ │ param_cls = type(module._parameters[name]) │
│ 107 │ │ │ kwargs = module._parameters[name].__dict__ │
│ ❱ 108 │ │ │ module._parameters[name] = param_cls(module._parameters[name].to(device), ** │
│ 109 │ │
│ 110 │ def register_empty_buffer(module, name, buffer): │
│ 111 │ │ old_register_buffer(module, name, buffer) │
│ │
│ /home/gaocegege/applications/miniconda3/envs/dev/lib/python3.9/site-packages/torch/nn/parameter. │
│ py:36 in __new__ │
│ │
│ 33 │ │ if type(data) is torch.Tensor or type(data) is Parameter: │
│ 34 │ │ │ # For ease of BC maintenance, keep this path for standard Tensor. │
│ 35 │ │ │ # Eventually (tm), we should change the behavior for standard Tensor to matc │
│ ❱ 36 │ │ │ return torch.Tensor._make_subclass(cls, data, requires_grad) │
│ 37 │ │ │
│ 38 │ │ # Path for custom tensors: set a flag on the instance to indicate parameter-ness │
│ 39 │ │ t = data.detach().requires_grad_(requires_grad) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Only Tensors of floating point and complex dtype can require gradients