thuml/depyf

[help wanted] Understand how graph break work

Closed this issue · 3 comments

All stuff in depfy works fine.

version: python==3.9.0, depyf==0.15.0

Here's the code I use, based on the readme.

import torch
from torch import _dynamo as torchdynamo
from typing import List
import numpy as np
from fn1 import fn1

@torch.compile(backend="eager", dynamic=True)
def toy_example(a, b):
    y = np.zeros((10, ))
    x = a / (torch.abs(a) + 1)
    res = fn1(x, y)
    if b.sum() < 0:  # graph break
        b = b * -1
        b = torch.pow(b, 2)
    else:            # graph break
        b = b * 2
        b = torch.pow(b, 2)
    return x * b, y, res

def main():
    for _ in range(100):
        toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
    # main()
    # surround the code you want to run inside `with depyf.prepare_debug`
    import depyf
    with depyf.prepare_debug("./dump_src_dir"):
        main()

    # surround the code you want to debug inside `with depyf.debug()`
    with depyf.debug():
        print("hello world")
        main()

Question: According to the generated file dump_src_dir/full_code_for_toy_example_0.py, why transformed___resume_at_50_2() is not called by __transformed_code_0_for_toy_example()? In other words, why the subgraph generated by torch.compile() is not called after graph break?

# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def __resume_at_50_2(b, y, x, res):
    b = b * -1
    b = torch.pow(b, 2)
    return x * b, y, res

def transformed___resume_at_50_2(b, y, x, res):
    L = {"b": b, "y": y, "x": x, "res": res}
    if __guard_0_for_torch_dynamo_resume_in_toy_example_at_16(L):
        return __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_16(b, y, x, res)
    # Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
    return __resume_at_50_2(b, y, x, res)

def __transformed_code_0_for_toy_example(a, b):
    graph_out_0 = __compiled_fn_1(a, b)
    res = graph_out_0[3]
    x = graph_out_0[2]
    y = __import_torch_dot__dynamo_dot_utils.to_numpy_helper(graph_out_0[1])
    if graph_out_0[0]:
        return __resume_at_50_2(b, y, x, res)
    return __resume_at_72_3(b, y, x, res)

def transformed_toy_example(a, b):
    L = {"a": a, "b": b}
    if __guard_0_for_toy_example(L):
        return __transformed_code_0_for_toy_example(a, b)
    # Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
    return toy_example(a, b)

return __resume_at_50_2(b, y, x, res) this line does not mean directly calling __resume_at_50_2 function. It means calling the function with dynamo enabled, which will check the compiled code and guards. This procedure happens in C code, and depyf illustrates the process for you.

You can note the comment before __resume_at_50_2:

# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.

This means, when you see return __resume_at_50_2(b, y, x, res) in __transformed_code_0_for_toy_example, you should follow transformed___resume_at_50_2, which, after checking guards, should call __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_16 .

Further question, when will __guard_1_for_torch_dynamo_resume_in_toy_example_at_16() become true?

In my situation, the code don't step into __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_16(), it just returns back to origin scipt example.py, instead of the subgraph compiled fn, i.e. __compiled_fn_5 Captured Graph 0.py

The guard information should be added back in #30 . Please try the latest code.

PyTorch recently uses cpp guards by default, so you cannot see the python code. I switch it back to Python guards so that we can view it.