aws-neuron/aws-neuron-sdk

Error when using torch.block_diag method

joelamzn opened this issue · 2 comments

The following code throws the below error on trn1.32xlarge instance.

>>> import torch, torch_xla
>>> import torch_xla.core.xla_model as xm
>>> device = xm.xla_device()
>>> segments = [torch.ones((1,4), device=device)]
>>> torch.block_diag(*segments[:])

Error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/torch/functional.py", line 1266, in block_diag
    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore[attr-defined]
RuntimeError: torch_xla/csrc/aten_xla_type.cpp:3426 : Check failed: !runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::XLANativeFunctions::block_diag(c10::ArrayRef<at::Tensor>)

	at::_ops::block_diag::redispatch(c10::DispatchKeySet, c10::ArrayRef<at::Tensor>)


	at::_ops::block_diag::call(c10::ArrayRef<at::Tensor>)

	PyCFunction_Call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyObject_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	PyEval_EvalCode



	PyRun_InteractiveLoopFlags
	PyRun_AnyFileExFlags

	Py_BytesMain
	__libc_start_main
	_start
*** End stack trace ***

Setup

  • Torch: 2.1.2+cu121
  • Torch XLA: 2.1.1

Hi @joelamzn ,

Thanks for reporting the issue. Currently we disable functionalization as default for performance reason. Will you try setting XLA_DISABLE_FUNCTIONALIZATION=0 to run your example. I tried your code with this environment variable setting and no longer see the error.

Hello @joelamzn,

We haven't heard from you in a while, so I'm going to resolve this issue. Feel free to re-open if you require further asstance.