[help wanted] debugging one frame with multiple cache entries
imShZh opened this issue · 1 comments
I have been learning torch.compile
recently and this project gives me a HUGE inspiration. Thanks for such an outstanding job!
It seems we could register our decompiled code into torch.compile
by using these few lines of code below:
from torch._dynamo.utils import orig_code_map
from torch._dynamo.convert_frame import output_codes
output_codes.add(decompiled_and_compiled_back_code)
orig_code_map[decompiled_and_compiled_back_code] = code
After that we could freely set breakpoint in __transformed_code
files and start debugging.
I wonder how depyf
would behave if we have one frame with multiple cache entries(i.e. linked list of guard and optimized code). I think the same frame might generate multiple __transformed_code
files.
My questions are, would torch.compile
choose the correct file based on the guard so we can still debug normally? Under what circumstances will there be multiple cache entries for the same frame?
Thanks for your appreciation and interest :)
Under what circumstances will there be multiple cache entries for the same frame?
as long as all guards fail, torch.compile
will try to generate one new cache entry for you.
for example, the following code will trigger re-compile:
import torch
@torch.compile
def f(a, b):
return a + b
def main():
x = torch.tensor([1.0])
f(x, 1)
f(x, 2)
f(x, 3)
if __name__ == "__main__":
main()
Run it with TORCH_LOGS=recompiles_verbose python test.py
, and you will see:
f(x, 1)
triggers the first compilation, and puts a guard b==1
f(x, 2)
fails the guard, so torch.compile
re-compiles, and puts a guard b==2
f(x, 3)
fails both two guards, so torch.compile
re-compiles, and puts a guard b==3
The reason behind that, is pytorch always try to embed constants into the graph.
would torch.compile choose the correct file based on the guard so we can still debug normally?
yes, torch.compile
will do the job.
how depyf would behave if we have one frame with multiple cache entries
depyf
just decompile the code for every cache entries, and make all of them debuggable.
torch.compile
chooses the cache entry to execute, and you can set breakpoints in any files you want.