MegEngine/MegDiffusion

Gradient clipping issues in MegEngine v1.9.x

cheekyshibe opened this issue · 1 comments

Description

Training with a single GPU & using gradient clipping in this codebase will cause an error in MegEngine 1.9.x version. After 1 iteration with auto diff & parameter update, the next time the model do forward will break. Error message:

RuntimeError: assertion `filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4' failed at ../../../../../../imperative/src/impl/ops/convolution.cpp:61: megdnn::TensorLayout mgb::imperative::{anonymous}::convolution::do_shape_infer(const mgb::imperative::OpDef&, size_t, megdnn::TensorLayout, megdnn::TensorLayout)
extra message: bad filter ndim for dense convolution: spatial_ndim=2 filter_ndim=0

Here is the simplest example to reproduce this problem:

import megengine
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
import megengine.autodiff as autodiff

megengine.async_level = 0

class SimpleModel(M.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.conv1 = M.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.conv2 = M.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.nn.interpolate(x, scale_factor=1, mode="nearest")
        x = self.conv2(x)
        return x

if __name__ == "__main__":
    x = F.ones((1, 1, 2, 2))
    model = SimpleModel(in_ch = 1)

    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    gm = autodiff.GradManager()
    gm.attach(model.parameters())

    with gm:
        loss = model(x) + 0
        gm.backward(loss)
        
    optim.clip_grad_norm(model.parameters(), max_norm=1.)
    optimizer.step()
    y = model(x)

Workaround

  • Solution 1: Comment this line in megdiffusion.scripts.train:

    optim.clip_grad_norm(model.parameters(), FLAGS.grad_clip)

    Then we can train the model without clipping grad. ( But it's not expected... 😣 )

  • Solution 2: This situation did not happen when using distributed training.

  • Solution 3: Try changing loss = model(x) + 0 to loss = model(x) 🤔🤔🤔

  • Solution 4: Try deleting x = F.nn.interpolate(x, scale_factor=1, mode="nearest") 🤔🤔🤔

Issue Track

This problem was fixed in MegEngine/MegEngine@df5ebd3 so you can wait for the release of MegEngine v1.10 or build MegEngine dev latest than this commit from the source.

Python traceback messages show that dnn apply op conv1(x) failed:

  • assertion `filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4' failed
  • bad filter ndim for dense convolution: spatial_ndim=2 filter_ndim=0

In MegEngine/MegDNN, it's common to do shape_infer when we hope to do something like (output,) = apply(op, inp, weight) and dispatch the kernel to the MegDNN computing library. With the input descriptions, we could (not always work) infer the information about the output such as shape. It's convenient in some situations, for example, if we need to get the shape of the convolution filter in the ith layer, we don't need to input the data and execute the corresponding code until the filter tensor was calculated -- try to infer that information in advance. Then get them directly when needed.

TensorShape ChannelImpl::get_shape(Handle handle) {
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
    return ret;
}
DType ChannelImpl::get_dtype(Handle handle) {
    auto info = reinterpret_cast<TensorInfo*>(handle);
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
    auto ret = info->desc.layout.dtype;
    return ret;
}

In MegDNN, filter_ndim=0 means the filter's ndim is unknown (or we say it can not be inferred in advance). But it SHOULD NOT break our programs because we will be able to get this shape information when needed computation is done. So MegEngine/MegEngine@df5ebd3 is the solution.

The story is not ending...

It seems that the problem has been solved. But if you are still confused about why those workarounds mentioned above are valid, you can try debugging the example program. The RuntimeError is raised from MegDNN, you need gdb, so let's do it!

Build MegEngine with -DCMAKE_BUILD_TYPE=Debug from the source.
Run gdb python3, catch throw then run example.py (might wait for seconds depending on the machine):

  • Read output messages, add a breakpoint at .../imperative/src/impl/ops/convolution.cpp:127 then r two times;
  • In the second time executing desc.layout = do_shape_infer(def, src_ndim, src, filter), p filter then you'll find it's ndim is 0, which means that shape_infer is failed. But when did it was started?
  • Read the stack infos, add a breakpoint at .../imperative/src/impl/interpreter/interpreter_impl.cpp:440;
    • Then r and l, read apply_op_impl source code:
       SmallVector<Handle> ChannelImpl::apply_op_impl(
           {
               MGB_LOCK_GUARD(m_info_spin);
               for (auto i : inputs) {
                   auto info = reinterpret_cast<TensorInfo*>(i);
                   mgb_assert(
                           !info->invalid,
                           "an input tensor is unusable due to previous error");
                   input_infos.push_back(info);
                   input_descs.push_back(info->desc);
               }
          
       }
    • Add condition info->desc.layout.ndim == 0 then r utils break here, check py-bt:
       File ".../imperative/python/megengine/functional/nn.py", line 261, in conv2d
         (output,) = apply(op, inp, weight)
       File ".../imperative/python/megengine/module/conv.py", line 422, in calc_conv
         self.compute_mode,
       File ".../imperative/python/megengine/module/conv.py", line 426, in forward
         return self.calc_conv(inp, self.weight, self.bias)
       File ".../imperative/python/megengine/module/module.py", line 142, in __call__
         outputs = self.forward(*inputs, **kwargs)
       File "example.py", line 21, in forward
         x = self.conv2(x)
      it means while dispatching the conv2(x) kernel, its input lost shape info.

Who lost the shape info?

  • Read source code ChannelImpl::dispatch_kernel(), we can find profiling info:

       void ChannelImpl::dispatch_kernel(
               std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
               const SmallVector<LogicalTensorDesc>& input_descs,
    
           // ...
           std::optional<StackManager::Guard> guard;
           if (Profiler::is_profiling()) {
               guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
           }
           auto [output_descs, validated] =
                   OpDef::infer_output_attrs_fallible(*op, input_descs);
           MGB_RECORD_EVENT(ShapeInferEvent, validated);
       
           if (Profiler::is_profiling()) {
           // ...
           }
       }
  • So using the profiler might be helpful. Following the document we will get the picture:

    image

    Note that nr_shape_infer_failure is recorded, so we need to focus on when it increases the first time.

    image

    Now we find the answer, the Resize op (interpolate) can not infer shape in advance.

Verification

We can insert a reshape op between interpolation and conv2 ops as follows:

class SimpleModel(M.Module):
    #...
    def forward(self, x):
        x = self.conv1(x)
        x = F.nn.interpolate(x, scale_factor=1, mode="nearest").reshape(1, 1, 2, 2)
        x = self.conv2(x)

Run the program again, congratulations! Everything is fine.