thuml/depyf

[Bug]: depyf.prepare_debug causes torch.compile crash with cell var mismatch error

tombousso opened this issue ยท 3 comments

Your current environment

latest depyf + torch

๐Ÿ› Describe the bug

To reproduce:

def f(a, b):
    assert a
    print(b)
    [a for _ in [None]]

f(1,2)

import depyf
import torch

f = torch.compile(f)
with depyf.prepare_debug('debug_dir'):
    f(1,2)

Output:

2
/usr/local/lib/python3.10/dist-packages/depyf/explain/enable_debugging.py:153: UserWarning: You are trying to debug `torch.compile`. Please make sure the code runs multiple times to cover all the possible branches.
  warnings.warn((
Traceback (most recent call last):
  File "/.../test.py", line 13, in <module>
    f(1,2)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 437, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 621, in compile_inner
    assert code.co_cellvars == out_code.co_cellvars, msg
AssertionError: cell var mismatch: old code object has cell var ('a',), new code object has cell var ('a', 'b')


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Without depyf.prepare_debug, there are no errors, torch.compile succeeds.

@tombousso can you try #26 ? it should solve this issue.

Thanks @youkaichao! It seems like this issue is solved. But now it's failing in a different part of the code:

import torch
torch.set_default_device('cuda')

from transformer_engine.pytorch.attention import DotProductAttention

model = DotProductAttention(32, 128)
args = (torch.ones([1, 256, 32, 128]),torch.ones([1, 256, 32, 128]),torch.ones([1, 256, 32, 128]))

model = torch.compile(model)
import depyf
with depyf.prepare_debug('debug_dir'):
    out = model.forward(*args)

Output:

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 616, in compile_inner
    assert code.co_freevars == out_code.co_freevars, msg
AssertionError: free var mismatch: old code object has free var ('data_ptr', 'last_dim_size', 'last_two_dims_size', 'shape', 'stride', 'qkv_format'), new code object has free var ('data_ptr', 'last_dim_size', 'last_two_dims_size', 'qkv_format', 'shape', 'stride')

Again it's only failing when I use prepare_debug. It looks like the co_freevars list is the same, but in a different order.

Okay, I will close this issue. Please open another issue to report the newly arised problem