thuml/depyf

[Issue Tracker]: guards are empty with some pytorch versions

youkaichao opened this issue ยท 1 comments

Your current environment

The output of `python collect_env.py`

๐Ÿ› Describe the bug

We rely on guard.code_parts to get the code parts of a guard.

Here is a test script to test if the code parts are available:

from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return gm.forward  # return a python callable

@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))

from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(guard)
dis.dis(code)

print(f"{torch.__version__=}")
for code_part in guard.code_parts:
    print(code_part) # see if it prints anything

After manual testing, it works for:

  • 2.2.0
  • 2.2.1
  • 2.2.2
  • 2.3.0
  • 2.3.1

But it does not work for 2.4.0, because it uses cpp guard by default.

Adding torch._dynamo.config.enable_cpp_guard_manager = False after import torch works. However, pytorch will turn to cpp guard only in the future.

The solution from pytorch team, is to make guard.code_parts work even if we are using python guard. See pytorch/pytorch#127977 (review) for more details.

At the time of writing, 2.5.0.dev20240805+cu121 does not have it yet.

In summary, if a pytorch version comes out, with:

  • guard.code_parts are not available by default
  • torch._dynamo.config.enable_cpp_guard_manager = False does not work due to depreciation, such as in pytorch/pytorch#132692

Then we would get in trouble.

Fortunately, it seems pytorch 2.4.0 works with torch._dynamo.config.enable_cpp_guard_manager = False .

I will keep updating the status, when new pytorch version comes out.

fixed by #44