issue with torch.compile
f-fuchs opened this issue · 4 comments
Hey,
I get the following error when I try to compile my PyTorch module:
File "/home/fuchsfa/foundation-models/train.py", line 139, in <module>
main()
File "/home/fuchsfa/foundation-models/train.py", line 125, in main
train(
File "/home/fuchsfa/foundation-models/src/foundation_models/train.py", line 46, in train
epoch_loss = train_one_epoch(
File "/home/fuchsfa/foundation-models/src/foundation_models/train.py", line 95, in train_one_epoch
output = model(images)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 515, in wrapped_fn
bound = param_signature.bind(*args, **kwargs)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 516, in torch_dynamo_resume_in_wrapped_fn_at_515
bound.apply_defaults()
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 518, in torch_dynamo_resume_in_wrapped_fn_at_516
memos = push_shape_memo(bound.arguments)
File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_storage.py", line 59, in push_shape_memo
def push_shape_memo(arguments: dict[str, Any]):
AttributeError: '_thread._local' object has no attribute 'memo_stack'
The error message looks closer to this issue #23 than the other open issue regarding torch.compile #196.
I also tried to run the program with python -O train.py
to disable beartype
typechecking but this did not change anything.
Yup, this is just what #196 looks like these days. Sadly this is a limitation on the part of torch.compile
.
Okay, this is unfortunate... I really like jaxtyping.
Is there a way to disable the jaxtyping decorates globally, so that I can still use them when the model is not compiled, and only disable them when the model is compiled?
I tried running it with python -O
, but unfortunately that did not help :(
jaxtyping does offer a config environment variable for this:
jaxtyping/jaxtyping/_config.py
Line 21 in 317cc9e
so setting JAXTYPING_DISABLE=1
might work.
thx that worked 👍