patrick-kidger/jaxtyping

issue with torch.compile

f-fuchs opened this issue · 4 comments

Hey,

I get the following error when I try to compile my PyTorch module:

  File "/home/fuchsfa/foundation-models/train.py", line 139, in <module>
    main()
  File "/home/fuchsfa/foundation-models/train.py", line 125, in main
    train(
  File "/home/fuchsfa/foundation-models/src/foundation_models/train.py", line 46, in train
    epoch_loss = train_one_epoch(
  File "/home/fuchsfa/foundation-models/src/foundation_models/train.py", line 95, in train_one_epoch
    output = model(images)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 515, in wrapped_fn
    bound = param_signature.bind(*args, **kwargs)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 516, in torch_dynamo_resume_in_wrapped_fn_at_515
    bound.apply_defaults()
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 518, in torch_dynamo_resume_in_wrapped_fn_at_516
    memos = push_shape_memo(bound.arguments)
  File "/home/fuchsfa/foundation-models/.venv/lib/python3.10/site-packages/jaxtyping/_storage.py", line 59, in push_shape_memo
    def push_shape_memo(arguments: dict[str, Any]):
AttributeError: '_thread._local' object has no attribute 'memo_stack'

The error message looks closer to this issue #23 than the other open issue regarding torch.compile #196.

I also tried to run the program with python -O train.py to disable beartype typechecking but this did not change anything.

Yup, this is just what #196 looks like these days. Sadly this is a limitation on the part of torch.compile.

Okay, this is unfortunate... I really like jaxtyping.

Is there a way to disable the jaxtyping decorates globally, so that I can still use them when the model is not compiled, and only disable them when the model is compiled?

I tried running it with python -O, but unfortunately that did not help :(

jaxtyping does offer a config environment variable for this:

self.update("jaxtyping_disable", os.environ.get("JAXTYPING_DISABLE", "0"))

so setting JAXTYPING_DISABLE=1 might work.

thx that worked 👍