Lightning-AI/litgpt

Compiled inference failed: "Global state changed while dynamo tracing"

Opened this issue · 1 comments

CLI command:
$ litgpt generate stabilityai/stablelm-base-alpha-3b --prompt "Hello, my name is" --compile true

Workaround:

  1. Adding torch._dynamo.config.suppress_errors = True to generate/base.py didn't work.
  2. Replace torch.compile(next_token, ...) with model = torch.compile(model) before model = fabric.setup_module(model) works

Env:

$ pip list | grep torch
pytorch-lightning        2.2.5
torch                    2.3.1
torchmetrics             1.4.0.post0

Full error:

Traceback (most recent call last):
  File "/opt/conda/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "litgpt/__main__.py", line 57, in main
    CLI(parser_data)
  File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 196, in _run_component
    return component(**cfg)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "litgpt/generate/base.py", line 252, in main
    y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "litgpt/generate/base.py", line 127, in generate
    token = next_token(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "litgpt/generate/base.py", line 74, in next_token
    logits = model(x, input_pos)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 138, in forward
    with precision.forward_context():
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    assert (
AssertionError: Global state changed while dynamo tracing, please report a bug


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
rasbt commented

Sorry that I can't be more helpful here, but I never used compilation in LitGPT myself, but I remember from my colleagues that torch.compile does not fully support everything and there are some issues with it.