Gradient clipping issues in MegEngine v1.9.x
cheekyshibe opened this issue · 1 comments
Description
Training with a single GPU & using gradient clipping in this codebase will cause an error in MegEngine 1.9.x version. After 1 iteration with auto diff & parameter update, the next time the model do forward will break. Error message:
RuntimeError: assertion `filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4' failed at ../../../../../../imperative/src/impl/ops/convolution.cpp:61: megdnn::TensorLayout mgb::imperative::{anonymous}::convolution::do_shape_infer(const mgb::imperative::OpDef&, size_t, megdnn::TensorLayout, megdnn::TensorLayout)
extra message: bad filter ndim for dense convolution: spatial_ndim=2 filter_ndim=0
Here is the simplest example to reproduce this problem:
import megengine
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
import megengine.autodiff as autodiff
megengine.async_level = 0
class SimpleModel(M.Module):
def __init__(self, in_ch):
super().__init__()
self.conv1 = M.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
self.conv2 = M.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = F.nn.interpolate(x, scale_factor=1, mode="nearest")
x = self.conv2(x)
return x
if __name__ == "__main__":
x = F.ones((1, 1, 2, 2))
model = SimpleModel(in_ch = 1)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
gm = autodiff.GradManager()
gm.attach(model.parameters())
with gm:
loss = model(x) + 0
gm.backward(loss)
optim.clip_grad_norm(model.parameters(), max_norm=1.)
optimizer.step()
y = model(x)
Workaround
-
Solution 1: Comment this line in
megdiffusion.scripts.train
:optim.clip_grad_norm(model.parameters(), FLAGS.grad_clip)
Then we can train the model without clipping grad. ( But it's not expected... 😣 )
-
Solution 2: This situation did not happen when using distributed training.
-
Solution 3: Try changing
loss = model(x) + 0
toloss = model(x)
🤔🤔🤔 -
Solution 4: Try deleting
x = F.nn.interpolate(x, scale_factor=1, mode="nearest")
🤔🤔🤔
Issue Track
This problem was fixed in MegEngine/MegEngine@df5ebd3 so you can wait for the release of MegEngine v1.10 or build MegEngine dev latest than this commit from the source.
Python traceback messages show that dnn apply op conv1(x)
failed:
- assertion `filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4' failed
- bad filter ndim for dense convolution: spatial_ndim=2 filter_ndim=0
In MegEngine/MegDNN, it's common to do shape_infer
when we hope to do something like (output,) = apply(op, inp, weight)
and dispatch the kernel to the MegDNN computing library. With the input descriptions, we could (not always work) infer the information about the output such as shape
. It's convenient in some situations, for example, if we need to get the shape of the convolution filter in the ith layer, we don't need to input the data and execute the corresponding code until the filter tensor was calculated -- try to infer that information in advance. Then get them directly when needed.
TensorShape ChannelImpl::get_shape(Handle handle) {
auto info = reinterpret_cast<TensorInfo*>(handle);
if (info->desc.layout.ndim != 0) {
return info->desc.layout;
}
TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
return ret;
}
DType ChannelImpl::get_dtype(Handle handle) {
auto info = reinterpret_cast<TensorInfo*>(handle);
MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
auto ret = info->desc.layout.dtype;
return ret;
}
In MegDNN, filter_ndim=0
means the filter's ndim
is unknown (or we say it can not be inferred in advance). But it SHOULD NOT break our programs because we will be able to get this shape information when needed computation is done. So MegEngine/MegEngine@df5ebd3 is the solution.
The story is not ending...
It seems that the problem has been solved. But if you are still confused about why those workarounds mentioned above are valid, you can try debugging the example program. The RuntimeError is raised from MegDNN, you need gdb
, so let's do it!
Build MegEngine with -DCMAKE_BUILD_TYPE=Debug
from the source.
Run gdb python3
, catch throw
then run example.py
(might wait for seconds depending on the machine):
- Read output messages, add a breakpoint at
.../imperative/src/impl/ops/convolution.cpp:127
thenr
two times; - In the second time executing
desc.layout = do_shape_infer(def, src_ndim, src, filter)
,p filter
then you'll find it'sndim
is0
, which means thatshape_infer
is failed. But when did it was started? - Read the stack infos, add a breakpoint at
.../imperative/src/impl/interpreter/interpreter_impl.cpp:440
;- Then
r
andl
, readapply_op_impl
source code:SmallVector<Handle> ChannelImpl::apply_op_impl( { MGB_LOCK_GUARD(m_info_spin); for (auto i : inputs) { auto info = reinterpret_cast<TensorInfo*>(i); mgb_assert( !info->invalid, "an input tensor is unusable due to previous error"); input_infos.push_back(info); input_descs.push_back(info->desc); } }
- Add condition
info->desc.layout.ndim == 0
thenr
utils break here, checkpy-bt
:it means while dispatching theFile ".../imperative/python/megengine/functional/nn.py", line 261, in conv2d (output,) = apply(op, inp, weight) File ".../imperative/python/megengine/module/conv.py", line 422, in calc_conv self.compute_mode, File ".../imperative/python/megengine/module/conv.py", line 426, in forward return self.calc_conv(inp, self.weight, self.bias) File ".../imperative/python/megengine/module/module.py", line 142, in __call__ outputs = self.forward(*inputs, **kwargs) File "example.py", line 21, in forward x = self.conv2(x)
conv2(x)
kernel, its input lost shape info.
- Then
Who lost the shape info?
-
Read source code
ChannelImpl::dispatch_kernel()
, we can find profiling info:void ChannelImpl::dispatch_kernel( std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos, const SmallVector<LogicalTensorDesc>& input_descs, // ... std::optional<StackManager::Guard> guard; if (Profiler::is_profiling()) { guard.emplace(op->trait()->make_name(*op), &state.stack_manager); } auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); MGB_RECORD_EVENT(ShapeInferEvent, validated); if (Profiler::is_profiling()) { // ... } }
-
So using the profiler might be helpful. Following the document we will get the picture:
Note that
nr_shape_infer_failure
is recorded, so we need to focus on when it increases the first time.Now we find the answer, the
Resize
op (interpolate) can not infer shape in advance.
Verification
We can insert a reshape
op between interpolation
and conv2
ops as follows:
class SimpleModel(M.Module):
#...
def forward(self, x):
x = self.conv1(x)
x = F.nn.interpolate(x, scale_factor=1, mode="nearest").reshape(1, 1, 2, 2)
x = self.conv2(x)
Run the program again, congratulations! Everything is fine.