Error when using torch.block_diag method
joelamzn opened this issue · 2 comments
joelamzn commented
The following code throws the below error on trn1.32xlarge instance.
>>> import torch, torch_xla
>>> import torch_xla.core.xla_model as xm
>>> device = xm.xla_device()
>>> segments = [torch.ones((1,4), device=device)]
>>> torch.block_diag(*segments[:])
Error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.8/dist-packages/torch/functional.py", line 1266, in block_diag
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
RuntimeError: torch_xla/csrc/aten_xla_type.cpp:3426 : Check failed: !runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::XLANativeFunctions::block_diag(c10::ArrayRef<at::Tensor>)
at::_ops::block_diag::redispatch(c10::DispatchKeySet, c10::ArrayRef<at::Tensor>)
at::_ops::block_diag::call(c10::ArrayRef<at::Tensor>)
PyCFunction_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalCodeWithName
PyEval_EvalCode
PyRun_InteractiveLoopFlags
PyRun_AnyFileExFlags
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
Setup
- Torch:
2.1.2+cu121
- Torch XLA:
2.1.1
jeffhataws commented
Hi @joelamzn ,
Thanks for reporting the issue. Currently we disable functionalization as default for performance reason. Will you try setting XLA_DISABLE_FUNCTIONALIZATION=0 to run your example. I tried your code with this environment variable setting and no longer see the error.
aws-taylor commented
Hello @joelamzn,
We haven't heard from you in a while, so I'm going to resolve this issue. Feel free to re-open if you require further asstance.